This document is relevant for: Inf2
, Trn1
, Trn1n
nki.isa.nc_matmul#
- nki.isa.nc_matmul(stationary, moving, is_stationary_onezero=False, is_moving_onezero=False, mask=None, **kwargs)[source]#
Compute
stationary.T @ moving
matrix multiplication using Tensor Engine.The
nc_matmul
instruction must read inputs from SBUF and write outputs to PSUM. Therefore, thestationary
andmoving
must be SBUF tiles, and the result tile is a PSUM tile.The nc_matmul instruction currently supports
float8/bfloat16/float16/tfloat32/float32
input data types as listed in Supported Data Types. The matmul accumulation and results are always in float32.The Tensor Engine imposes special layout constraints on the input tiles. First, the partition axis sizes of the
stationary
andmoving
tiles must be identical and<=128
, which corresponds to the contraction dimension of the matrix multiplication. Second, the free axis sizes ofstationary
andmoving
tiles must be<= 128
and<=512
, respectively, For example,stationary.shape = (128, 126)
;moving.shape = (128, 512)
andnc_matmul(stationary,moving)
returns a tile ofshape = (126, 512)
. For more information about the matmul layout, see Tensor Engine.If the contraction dimension of the matrix multiplication exceeds
128
, you may accumulate multiplenc_matmul
instruction output tiles into the same PSUM tile. See example code snippet below.Estimated instruction cost:
The Tensor Engine has complex performance characteristics given its data flow and pipeline design. The below formula is the average nc_matmul cost assuming many
nc_matmul
instructions of the same shapes running back-to-back on the engine:If input data type is one of
float8/bfloat16/float16/tfloat32
:max(min(64, N_stationary), N_moving)
Tensor Engine cycles, whereN_stationary
is the number of elements per partition instationary
tile andN_moving
is the number of elements per partition inmoving
tile.If input data type is
float32
: 4x higher than thefloat8/bfloat16/float16/tfloat32
instruction cost.
- Parameters:
stationary – the stationary operand on SBUF; layout: (partition axis
<= 128
, free axis<= 128
)moving – the moving operand on SBUF; layout: (partition axis
<= 128
, free axis<= 512
)mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
is_stationary_onezero – hints to the compiler whether the
stationary
operand is a tile with ones/zeros only; setting this field explicitly could lead to 2x better performance ifstationary
tile is in float32; the field has no impact for non-float32stationary
.is_moving_onezero – hints to the compiler if the
moving
operand is a tile with ones/zeros only; setting this field explicitly could lead to 2x better performance ifmoving
tile is in float32; the field has no impact for non-float32moving
.
- Returns:
a tile on PSUM that has the result of matrix multiplication of
stationary
andmoving
tiles; layout: partition axis comes from free axis ofstationary
, while free axis comes from free axis ofmoving
.
Example:
import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl ... ################################################################## # Example 1: # multiply matrix a of shape (128, 128) and matrix b of shape (128, 512) # to get matrix c in PSUM of shape (128, 512) ################################################################## i_p_a = nl.arange(128)[:, None] i_f_a = nl.arange(128)[None, :] i_p_b = nl.arange(128)[:, None] i_f_b = nl.arange(512)[None, :] a = nl.load(a_tensor[i_p_a, i_f_a]) b = nl.load(b_tensor[i_p_b, i_f_b]) c_psum = nisa.nc_matmul(a[i_p_a, i_f_a], b[i_p_b, i_f_b]) nl.store(c_tensor[i_p_a, i_f_b], c_psum) ################################################################## # Example 2: # multiply matrix d of shape (256, 128) and matrix e of shape (256, 512) # to get matrix f in PSUM of shape (128, 512) using psum accumulation ################################################################## f_psum = nl.zeros((128, 512), nl.float32, buffer=nl.psum) i_p_d = nl.arange(128)[:, None] i_f_d = nl.arange(128)[None, :] i_p_e = nl.arange(128)[:, None] i_f_e = nl.arange(512)[None, :] for i_contract in nl.affine_range(2): d = nl.load(d_tensor[i_contract * 128 + i_p_d, i_f_d]) e = nl.load(e_tensor[i_contract * 128 + i_p_e, i_f_e]) f_psum += nisa.nc_matmul(d[i_p_d, i_f_d], e[i_p_e, i_f_e]) nl.store(f_tensor[i_p_d, i_f_e], f_psum)
This document is relevant for: Inf2
, Trn1
, Trn1n