This document is relevant for: Inf2, Trn1, Trn2

nki.isa.dropout#

nki.isa.dropout(data, prob, *, mask=None, dtype=None, **kwargs)[source]#

Randomly replace some elements of the input tile data with zeros based on input probabilities using Vector Engine. The probability of replacing input elements with zeros (i.e., drop probability) is specified using the prob field: - If the probability is 1.0, all elements are replaced with zeros. - If the probability is 0.0, all elements are kept with their original values.

The prob field can be a scalar constant or a tile of shape (data.shape[0], 1), where each partition contains one drop probability value. The drop probability value in each partition is applicable to the input data elements from the same partition only.

Data type of the input data tile can be any valid NKI data types (see Supported Data Types for more information). However, data type of prob has restrictions based on the data type of data:

  • If data type of data is any of the integer types (e.g., int32, int16), prob data type must be float32

  • If data type of data is any of the float types (e.g., float32, bfloat16), prob data can be any valid float type

The output data type of this instruction is specified by the dtype field. The output data type must match the input data type of data if input data type is any of the integer types. Otherwise, output data type can be any valid NKI data types. If output data type is not specified, it is default to be the same as input data type.

Estimated instruction cost:

max(MIN_II, N) Vector Engine cycles, where N is the number of elements per partition in data, and MIN_II is the minimum instruction initiation interval for small input tiles. MIN_II is roughly 64 engine cycles.

Parameters:
  • data – the input tile

  • prob – a scalar or a tile of shape (data.shape[0], 1) to indicate the probability of replacing elements with zeros

  • mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)

  • dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.

Returns:

an output tile of the dropout result

Example:

import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
from neuronxcc.nki.typing import tensor

###########################################################################
# Example 1: From an input tile a of shape [128, 512], dropout its values
# with probabilities in tile b of shape [128, 1] and store the result in c.
###########################################################################
a: tensor[128, 512] = nl.load(a_tensor)
b: tensor[128, 1] = nl.load(b_tensor)

c: tensor[128, 512] = nisa.dropout(a, prob=b)

nl.store(c_tensor, c)

######################################################
# Example 2: From an input tile a, dropout its values 
# with probability of 0.2 and store the result in b.
######################################################
a = nl.load(in_tensor)

b = nisa.dropout(a, prob=0.2)

nl.store(out_tensor, b)

This document is relevant for: Inf2, Trn1, Trn2