This document is relevant for: Trn2, Trn3
Fused Adam Kernel API Reference#
Adam optimizer kernel with L2 regularization.
This kernel implements the Adam optimizer with L2 weight regularization. For decoupled weight decay (AdamW), use adamw_kernel instead.
Background#
The adam_kernel and adamw_kernel implement fused Adam and AdamW optimizers on NeuronCore, processing parameter updates entirely on-device to avoid host-device round trips.
API Reference#
Source code for this kernel API can be found at: fused_adam.py
adam_kernel#
- nkilib.experimental.optimizer.adam_kernel(param_ptr, grad_ptr, exp_avg_ptr, exp_avg_sq_ptr, max_exp_avg_sq_ptr, step_size_ptr, inv_bc2_sqrt_ptr, wd_factor_ptr, numel, beta1, beta2, eps, amsgrad=False)#
Adam optimizer kernel with L2 regularization.
- Parameters:
param_ptr – [N], Parameter tensor on HBM
grad_ptr – [N], Gradient tensor on HBM
exp_avg_ptr – [N], First moment estimate tensor on HBM
exp_avg_sq_ptr – [N], Second moment estimate tensor on HBM
max_exp_avg_sq_ptr – [N], Max second moment tensor on HBM (AMSGrad only)
step_size_ptr – [P_MAX, 1], Bias-corrected step size on HBM
inv_bc2_sqrt_ptr – [P_MAX, 1], Inverse bias correction sqrt on HBM
wd_factor_ptr – [P_MAX, 1], Weight decay coefficient (lambda) on HBM
numel – Number of elements in the parameter tensor
beta1 – First moment decay factor (typically 0.9)
beta2 – Second moment decay factor (typically 0.999)
eps – Epsilon for numerical stability (typically 1e-8)
amsgrad – If True, use AMSGrad variant (default: False)
- Returns:
[N], Updated parameter tensor on HBM
- Return type:
nl.ndarray- Returns:
[N], Updated first moment tensor on HBM
- Return type:
nl.ndarray- Returns:
[N], Updated second moment tensor on HBM
- Return type:
nl.ndarray- Returns:
[N], Updated max second moment (if amsgrad=True)
- Return type:
nl.ndarray
Notes:
Uses L2 regularization: grad = grad + lambda * param
For decoupled weight decay (AdamW), use adamw_kernel instead
Dimensions:
N: Number of elements in the parameter tensor (numel)
adamw_kernel#
- nkilib.experimental.optimizer.adamw_kernel(param_ptr, grad_ptr, exp_avg_ptr, exp_avg_sq_ptr, max_exp_avg_sq_ptr, step_size_ptr, inv_bc2_sqrt_ptr, wd_factor_ptr, numel, beta1, beta2, eps, amsgrad=False)#
AdamW optimizer kernel with decoupled weight decay.
- Parameters:
param_ptr – [N], Parameter tensor on HBM
grad_ptr – [N], Gradient tensor on HBM
exp_avg_ptr – [N], First moment estimate tensor on HBM
exp_avg_sq_ptr – [N], Second moment estimate tensor on HBM
max_exp_avg_sq_ptr – [N], Max second moment tensor on HBM (AMSGrad only)
step_size_ptr – [P_MAX, 1], Bias-corrected step size on HBM
inv_bc2_sqrt_ptr – [P_MAX, 1], Inverse bias correction sqrt on HBM
wd_factor_ptr – [P_MAX, 1], Weight decay factor (1 - lr*lambda) on HBM
numel – Number of elements in the parameter tensor
beta1 – First moment decay factor (typically 0.9)
beta2 – Second moment decay factor (typically 0.999)
eps – Epsilon for numerical stability (typically 1e-8)
amsgrad – If True, use AMSGrad variant (default: False)
- Returns:
[N], Updated parameter tensor on HBM
- Return type:
nl.ndarray- Returns:
[N], Updated first moment tensor on HBM
- Return type:
nl.ndarray- Returns:
[N], Updated second moment tensor on HBM
- Return type:
nl.ndarray- Returns:
[N], Updated max second moment (if amsgrad=True)
- Return type:
nl.ndarray
Notes:
Uses decoupled weight decay: param = param * wd_factor - update
For L2 regularization (Adam), use adam_kernel instead
Dimensions:
N: Number of elements in the parameter tensor (numel)
This document is relevant for: Trn2, Trn3