This document is relevant for: Trn1, Trn2, Trn3
nki.language.matmul#
- nki.language.matmul(x, y, transpose_x=False)[source]#
x @ y matrix multiplication of x and y.
Warning
This API is experimental and may change in future releases.
- Parameters:
x – a tile on SBUF (partition dimension <= 128, free dimension <= 128), x’s free dimension must match y’s partition dimension.
y – a tile on SBUF (partition dimension <= 128, free dimension <= 512).
transpose_x – defaults to False. If True, x is treated as already transposed. If False, an additional transpose will be inserted to make x’s partition dimension the contract dimension of the matmul to align with the Tensor Engine.
- Returns:
x @ y or x.T @ y if transpose_x=True.
Examples:
import nki.language as nl # nki.language.matmul -- identity.T @ ones = ones x = nl.shared_identity_matrix(n=128, dtype=nl.float32) y = nl.full((128, 128), 1.0, dtype=nl.float32, buffer=nl.sbuf) result_psum = nl.matmul(x, y, transpose_x=True) result = nl.ndarray((128, 128), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_copy(result, result_psum) expected = nl.full((128, 128), 1.0, dtype=nl.float32, buffer=nl.sbuf) assert nl.equal(result, expected)
This document is relevant for: Trn1, Trn2, Trn3