This document is relevant for: Trn2, Trn3

nki.language.shared_identity_matrix#

nki.language.shared_identity_matrix(n, dtype='uint8', dst=None)[source]#

Create an identity matrix in SBUF with the specified data type.

The compiler will reuse all identity matrices of the same dtype in the graph to save space.

Parameters:
  • n – the number of rows (and columns) of the returned identity matrix

  • dtype – the data type of the tensor, default to be nl.uint8 (see Supported Data Types for more information).

Returns:

a new NkiTensor which contains the identity tensor

Examples:

import nki.language as nl

# nki.language.shared_identity_matrix -- 128x128 identity matrix
identity = nl.shared_identity_matrix(n=128, dtype=nl.float32)
expected = nl.load(expected_tensor[0:128, 0:128])
assert nl.equal(identity, expected)
nl.store(actual_tensor[0:128, 0:128], identity)

This document is relevant for: Trn2, Trn3