This document is relevant for: Inf2, Trn1, Trn2

nki.language.shared_identity_matrix#

nki.language.shared_identity_matrix(n, dtype=<class 'numpy.uint8'>, **kwargs)[source]#

Create a new identity tensor with specified data type.

This function has the same behavior to nki.language.shared_constant but is preferred if the constant matrix is an identity matrix. The compiler will reuse all the identity matrices of the same dtype in the graph to save space.

Parameters:
  • n – the number of rows(and columns) of the returned identity matrix

  • dtype – the data type of the tensor, default to be np.uint8 (see Supported Data Types for more information).

Returns:

a tensor which contains the identity tensor

This document is relevant for: Inf2, Trn1, Trn2