This document is relevant for: Inf2
, Trn1
, Trn2
nki.language.spmd_dim#
- nki.language.spmd_dim = Ellipsis#
Create a dimension in the SPMD launch grid of a NKI kernel with sub-dimension tiling.
A key use case for
spmd_dim
is to shard an existing NKI kernel over multiple NeuronCores without modifying the internal kernel implementation. Suppose we have a kernel,nki_spmd_kernel
, which is launched with a 2D SPMD grid, (4, 2). We can shard the first dimension of the launch grid (size 4) over two physical NeuronCores by directly manipulating the launch grid as follows:import neuronxcc.nki.language as nl @nki.jit def nki_spmd_kernel(a): b = nl.ndarray(a.shape, dtype=a.dtype, buffer=nl.shared_hbm) i = nl.program_id(0) j = nl.program_id(1) a_tile = nl.load(a[i, j]) nl.store(b[i, j], a_tile) return b ############################################################################ # Example 1: Let compiler decide how to distribute the instances of spmd kernel ############################################################################ dst = nki_spmd_kernel[4, 2](src) ############################################################################ # Example 2: Distribute SPMD kernel instances to physical NeuronCores with # explicit annotations. Expected physical NeuronCore assignments: # Physical NC [0]: kernel[0, 0], kernel[0, 1], kernel[1, 0], kernel[1, 1] # Physical NC [1]: kernel[2, 0], kernel[2, 1], kernel[3, 0], kernel[3, 1] ############################################################################ dst = nki_spmd_kernel[nl.spmd_dim(nl.nc(2), 2), 2](src) dst = nki_spmd_kernel[nl.nc(2) * 2, 2](src) # syntactic sugar ############################################################################ # Example 3: Distribute SPMD kernel instances to physical NeuronCores with # explicit annotations. Expected physical NeuronCore assignments: # Physical NC [0]: kernel[0, 0], kernel[0, 1], kernel[2, 0], kernel[2, 1] # Physical NC [1]: kernel[1, 0], kernel[1, 1], kernel[3, 0], kernel[3, 1] ############################################################################ dst = nki_spmd_kernel[nl.spmd_dim(2, nl.nc(2)), 2](src) dst = nki_spmd_kernel[2 * nl.nc(2), 2](src) # syntactic sugar
This document is relevant for: Inf2
, Trn1
, Trn2