This document is relevant for: Trn1, Trn2, Trn3
Dynamic Elementwise Add Kernel API Reference#
Elementwise addition with dynamic partition dimension tiling.
Computes output = input_a + input_b for 2D bf16 tensors where the number of M-dimension tiles to process is determined at runtime via num_m_tiles. Optimized for M dimensions up to 2048 and H dimensions up to 8192.
Background#
The dynamic_elementwise_add kernel computes elementwise addition where the number of M-dimension tiles to process is determined at runtime. This demonstrates NKI’s support for dynamic loop bounds using sequential_range with runtime-variable trip counts.
API Reference#
Source code for this kernel API can be found at: dynamic_elementwise_add.py
dynamic_elementwise_add#
- nkilib.experimental.dynamic_shapes.dynamic_elementwise_add(input_a: nl.ndarray, input_b: nl.ndarray, num_m_tiles: nl.ndarray) nl.ndarray#
Elementwise addition with dynamic partition dimension tiling.
- Parameters:
input_a (
nl.ndarray) – [M, H], First input tensor, bf16, on HBM.input_b (
nl.ndarray) – [M, H], Second input tensor, bf16, on HBM. Must match input_a shape.num_m_tiles (
nl.ndarray) – [1, 1], int32 scalar tensor on HBM. Value = number of M-tiles to process (0 <= num_m_tiles <= M // P_MAX).
- Returns:
[M, H], bf16 output tensor on HBM. Elements in the first (num_m_tiles * P_MAX) rows contain input_a + input_b; remaining rows are unmodified.
- Return type:
nl.ndarray
Notes:
M must be divisible by P_MAX (128)
H must be divisible by H_TILE_SIZE (512)
input_a and input_b must have identical shapes
Dimensions:
M: Row dimension, tiled at P_MAX (128). Dynamic at runtime via num_m_tiles.
H: Hidden/column dimension, tiled at H_TILE_SIZE (512). Static.
This document is relevant for: Trn1, Trn2, Trn3