This document is relevant for: Inf2, Trn1, Trn1n

nki.FrameworkKernel#

class nki.FrameworkKernel[source]#

NKI kernels are represeted as XLA CustomCall instructions in HLO. This class facilitates the HLO generation for NKI kernels.

For example, a kernel that read from the first two tensors, and write to its last argument in python,

def example_kernel(in1, in2, out):
    # Actual kernel content omitted
    pass

should be mapped to the following HLO instruction,

%custom-call.2 = f32[16,8,128,512]{3,2,1,0} custom-call(
f32[16,8,128,512]{3,2,1,0} %p2.2, f32[16,8,128,512]{3,2,1,0} %p1.2),
custom_call_target="AwsNeuronCustomNativeKernel",
api_version=API_VERSION_UNSPECIFIED,
metadata={op_type="xla___op_NkiKernelCallImpl" op_name="xla___op_NkiKernelCallImpl"},
backend_config= # ...omitted

It is important to notice that, although in Python, NKI kernels use pass-by-reference semantics, the corresponding HLO instruction returns the output tensor.

The field api_version is optional. The field metadata is optional debug information, developer could elect to pass op_type and op_name, the information will show up in the profile using neuron-profiler. The custom_call_target should always be “AwsNeuronCustomNativeKernel”.

Framework developers should inherit this class and implement the following methods.

  1. translate_to_neuron_dtype

  2. is_framework_tensor

  3. map_framework_tensor

Then backend_config can be obtained by calling dump_config(*args, **kwargs).

As an example, suppose we have correctly implemented a PyTorch variant of this class, i.e. PyTorchFrameWorkKernel(FrameworkKernel), then we can generate the backend_config for the HLO instruction example with the following.

in1 = torch.rand((16, 8, 128, 512), dtype=torch.float32)
in2 = torch.rand((16, 8, 128, 512), dtype=torch.float32)
out = torch.rand((16, 8, 128, 512), dtype=torch.float32)
kernel = PyTorchFrameworkKernel(func_name=example_kernel.__name__, func=example_kernel, grid=(16, 8))
kernel.dump_config(in1, in2, out) # Dump config based on inputs
# Omitted, config string specialized for (16, 8, 12, 512)
in3 = torch.rand((16, 8, 64, 1024), dtype=torch.float32)
in4 = torch.rand((16, 8, 64, 1024), dtype=torch.float32)
out = torch.rand((16, 8, 64, 1024), dtype=torch.float32)
kernel = PyTorchFrameworkKernel(func_name=example_kernel.__name__, func=example_kernel, grid=(16, 8))
kernel.dump_config(in3, in4, out=out) # Dump config based on inputs
# Omitted, config string specialized for (16, 8, 64, 1024)

The kernel should be called for each set of different input tensor shapes configuration.

Methods

dump_config

Returns the backend_config, the list of input names and the list of the output name, based on given arguments.

is_framework_tensor

Return true if and only if t should be treated as a tensor.

map_framework_tensor

Take in a framework tensor, returns the shape of tensor and its type in a tuple.

translate_to_neuron_dtype

Translate a framework dtype to neuron specific dtype representation in numpy or neuron specific dtype.

dump_config(*args, **kwargs)[source]#

Returns the backend_config, the list of input names and the list of the output name, based on given arguments.

If self.enable_cache is True, dump_config will try to retrieve the results from the cache using args, kwargs and the spmd launch grid and other kernel attributes as key to identify the identical backend_config.

Otherwise, dump_config will always generate new backend_config.

is_framework_tensor(t)[source]#

Return true if and only if t should be treated as a tensor. Parameter that returns false must be constants known at compile time.

As an example, for PyTorch,

>>> is_framework_tensor(torch.rand((2, 3)))
True
>>> is_framework_tensor("this is not a tensor")
False
map_framework_tensor(t)[source]#

Take in a framework tensor, returns the shape of tensor and its type in a tuple. This function should only be called on t where is_framework_tensor(t) returns True.

As an example, for PyTorch,

>>> map_framework_tensor(torch.rand((2, 3), dtype=torch.bfloat16))
(torch.Size([2, 3]), torch.bfloat16)
translate_to_neuron_dtype(_dtype)[source]#

Translate a framework dtype to neuron specific dtype representation in numpy or neuron specific dtype.

As an example, for PyTorch,

>>> result = translate_to_neuron_dtype(torch.bfloat16)
>>> result == neuronxcc.nki.language.bfloat16
True

This document is relevant for: Inf2, Trn1, Trn1n