Depthwise Conv1D Kernel API Reference#
Implements depthwise 1D convolution using implicit GEMM without full im2col materialization.
The kernel supports:
Depthwise 1D convolution with stride=1 and zero padding
Implicit GEMM approach for memory efficiency
LNC2 sharding on channel dimension
Optimized for TRN2 platform
Background#
The depthwise_conv1d_implicit_gemm kernel performs depthwise 1D convolution by loading input with shape [S_TILE, Q] where row k contains elements starting at index k (i.e., input[k:k+Q]), enabling implicit im2col via offset-based loading. This approach avoids materializing the full im2col matrix, saving W*S*C memory. The kernel tiles on S dimension for S > 128 and is optimized for TRN2 platform with LNC2 sharding on channel dimension.
API Reference#
Source code for this kernel API can be found at: depthwise_conv1d.py
depthwise_conv1d_implicit_gemm#
- nkilib.experimental.conv.depthwise_conv1d_implicit_gemm(img_ref: nl.ndarray, filter_ref: nl.ndarray, padding: tuple = ((0, 0), (0, 0)), stride: tuple = (1, 1), rhs_dilation: tuple = (1, 1), lhs_dilation: tuple = (1, 1), feature_group_count: int = 1, batch_group_count: int = 1, in_perm: tuple = None, kern_perm: tuple = None, out_perm: tuple = None) nl.ndarray#
Depthwise Conv1D using implicit GEMM without full im2col materialization.
Performs depthwise 1D convolution by loading input with shape [S_TILE, Q] where row k contains elements starting at index k (i.e., input[k:k+Q]), enabling implicit im2col via offset-based loading. Tiles on S dimension for S > 128. Optimized for TRN2 platform with LNC2 sharding on channel dimension.
- Parameters:
img_ref (
nl.ndarray) – Input tensor on HBM with shape [N, C, 1, W].filter_ref (
nl.ndarray) – Depthwise kernel weights on HBM with shape [C, 1, 1, S].padding (
tuple) – Padding as ((H_pad_l, H_pad_r), (W_pad_l, W_pad_r)). Default: ((0,0),(0,0)), only zeros supported.stride (
tuple) – Stride values. Default: (1, 1), only (1, 1) supported.rhs_dilation (
tuple) – RHS dilation. Default: (1, 1).lhs_dilation (
tuple) – LHS dilation. Default: (1, 1).feature_group_count (
int) – Number of feature groups. Default: 1.batch_group_count (
int) – Number of batch groups. Default: 1.in_perm (
tuple, optional) – Input permutation. Default: None.kern_perm (
tuple, optional) – Kernel permutation. Default: None.out_perm (
tuple, optional) – Output permutation. Default: None.
- Returns:
Convolution output on HBM with shape [N, C, 1, Q] where Q = W - S + 1.
- Return type:
nl.ndarray
Notes:
Only supports stride=1 and zero padding
Requires C to be divisible by NUM_SHARDS (2)
Uses LNC2 sharding on channel dimension
For depthwise convolution, feature_group_count must equal C
Dimensions:
N: Batch size
C: Number of channels
W: Input width (spatial dimension)
S: Kernel size
Q: Output width (W - S + 1)
Implementation Details#
The kernel implementation includes several key optimizations:
Implicit GEMM Approach: Avoids materializing full im2col matrix by using offset-based loading patterns, saving W*S*C memory.
Tiling Strategy: - Input: [N, C, W] tiled as [N, C_TILES, C_TILE] x [S_TILES, S_TILE, Q] - Filter: [C, S] tiled as [C_TILES, C_TILE] x [S_TILES, S_TILE] - Output: [N, C, Q] accumulated in [Q_TILES, Q_TILE] chunks
Tile Size Selection: - S_TILE = min(S, 128): Matches partition dimension (P_MAX=128) - Q_TILE = min(Q, 512): Matches free dimension (F_MAX=512) - C_TILE = min(C_per_shard, 128): Balances parallelism and memory
Filter Preloading: Amortizes transpose cost across channels by preloading filter tiles in outer loop.
Sequential S-tile Accumulation: Enables pipelining and reduces PSUM pressure.
LNC2 Sharding: Distributes computation across channel dimension for parallel processing.