This document is relevant for: Inf2
, Trn1
, Trn2
nki.language.ds#
- nki.language.ds(start, size)[source]#
Construct a dynamic slice for simple tensor indexing.
import neuronxcc.nki.language as nl ... @nki.jit(mode="simulation") def example_kernel(in_tensor): out_tensor = nl.ndarray(in_tensor.shape, dtype=in_tensor.dtype, buffer=nl.shared_hbm) for i in nl.affine_range(in_tensor.shape[1] // 512): tile = nl.load(in_tensor[:, (i * 512):((i + 1) * 512)]) # Same as above but use ds (dynamic slice) instead of the native # slice syntax tile = nl.load(in_tensor[:, nl.ds(i * 512, 512)])
This document is relevant for: Inf2
, Trn1
, Trn2