This document is relevant for: Trn1, Trn2, Trn3

nki.language.shared_identity_matrix#

nki.language.shared_identity_matrix(n, dtype='uint8', dst=None)[source]#

Create a new identity tensor with specified data type.

This function has the same behavior to nki.language.shared_constant but is preferred if the constant matrix is an identity matrix. The compiler will reuse all the identity matrices of the same dtype in the graph to save space.

Parameters:
  • n – the number of rows (and columns) of the returned identity matrix

  • dtype – the data type of the tensor, default to be nl.uint8 (see Supported Data Types for more information).

Returns:

a tensor which contains the identity tensor

Examples:

import nki.language as nl

# nki.language.shared_identity_matrix -- 128x128 identity matrix
identity = nl.shared_identity_matrix(n=128, dtype=nl.float32)
expected = nl.load(expected_tensor[0:128, 0:128])
assert nl.equal(identity, expected)
nl.store(actual_tensor[0:128, 0:128], identity)

This document is relevant for: Trn1, Trn2, Trn3