This document is relevant for: Inf2
, Trn1
, Trn1n
nki.FrameworkKernel#
- class nki.FrameworkKernel[source]#
NKI kernels are represeted as XLA CustomCall instructions in HLO. This class facilitates the HLO generation for NKI kernels.
For example, a kernel that read from the first two tensors, and write to its last argument in python,
def example_kernel(in1, in2, out): # Actual kernel content omitted pass
should be mapped to the following HLO instruction,
%custom-call.2 = f32[16,8,128,512]{3,2,1,0} custom-call( f32[16,8,128,512]{3,2,1,0} %p2.2, f32[16,8,128,512]{3,2,1,0} %p1.2), custom_call_target="AwsNeuronCustomNativeKernel", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_NkiKernelCallImpl" op_name="xla___op_NkiKernelCallImpl"}, backend_config= # ...omitted
It is important to notice that, although in Python, NKI kernels use pass-by-reference semantics, the corresponding HLO instruction returns the output tensor.
The field api_version is optional. The field metadata is optional debug information, developer could elect to pass op_type and op_name, the information will show up in the profile using neuron-profiler. The custom_call_target should always be “AwsNeuronCustomNativeKernel”.
Framework developers should inherit this class and implement the following methods.
translate_to_neuron_dtype
is_framework_tensor
map_framework_tensor
Then backend_config can be obtained by calling dump_config(*args, **kwargs).
As an example, suppose we have correctly implemented a PyTorch variant of this class, i.e. PyTorchFrameWorkKernel(FrameworkKernel), then we can generate the backend_config for the HLO instruction example with the following.
in1 = torch.rand((16, 8, 128, 512), dtype=torch.float32) in2 = torch.rand((16, 8, 128, 512), dtype=torch.float32) out = torch.rand((16, 8, 128, 512), dtype=torch.float32) kernel = PyTorchFrameworkKernel(func_name=example_kernel.__name__, func=example_kernel, grid=(16, 8)) kernel.dump_config(in1, in2, out) # Dump config based on inputs # Omitted, config string specialized for (16, 8, 12, 512) in3 = torch.rand((16, 8, 64, 1024), dtype=torch.float32) in4 = torch.rand((16, 8, 64, 1024), dtype=torch.float32) out = torch.rand((16, 8, 64, 1024), dtype=torch.float32) kernel = PyTorchFrameworkKernel(func_name=example_kernel.__name__, func=example_kernel, grid=(16, 8)) kernel.dump_config(in3, in4, out=out) # Dump config based on inputs # Omitted, config string specialized for (16, 8, 64, 1024)
The kernel should be called for each set of different input tensor shapes configuration.
Methods
Returns the backend_config, the list of input names and the list of the output name, based on given arguments.
Return true if and only if t should be treated as a tensor.
Take in a framework tensor, returns the shape of tensor and its type in a tuple.
Translate a framework dtype to neuron specific dtype representation in numpy or neuron specific dtype.
- dump_config(*args, **kwargs)[source]#
Returns the backend_config, the list of input names and the list of the output name, based on given arguments.
If self.enable_cache is True, dump_config will try to retrieve the results from the cache using args, kwargs and the spmd launch grid and other kernel attributes as key to identify the identical backend_config.
Otherwise, dump_config will always generate new backend_config.
- is_framework_tensor(t)[source]#
Return true if and only if t should be treated as a tensor. Parameter that returns false must be constants known at compile time.
As an example, for PyTorch,
>>> is_framework_tensor(torch.rand((2, 3))) True >>> is_framework_tensor("this is not a tensor") False
- map_framework_tensor(t)[source]#
Take in a framework tensor, returns the shape of tensor and its type in a tuple. This function should only be called on t where is_framework_tensor(t) returns True.
As an example, for PyTorch,
>>> map_framework_tensor(torch.rand((2, 3), dtype=torch.bfloat16)) (torch.Size([2, 3]), torch.bfloat16)
This document is relevant for: Inf2
, Trn1
, Trn1n