This document is relevant for: Inf2, Trn1, Trn1n

Neuron Custom C++ Operators Performance Optimization#

In this tutorial, we will build on the small MLP model shown in Neuron Custom C++ Operators in MLP Training and demonstrate methods to optimize the performance of a custom C++ operator. We will be taking advantage of the TCM accessor as well as the usage of multiple GPSIMD cores to enhance performance.

This tutorial assumes the reader has read and set up an environment described in Neuron Custom C++ Operators in MLP Training.

Download Examples#

To download the source code for this tutorial, do:

git clone https://github.com/aws-neuron/aws-neuron-samples.git
cd aws-neuronx-samples/torch-neuronx/inference/customop_mlp

Note

We will be using an inference example in this tutorial in order to adhere to certain Custom C++ operator restrictions when using multiple GPSIMD cores (see Custom Operators API Reference Guide [Experimental] for details on current restrictions).

Model Configuration Adjustment#

For this tutorial, we will enlarge the size of the hidden layer from [120, 84] to [4096, 2048].

import torch
import torch.nn as nn
from torch.nn import functional as F
import my_ops

# Declare 3-layer MLP for MNIST dataset
class MLP(nn.Module):
    def __init__(self, input_size = 28 * 28, output_size = 10, layers = [4096, 2048]):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, layers[0])
        self.fc2 = nn.Linear(layers[0], layers[1])
        self.fc3 = nn.Linear(layers[1], output_size)

    def forward(self, x):
        f1 = self.fc1(x)
        r1 = my_ops.Relu.apply(f1)
        f2 = self.fc2(r1)
        r2 = my_ops.Relu.apply(f2)
        f3 = self.fc3(r2)
        return torch.log_softmax(f3, dim=1)

Performance with Element-wise Accessor#

The neuron directory contains the same code shown in Neuron Custom C++ Operators in MLP Training, where the relu_forward is implemented with element-wise accessor. Go to neuron directory and run inference.py, the expected output on a trn1 instance is,

Inf throughput (iter/sec): 8.098649744235592
----------End Inference ---------------

Performance with TCM Accessor#

Now we switch to neuron-tcm folder. As mentioned in Custom Operators API Reference Guide [Experimental], TCM accessors provide faster read and write performance. We implement the relu_forward using TCM accessor:

torch::Tensor relu_forward(const torch::Tensor& t_in) {
    size_t num_elem = t_in.numel();
    torch::Tensor t_out = torch::zeros(t_in.sizes(), torch::kFloat);

    static constexpr size_t buffer_size = 1024;
    float *tcm_buffer = (float*)torch::neuron::tcm_malloc(sizeof(float) * buffer_size);

    if (tcm_buffer != nullptr) {
        auto t_in_tcm_acc = t_in.tcm_accessor();
        auto t_out_tcm_acc = t_out.tcm_accessor();

        for (size_t i = 0; i < num_elem; i += buffer_size) {
        size_t remaining_elem = num_elem - i;
        size_t copy_size = (remaining_elem > buffer_size) ? buffer_size : remaining_elem;

        t_in_tcm_acc.tensor_to_tcm<float>(tcm_buffer, i, copy_size);
        for (size_t j = 0; j < copy_size; j++) {
            tcm_buffer[j] = tcm_buffer[j] > 0.0 ? tcm_buffer[j] : 0.0;
        }
        t_out_tcm_acc.tcm_to_tensor<float>(tcm_buffer, i, copy_size);
        }
    }
    torch::neuron::tcm_free(tcm_buffer);
    return t_out;
}

Running build.py and inference.py, the expected output on a trn1 instance is:

Inf throughput (iter/sec): 220.73800131604054
----------End Inference ---------------

Extending the example to utilize multiple GPSIMD cores#

Now we switch to the neuron-multicore folder. We first enable the usage of multiple GPSIMD cores by multicore=True in the build.py.

custom_op.load(
    name='relu',
    compute_srcs=['relu.cpp'],
    shape_srcs=['shape.cpp'],
    build_directory=os.getcwd(),
    multicore=True,
    verbose=True
)

After passing the flag, the kernel function relu_forward will execute on all GPSIMD cores. Thus we need to use cpu_id to partiton the workload among all cores.

torch::Tensor relu_forward(const torch::Tensor& t_in) {
    size_t num_elem = t_in.numel();
    torch::Tensor t_out = get_dst_tensor();

    uint32_t cpu_id = get_cpu_id();
    uint32_t cpu_count = get_cpu_count();
    uint32_t partition = num_elem / cpu_count;
    if (cpu_id == cpu_count - 1) {
        partition = num_elem - partition * (cpu_count - 1);
    }

    static constexpr size_t buffer_size = 1024;
    float *tcm_buffer = (float*)torch::neuron::tcm_malloc(sizeof(float) * buffer_size);

    if (tcm_buffer != nullptr) {
        auto t_in_tcm_acc = t_in.tcm_accessor();
        auto t_out_tcm_acc = t_out.tcm_accessor();

        for (size_t i = 0; i < partition; i += buffer_size) {
        size_t remaining_elem = partition - i;
        size_t copy_size = (remaining_elem > buffer_size) ? buffer_size : remaining_elem;

        t_in_tcm_acc.tensor_to_tcm<float>(tcm_buffer, partition *cpu_id + i, copy_size);
        for (size_t j = 0; j < copy_size; j++) {
            tcm_buffer[j] = tcm_buffer[j] > 0.0 ? tcm_buffer[j] : 0.0;
        }
        t_out_tcm_acc.tcm_to_tensor<float>(tcm_buffer, partition *cpu_id + i, copy_size);
        }
    }
    torch::neuron::tcm_free(tcm_buffer);
    return t_out;
}

There are two things noteworthy in the code:

  1. We use cpu_id and cpu_count to distribute the workload among all cores. Particularly, each cores performs relu on a partition of the tensor, the offset is computed based on cpu_id.

  2. The output of the operator is directly written to the tensor from get_dst_tensor(). The return t_out; statement is ignored during execution.

Run the code, the expected output on a trn1 instance is:

Inf throughput (iter/sec): 269.936119707143
----------End Inference ---------------

Details of the API used in the sample here can be found in Custom Operators API Reference Guide [Experimental].

This document is relevant for: Inf2, Trn1, Trn1n