This document is relevant for: Inf2, Trn1, Trn1n

RMSNorm#

In this tutorial, we implement a kernel to perform RMSNorm of a 2D tensor, as described in Root Mean Square Layer Normalization. In doing so, we learn about:

  • The NKI syntax and programming model

  • Broadcasting tensors in different axis

  • Mapping embarrassingly parallel vector operations efficiently to the NeuronCore

  • Disable ineffectual data movement or compute within a tile using an execution mask

Before diving into RMSNorm of 2D input, let’s go over the RMSNorm operator for a 1D vector a defined as below:

\[\bar{a_i} = \frac{a_i}{\text{RMS}(a)}g_i,\text{ where RMS}(a) = \sqrt{\frac{1}{n}\sum_{i=0}^n{a_i^2}}\]

Note, g is the RMSNorm weight, which has the same shape as the input vector a. The function RMS(a) produces a single scalar element, and we divide every element in the input vector a by the RMS(a) scalar (i.e., a broadcast divide).

In Transformer models, we typically perform RMSNorm on a 2D input tensor instead (with shape [sequence length, embedding size]). 2D-RMSNorm simply performs 1D-RMSNorm as discussed above for every row of the input 2D tensor. The g RMSNorm weight vector is shared (i.e., broadcasted) across the rows for the multiplication. Figure below visualizes the tensor shapes involved in 2D-RMSNorm, where a_tensor is the 2D input tensor and g_tensor is the 1D RMSNorm weight:

../../../_images/rmsnorm-tensor.png

Fig. 78 RMSNorm tensor shapes#

We are going to map the rows (a_tensor.shape[0]) to the partition dimension of the SBUF once we load the tensor from HBM. This is a natural layout choice since each SBUF partition has a one-to-one mapping to a parallel vector lane in the compute engines for calculating RMS(a_tensor).

Note, the division of RMS(a_tensor) requires broadcasting of one scalar across all elements of a_tensor within each partition, which is considered a free-axis broadcast and supported by the flexible memory access pattern in hardware. On the other hand, the multiplication with g_tensor requires broadcasting of a vector across all partitions, which is considered a partition-axis broadcast and must invoke another instruction for the broadcasting (broadcast_to() API, details see below implementation) .

Compute kernel#

 1import math
 2import neuronxcc.nki as nki
 3import neuronxcc.nki.language as nl
 4
 5
 6def nki_rmsnorm_kernel(a_tensor, g_tensor, out_tensor):
 7  # Calculate out_tensor = a_tensor/RMS(a_tensor) * g_tensor
 8  # Where RMS(a_tensor) = sqrt((1/N) * sum(a_tensor * a_tensor))
 9  # and N = a_tensor.shape[1]
10  # Reduction (mean) is performed in the free (2nd) dimension
11
12  # Make sure shapes match
13  assert a_tensor.shape[1] == g_tensor.shape[0]
14  assert a_tensor.shape == out_tensor.shape
15
16  # Generate tensor indices to index input tensor
17  ix = nl.arange(128)[:, None]
18  iw = nl.arange(1)[:, None]
19  iy = nl.arange(a_tensor.shape[1])[None, :]
20
21  num_rows = a_tensor.shape[0]
22
23  # Load RMSNorm weight once, reused by rows/tiles of a_tensor
24  g_tile = nl.load(g_tensor.reshape((1, g_tensor.shape[0]))[iw, iy])
25
26  # Process 128 rows at a time due to 128-partition tile size limitation
27  # Since we're not reducing across the first dimension
28  # Tiles can be processed independently
29  for i in nl.affine_range(math.ceil(a_tensor.shape[0]/128)):
30
31    # Load input data from external memory to on-chip memory
32    a_tile = nl.load(a_tensor[i * 128 + ix, iy],
33                    mask=(i * 128 + ix < num_rows))
34
35    # Compute element-wise square of a_tensor
36    in_square = nl.square(a_tile)
37
38    # Calculate sum of squared elements, along last dimension
39    square_sum = nl.sum(in_square, axis=[1])
40
41    # Scale and get a reciprocal
42    mean = square_sum / a_tensor.shape[1]
43
44    # Take square root of mean and then reciprocal with
45    # rsqrt API (one ISA instruction)
46    rms_reciprocal = nl.rsqrt(mean)
47
48    # Scale the input tensor
49    out_tile = nl.multiply(a_tile, rms_reciprocal)
50
51    # Broadcast weight along first axis to match tensor shape
52    # num_rows_active = min(num_rows - i * 128, 128)
53    g_bcast = g_tile.broadcast_to((128, g_tensor.shape[0]))
54
55    # Multiply with the RMSNorm weight
56    out_tile[...] = nl.multiply(out_tile, g_bcast,
57                           mask=(i * 128 + ix < num_rows))
58
59    # store the addition results back to external memory (out_tensor)
60    nl.store(out_tensor[i * 128 + ix, iy], value=out_tile,
61            mask=(i * 128 + ix < num_rows))

