This document is relevant for: Trn2, Trn3

Ssd Kernel API Reference#

State Space Duality (SSD) scan for Mamba-2 models.

Performs chunk-wise parallel computation combining TensorE matmuls (intra-chunk structured attention) with VectorE cumulative sums (decay computation). For each chunk of size Q: 1. Cumulative decay: cs = cumsum(dt * A) 2. Intra-chunk: Y_intra = exp(cs) * ((CB * causal) @ (exp(-cs) * dt * x)) where CB = C @ B^T is the structured attention matrix 3. State-to-output: Y_off = exp(cs) * (C @ state) 4. State update: state = exp(cs[-1]) * state + B^T @ (dt * x * exp(cs[-1] - cs)) 5. Output: y = Y_intra + Y_off [+ D * x]

Background#

The ssd kernel implements the State Space Duality (SSD) scan for Mamba-2 models, combining chunk-wise parallel TensorE matmuls for intra-chunk structured attention with VectorE cumulative sums for decay computation.

API Reference#

Source code for this kernel API can be found at: ssd.py

ssd#

nkilib.experimental.scan.ssd(x: nl.ndarray, dt: nl.ndarray, A: nl.ndarray, B: nl.ndarray, C: nl.ndarray, chunk_size: int = 128, D: nl.ndarray = None, initial_state: nl.ndarray = None, causal_mask: nl.ndarray = None) tuple#

State Space Duality (SSD) scan for Mamba-2 models.

Parameters:
  • x (nl.ndarray) – [batch, nheads, seqlen, headdim], Input tensor.

  • dt (nl.ndarray) – [batch, nheads, seqlen], Timestep tensor. Should be positive.

  • A (nl.ndarray) – [nheads], State transition scalar per head. Should be negative.

  • B (nl.ndarray) – [batch, seqlen, dstate], Input projection. Shared across heads.

  • C (nl.ndarray) – [batch, seqlen, dstate], Output projection. Shared across heads.

  • chunk_size (int) – Chunk size Q. Must be <= 128 (compile-time constant).

  • D (nl.ndarray) – [nheads], Skip connection weights. Default: None.

  • initial_state (nl.ndarray) – [batch, nheads, dstate, headdim], Initial hidden state. Default: None (zeros).

  • causal_mask (nl.ndarray) – [Q, Q], Lower triangular mask. Required. Pass np.tril(np.ones((Q, Q), dtype=np.float32)).

Returns:

(y, final_state) - y (nl.ndarray): [batch, nheads, seqlen, headdim], Output tensor with same dtype as x. - final_state (nl.ndarray): [batch, nheads, dstate, headdim], Final hidden state in float32.

Return type:

nl.ndarray

Notes:

  • chunk_size <= 128 (must fit in partition dimension)

  • dstate <= 128 (for nc_transpose and matmul stationary free dim)

  • headdim <= 512 (PSUM free dimension limit on gen2/3)

  • seqlen must be divisible by chunk_size

  • ngroups=1 (B/C shared across all heads)

  • Uses float32 accumulation internally for numerical stability

  • A should be negative for stable dynamics (decay < 1)

  • dt should be positive; discretization computes exp(dt * A)

  • Inter-chunk state propagation is sequential; intra-chunk uses matmuls

Dimensions:

  • batch: Batch size

  • nheads: Number of attention heads

  • seqlen: Sequence length (must be divisible by chunk_size)

  • headdim: Head dimension (<= 512 for gen2/3 PSUM free dim limit)

  • dstate: SSM state dimension (<= 128 for nc_transpose and matmul)

This document is relevant for: Trn2, Trn3