This document is relevant for: Inf2, Trn1, Trn2

nki.language.ds#

nki.language.ds(start, size)[source]#

Construct a dynamic slice for simple tensor indexing.

import neuronxcc.nki.language as nl
...



@nki.jit(mode="simulation")
def example_kernel(in_tensor):
  out_tensor = nl.ndarray(in_tensor.shape, dtype=in_tensor.dtype,
                          buffer=nl.shared_hbm)
  for i in nl.affine_range(in_tensor.shape[1] // 512):
    tile = nl.load(in_tensor[:, (i * 512):((i + 1) * 512)])
    # Same as above but use ds (dynamic slice) instead of the native
    # slice syntax
    tile = nl.load(in_tensor[:, nl.ds(i * 512, 512)])

This document is relevant for: Inf2, Trn1, Trn2