In this example, we implement RMSNorm for a 2D input tensor in nki_rmsnorm_kernel:

  • We assume each SBUF partition is large enough to fit at least one row of a_tensor and one copy of g_tensor simultaneously.

  • We load g_tensor once into the SBUF outside the main loop that iterates over tiles of a_tensor to achieve maximum reuse. The g_tensor is reshaped into a 2D tensor because SBUF is a two-dimensional memory and hence expects at least two dimension for any SBUF tensor. A reshape of an HBM tensor without changing the underlying storage format is in fact a no-op with no performance cost in the final compiled executable.

  • To adhere to NKI’s tile-size considerations (Tile Size Considerations), we limit the partition axis size of g_tensor tile to be 128.

  • The trip count of the compute loop is math.ceil(a_tensor.shape[0]/128). In cases where a_tensor.shape[0] is not a multiple of 128, we can disable ineffectual data movement or compute in the last iteration using the mask field (discussions below).

  • Within the compute loop:

    • We load one tile of g_tensor with shape (128, g_tensor.shape[1]) using nl.load API. We guard the loading boundary by specifying mask=(i * 128 + ix < num_rows), which ensures we don’t access out-of-bound memory when the number of rows in a_tensor is not a multiple of 128.

    • We perform the free-axis broadcast multiply (division of RMS(a)) using nl.multiply(a_tile, rms_reciprocal), which is lowered into nki.isa.tensor_scalar instruction under the hood.

    • To broadcast multiply with the RMSNorm weight g_tensor, we need to perform a partition-axis broadcast of the g_tensor. The number of partitions to broadcast to depends on how many active rows are being normalized in the current loop iteration: min(num_rows - i * 128, 128). Next, we can do element-wise multiplication of the broadcasted g_tensor and the intermediate normalized tile out_tile, which is lowered into nki.isa.tensor_tensor instruction under the hood.

    • Finally, we store the normalized tile back into HBM using the nl.store API. We guard the store boundary similar to load boundary using the mask field.

Launching kernel and testing correctness#

PyTorch#

Below we write a reference PyTorch implementation of RMSNorm and verify our NKI kernel output against the reference in the same script as the kernel.

 1# Reference torch implementation
 2def torch_rmsnorm_kernel(a_tensor, g_tensor):
 3  # Square the tensor (element-wise)
 4  in_square = a_tensor.pow(2)
 5  # Calculate means in the free dimension
 6  mean = in_square.mean(dim=1, keepdim=True)
 7  # Scale by reciprocal of sqrt(mean)
 8  tensor = a_tensor * torch.rsqrt(mean)
 9
10  # Scale the output by the weight
11  return tensor * g_tensor
12
13from torch_xla.core import xla_model as xm
14device = xm.xla_device()
15
16nki_rmsnorm_kernel = nki_jit(nki_rmsnorm_kernel)
17
18a_tensor = torch.rand((250, 512), dtype=torch.float32).to(device=device)
19g_tensor = torch.rand((512), dtype=torch.float32).to(device=device)
20output_nki = torch.zeros((250, 512), dtype=torch.float32).to(device=device)
21
22nki_rmsnorm_kernel(a_tensor, g_tensor, output_nki)
23print(f"output_nki={output_nki}")
24
25output_torch = torch_rmsnorm_kernel(a_tensor, g_tensor)
26print(f"output_torch={output_torch}")
27
28if torch.allclose(output_torch, output_nki, atol=1e-5, rtol=1e-3):
29  print("NKI and Torch match")
30else:
31  print("NKI and Torch differ")

Output:

