This document is relevant for: Trn2, Trn3
FP8 Quantize Kernel API Reference#
Tensor-wise static FP8 quantization.
Multiplies input by quant_scale (= 1/dequant_scale), clips to [-FP8_MAXVAL, FP8_MAXVAL]. The caller can pre-compute quant_scale via nisa.reciprocal to avoid redundant computation when the same dequant_scale is reused (e.g., gate and up projections share gate_up_in_scale).
Background#
The static_quantization kernel performs tensor-wise static FP8 quantization by multiplying input values by a quantization scale and clipping to the FP8 representable range.
API Reference#
Source code for this kernel API can be found at: fp8_quantize.py
static_quantization#
- nkilib.core.quantization.static_quantization(hidden_state, input_dequant_scale, quant_scale=None, dtype=nl.float8_e4m3fn, sbm: Optional[BufferManager] = None, quantized=None)#
Tensor-wise static FP8 quantization.
row_quantization#
- nkilib.core.quantization.row_quantization(hidden_state, dtype=nl.float8_e4m3fn, sbm: Optional[BufferManager] = None, output_dtype=None, quantized=None, dequant_scale=None)#
Row-wise dynamic FP8 quantization.
pre_combine_dequant_scales#
- nkilib.core.quantization.pre_combine_dequant_scales(input_dequant_scale, weight_dequant_scale, sbm: Optional[BufferManager] = None)#
Pre-combine input and weight dequant scales: combined = w_dequant * in_dequant.
This document is relevant for: Trn2, Trn3