This document is relevant for: Trn1, Trn2, Trn3

nki.language.transpose#

nki.language.transpose(x, dtype=None)[source]#

Transposes a 2D tile between its partition and free dimension.

Warning

This API is experimental and may change in future releases.

Parameters:
  • x – 2D input tile.

  • dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.

Returns:

a tile that has the values of the input tile with its partition and free dimensions swapped.

Examples:

import nki.language as nl

# nki.language.transpose -- transpose of identity is identity
x = nl.shared_identity_matrix(n=128, dtype=nl.float32)
result_psum = nl.transpose(x)
result = nl.ndarray((128, 128), dtype=nl.float32, buffer=nl.sbuf)
nisa.tensor_copy(result, result_psum)
assert nl.equal(result, x)

This document is relevant for: Trn1, Trn2, Trn3