This document is relevant for: Trn2, Trn3
Linear Scan Kernel API Reference#
Compute first-order linear recurrence along the last dimension.
This kernel computes result[t] = decay[t] * result[t-1] + data[t] along the last dimension of the input tensors using nisa.tensor_tensor_scan. Supports arbitrary batch dimensions which are collapsed internally.
Background#
The linear_scan kernel computes a first-order linear recurrence along the last dimension, where each output element is the sum of the current input and the product of a decay coefficient with the previous output.
API Reference#
Source code for this kernel API can be found at: linear_scan.py
linear_scan#
- nkilib.experimental.scan.linear_scan(decay: nl.ndarray, data: nl.ndarray, initial: nl.ndarray = None) tuple#
Compute first-order linear recurrence along the last dimension.
- Parameters:
decay (
nl.ndarray) – Input HBM tensor of shape (…, P, L) containing multiplicative decay coefficients. dtype can be any NKI-supported type.data (
nl.ndarray) – Input HBM tensor of shape (…, P, L) containing additive input values. Must have same shape as decay.initial (
nl.ndarray) – Initial state tensor of shape (…, P, 1). If None, initial state is zero. Default: None.
- Returns:
(result, final_state) - result: HBM tensor with same shape as inputs, containing the scan output. - final_state: HBM tensor of shape (…, P, 1) containing the last state for each sequence, in float32.
- Return type:
nl.ndarray
Notes:
Only supports scan along the last dimension
Uses float32 accumulation internally for numerical stability
decay and data must have identical shapes and rank >= 2
For long sequences (>2048), the scan is tiled with carry propagation
Dimensions:
P: Partition dimension (second-to-last), tiled at P_MAX=128
L: Free dimension (last), tiled at F_TILE_SIZE=2048
This document is relevant for: Trn2, Trn3