This document is relevant for: Trn1, Trn2, Trn3
nki.language.shared_identity_matrix#
- nki.language.shared_identity_matrix(n, dtype='uint8', dst=None)[source]#
Create a new identity tensor with specified data type.
This function has the same behavior to nki.language.shared_constant but is preferred if the constant matrix is an identity matrix. The compiler will reuse all the identity matrices of the same dtype in the graph to save space.
- Parameters:
n – the number of rows (and columns) of the returned identity matrix
dtype – the data type of the tensor, default to be
nl.uint8(see Supported Data Types for more information).
- Returns:
a tensor which contains the identity tensor
Examples:
import nki.language as nl # nki.language.shared_identity_matrix -- 128x128 identity matrix identity = nl.shared_identity_matrix(n=128, dtype=nl.float32) expected = nl.load(expected_tensor[0:128, 0:128]) assert nl.equal(identity, expected) nl.store(actual_tensor[0:128, 0:128], identity)
This document is relevant for: Trn1, Trn2, Trn3