This document is relevant for: Trn2, Trn3

Rng Kernel API Reference#

Retrieve the current RNG state from the GPSIMD engine.

Reads all 128 lanes of RNG state from the GPSIMD engine into SBUF, then copies only lane 0’s seeds to a new output HBM tensor. Input shape range is constant [1, NUM_RNG_SEEDS]

Background#

The get_rng_state_gpsimd kernel retrieves the current random number generator state from the GPSIMD engine by reading all 128 lanes of RNG state into SBUF and copying lane 0’s seeds to an output HBM tensor.

API Reference#

Source code for this kernel API can be found at: rng.py

get_rng_state_gpsimd#

nkilib.experimental.rng.get_rng_state_gpsimd(tensor_state: nl.ndarray)#

Retrieve the current RNG state from the GPSIMD engine.

Parameters:

tensor_state (nl.ndarray) – [1, NUM_RNG_SEEDS], dtype uint32, HBM tensor used only for shape/dtype reference.

Returns:

[1, NUM_RNG_SEEDS], dtype uint32, HBM tensor containing the 6 RNG seeds from lane 0.

Return type:

nl.ndarray

Dimensions:

  • L: Number of GPSIMD lanes (128)

set_rng_state_gpsimd#

nkilib.experimental.rng.set_rng_state_gpsimd(tensor_state: nl.ndarray)#

Set the RNG state for the GPSIMD engine by broadcasting seeds to all lanes.

Parameters:

tensor_state (nl.ndarray) – [1, NUM_RNG_SEEDS], dtype uint32, HBM tensor containing the 6 seeds to broadcast.

Returns:

[1, NUM_RNG_SEEDS], dtype uint32, HBM tensor echoing back the seeds that were set.

Return type:

nl.ndarray

Dimensions:

  • L: Number of GPSIMD lanes (128)

generate_random#

nkilib.experimental.rng.generate_random(output: nl.ndarray, n_elements: int)#

Generate random int32 values, tiling to fit SBUF.

Parameters:
  • output (nl.ndarray) – [1, n_elements], dtype int32, HBM tensor to be filled with random values.

  • n_elements (int) – Number of random int32 values to generate.

Returns:

[1, n_elements], dtype int32, HBM tensor filled with random values.

Return type:

nl.ndarray

Notes:

  • Uses sequential_range (not affine_range) due to loop-carried RNG state dependency

  • Remainder tile is handled separately after full tiles

Dimensions:

  • N: Number of random elements to generate (n_elements)

This document is relevant for: Trn2, Trn3