This document is relevant for: Trn2, Trn3

FP8 Quantize Kernel API Reference#

Tensor-wise static FP8 quantization.

Multiplies input by quant_scale (= 1/dequant_scale), clips to [-FP8_MAXVAL, FP8_MAXVAL]. The caller can pre-compute quant_scale via nisa.reciprocal to avoid redundant computation when the same dequant_scale is reused (e.g., gate and up projections share gate_up_in_scale).

Background#

The static_quantization kernel performs tensor-wise static FP8 quantization by multiplying input values by a quantization scale and clipping to the FP8 representable range.

API Reference#

Source code for this kernel API can be found at: fp8_quantize.py

static_quantization#

nkilib.core.quantization.static_quantization(hidden_state, input_dequant_scale, quant_scale=None, dtype=nl.float8_e4m3fn, sbm: Optional[BufferManager] = None, quantized=None)#

Tensor-wise static FP8 quantization.

row_quantization#

nkilib.core.quantization.row_quantization(hidden_state, dtype=nl.float8_e4m3fn, sbm: Optional[BufferManager] = None, output_dtype=None, quantized=None, dequant_scale=None)#

Row-wise dynamic FP8 quantization.

pre_combine_dequant_scales#

nkilib.core.quantization.pre_combine_dequant_scales(input_dequant_scale, weight_dequant_scale, sbm: Optional[BufferManager] = None)#

Pre-combine input and weight dequant scales: combined = w_dequant * in_dequant.

This document is relevant for: Trn2, Trn3