This document is relevant for: Inf2
, Trn1
, Trn2
nki.language.shared_identity_matrix#
- nki.language.shared_identity_matrix(n, dtype=<class 'numpy.uint8'>, **kwargs)[source]#
Create a new identity tensor with specified data type.
This function has the same behavior to nki.language.shared_constant but is preferred if the constant matrix is an identity matrix. The compiler will reuse all the identity matrices of the same dtype in the graph to save space.
- Parameters:
n – the number of rows(and columns) of the returned identity matrix
dtype – the data type of the tensor, default to be
np.uint8
(see Supported Data Types for more information).
- Returns:
a tensor which contains the identity tensor
This document is relevant for: Inf2
, Trn1
, Trn2