This document is relevant for: Inf2, Trn1, Trn1n

nki.language.max#

nki.language.max(x, axis, dtype=None, mask=None, keepdims=False, **kwargs)[source]#

Maximum of elements along the specified axis (or axes) of the input.

((Similar to numpy.max))

Parameters:
  • x – a tile.

  • axis – int or tuple/list of ints. The axis (or axes) along which to operate; must be free dimensions, not partition dimension (0); can only be the last contiguous dim(s) of the tile: [1], [1,2], [1,2,3], [1,2,3,4]

  • dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.

  • mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)

  • keepdims – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

Returns:

a tile with the maximum of elements along the provided axis. This return tile will have a shape of the input tile’s shape with the specified axes removed.

This document is relevant for: Inf2, Trn1, Trn1n