This document is relevant for: Trn2, Trn3

Linear Scan Kernel API Reference#

Compute first-order linear recurrence along the last dimension.

This kernel computes result[t] = decay[t] * result[t-1] + data[t] along the last dimension of the input tensors using nisa.tensor_tensor_scan. Supports arbitrary batch dimensions which are collapsed internally.

Background#

The linear_scan kernel computes a first-order linear recurrence along the last dimension, where each output element is the sum of the current input and the product of a decay coefficient with the previous output.

API Reference#

Source code for this kernel API can be found at: linear_scan.py

linear_scan#

nkilib.experimental.scan.linear_scan(decay: nl.ndarray, data: nl.ndarray, initial: nl.ndarray = None) tuple#

Compute first-order linear recurrence along the last dimension.

Parameters:
  • decay (nl.ndarray) – Input HBM tensor of shape (…, P, L) containing multiplicative decay coefficients. dtype can be any NKI-supported type.

  • data (nl.ndarray) – Input HBM tensor of shape (…, P, L) containing additive input values. Must have same shape as decay.

  • initial (nl.ndarray) – Initial state tensor of shape (…, P, 1). If None, initial state is zero. Default: None.

Returns:

(result, final_state) - result: HBM tensor with same shape as inputs, containing the scan output. - final_state: HBM tensor of shape (…, P, 1) containing the last state for each sequence, in float32.

Return type:

nl.ndarray

Notes:

  • Only supports scan along the last dimension

  • Uses float32 accumulation internally for numerical stability

  • decay and data must have identical shapes and rank >= 2

  • For long sequences (>2048), the scan is tiled with carry propagation

Dimensions:

  • P: Partition dimension (second-to-last), tiled at P_MAX=128

  • L: Free dimension (last), tiled at F_TILE_SIZE=2048

This document is relevant for: Trn2, Trn3