This document is relevant for: Inf2, Trn1, Trn2

nki.language.spmd_dim#

nki.language.spmd_dim = Ellipsis#

Create a dimension in the SPMD launch grid of a NKI kernel with sub-dimension tiling.

A key use case for spmd_dim is to shard an existing NKI kernel over multiple NeuronCores without modifying the internal kernel implementation. Suppose we have a kernel, nki_spmd_kernel, which is launched with a 2D SPMD grid, (4, 2). We can shard the first dimension of the launch grid (size 4) over two physical NeuronCores by directly manipulating the launch grid as follows:

import neuronxcc.nki.language as nl


@nki.jit
def nki_spmd_kernel(a):
  b = nl.ndarray(a.shape, dtype=a.dtype, buffer=nl.shared_hbm)
  i = nl.program_id(0)
  j = nl.program_id(1)
  
  a_tile = nl.load(a[i, j])
  nl.store(b[i, j], a_tile)

  return b

############################################################################
# Example 1: Let compiler decide how to distribute the instances of spmd kernel
############################################################################
dst = nki_spmd_kernel[4, 2](src)

############################################################################
# Example 2: Distribute SPMD kernel instances to physical NeuronCores with
# explicit annotations. Expected physical NeuronCore assignments:
#   Physical NC [0]: kernel[0, 0], kernel[0, 1], kernel[1, 0], kernel[1, 1]
#   Physical NC [1]: kernel[2, 0], kernel[2, 1], kernel[3, 0], kernel[3, 1]
############################################################################
dst = nki_spmd_kernel[nl.spmd_dim(nl.nc(2), 2), 2](src)
dst = nki_spmd_kernel[nl.nc(2) * 2, 2](src)  # syntactic sugar

############################################################################
# Example 3: Distribute SPMD kernel instances to physical NeuronCores with
# explicit annotations. Expected physical NeuronCore assignments:
#   Physical NC [0]: kernel[0, 0], kernel[0, 1], kernel[2, 0], kernel[2, 1]
#   Physical NC [1]: kernel[1, 0], kernel[1, 1], kernel[3, 0], kernel[3, 1]
############################################################################
dst = nki_spmd_kernel[nl.spmd_dim(2, nl.nc(2)), 2](src)
dst = nki_spmd_kernel[2 * nl.nc(2), 2](src)  # syntactic sugar

This document is relevant for: Inf2, Trn1, Trn2