2024-07-27 15:22:50.000670:  7592  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-07-27 15:22:50.000672:  7592  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --target=trn1 --framework=XLA /tmp/ubuntu/neuroncc_compile_workdir/54c8e689-108c-433e-832a-f9282acdf114/model.MODULE_7170924315921358669+d41d8cd9.hlo_module.pb --output /tmp/ubuntu/neuroncc_compile_workdir/54c8e689-108c-433e-832a-f9282acdf114/model.MODULE_7170924315921358669+d41d8cd9.neff --verbose=35
DGE ON Levels: {'scalar_dynamic_offset', 'io'}
.
Compiler status PASS
output_nki=tensor([[0.8418, 1.3092, 0.7372,  ..., 0.1458, 0.8831, 0.2339],
        [0.1745, 0.3416, 0.1519,  ..., 0.3358, 0.1832, 0.4795],
        [0.0111, 1.1799, 0.8628,  ..., 0.3107, 0.8328, 0.5663],
        ...,
        [1.1213, 0.5449, 0.3020,  ..., 0.4050, 0.4838, 0.0834],
        [0.8246, 0.5027, 0.2745,  ..., 0.4069, 1.0456, 1.0978],
        [0.6415, 0.3637, 0.1462,  ..., 0.2441, 1.0535, 0.4138]],
       device='xla:0')
2024-07-27 15:22:51.000907:  7592  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-07-27 15:22:51.000908:  7592  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --target=trn1 --framework=XLA /tmp/ubuntu/neuroncc_compile_workdir/6d2046fc-c02d-4d3d-8746-50399ad50832/model.MODULE_18272098496972694952+d41d8cd9.hlo_module.pb --output /tmp/ubuntu/neuroncc_compile_workdir/6d2046fc-c02d-4d3d-8746-50399ad50832/model.MODULE_18272098496972694952+d41d8cd9.neff --verbose=35
DGE ON Levels: {'scalar_dynamic_offset', 'io'}
.
Compiler status PASS
output_torch=tensor([[0.8418, 1.3092, 0.7372,  ..., 0.1458, 0.8831, 0.2339],
        [0.1745, 0.3416, 0.1519,  ..., 0.3358, 0.1832, 0.4795],
        [0.0111, 1.1799, 0.8628,  ..., 0.3107, 0.8328, 0.5663],
        ...,
        [1.1213, 0.5449, 0.3020,  ..., 0.4050, 0.4838, 0.0834],
        [0.8246, 0.5027, 0.2745,  ..., 0.4069, 1.0456, 1.0978],
        [0.6415, 0.3637, 0.1462,  ..., 0.2441, 1.0535, 0.4138]],
       device='xla:0')
2024-07-27 15:22:53.000466:  7592  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-07-27 15:22:53.000467:  7592  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --target=trn1 --framework=XLA /tmp/ubuntu/neuroncc_compile_workdir/32c983cd-2c40-4723-8342-d4422107708c/model.MODULE_968738949480579147+d41d8cd9.hlo_module.pb --output /tmp/ubuntu/neuroncc_compile_workdir/32c983cd-2c40-4723-8342-d4422107708c/model.MODULE_968738949480579147+d41d8cd9.neff --verbose=35
DGE ON Levels: {'io', 'scalar_dynamic_offset'}
.
Compiler status PASS
NKI and Torch match

JAX#

Below we write a reference JAX implementation of RMSNorm and verify our NKI kernel output against the reference in the same script as the kernel.

 1# Reference JAX implementation
 2def jax_rms_norm(a_tensor, g_tensor):
 3  # Square the tensor (element-wise)
 4  in_square = jnp.square(a_tensor)
 5  # Calculate means in the free dimension
 6  mean = in_square.mean(axis=1, keepdims=True)
 7  # Scale by reciprocal of sqrt(mean)
 8  tensor = a_tensor * jnp.reciprocal(jnp.sqrt(mean))
 9
10  # Scale the output by the weight
11  return tensor * g_tensor
12
13a_key, g_key = jax.random.split(jax.random.PRNGKey(42))
14a_tensor = jax.random.uniform(a_key, (250, 512))
15g_tensor = jax.random.uniform(g_key, (512,))
16
17output_nki = nki_call(
18  nki_rmsnorm_kernel,
19  a_tensor, g_tensor,
20  out_shape=jax.ShapeDtypeStruct(a_tensor.shape, dtype=a_tensor.dtype),
21)
22
23print(a_tensor)
24
25print(f"output_nki={output_nki}")
26
27output_jax = jax_rms_norm(a_tensor, g_tensor)
28print(f"output_jax={output_jax}")
29
30if jnp.allclose(output_jax, output_nki, atol=1e-5, rtol=1e-3):
31  print("NKI and JAX match")
32else:
33  print("NKI and JAX differ")

Download All Source Code#

Click the links to download source code of the kernels and the testing code discussed in this tutorial.

You can also view the source code in the Github repository nki_samples

Example usage of the scripts:#

Run NKI baremetal implementation:

python3 rmsnorm_nki_kernels.py

Run PyTorch implementation:

python3 rmsnorm_torch.py

Run JAX implementation:

python3 rmsnorm_jax.py

This document is relevant for: Inf2, Trn1, Trn1n