This document is relevant for: Trn2, Trn3
Selective Scan Kernel API Reference#
Selective scan (SSM) as in Mamba models.
Performs fused discretization, linear recurrence, and output projection in a single kernel. For each state dimension n and time step t: decay[t] = exp(dt[t] * A[:, n]) inp[t] = dt[t] * x[t] * B[:, n, t] state[t] = decay[t] * state[t-1] + inp[t] y[t] += C[:, n, t] * state[t] y += D * x (optional skip connection)
Background#
The selective_scan kernel implements the selective scan state space model (SSM) used in Mamba models, performing fused discretization, linear recurrence, and output projection in a single kernel.
API Reference#
Source code for this kernel API can be found at: selective_scan.py
selective_scan#
- nkilib.experimental.scan.selective_scan(x: nl.ndarray, dt: nl.ndarray, A: nl.ndarray, B: nl.ndarray, C: nl.ndarray, D: nl.ndarray = None, initial_state: nl.ndarray = None) tuple#
Selective scan (SSM) as in Mamba models.
- Parameters:
x (
nl.ndarray) – Input tensor of shape [B_dim, channels, L].dt (
nl.ndarray) – Time step tensor of shape [B_dim, channels, L]. Should be positive.A (
nl.ndarray) – State transition matrix of shape [channels, state_size]. Typically negative.B (
nl.ndarray) – Input projection matrix of shape [B_dim, state_size, L].C (
nl.ndarray) – Output projection matrix of shape [B_dim, state_size, L].D (
nl.ndarray) – Skip connection weights of shape [channels]. Default: None.initial_state (
nl.ndarray) – Initial hidden state of shape [B_dim, channels, state_size]. Default: None (zeros).
- Returns:
(y, final_state) - y (nl.ndarray): Output tensor of shape [B_dim, channels, L] with same dtype as x. - final_state (nl.ndarray): Final hidden state of shape [B_dim, channels, state_size] in float32.
- Return type:
nl.ndarray
Notes:
Uses float32 accumulation internally for numerical stability
A should be negative for stable recurrence (decay < 1)
dt should be positive; discretization computes exp(dt * A)
The scan is sequential along the L dimension but parallel across channels
Accumulation across state dimensions uses SBUF per free tile to avoid HBM read-modify-write (which requires trn2 shared memory)
Carries between free tiles are stored in the final_state HBM tensor
Dimensions:
B_dim: Batch size
channels: Number of channels (partition dimension, tiled at P_MAX=128)
L: Sequence length (free dimension, tiled at F_TILE_SIZE=512)
This document is relevant for: Trn2, Trn3