This document is relevant for: Inf2
, Trn1
, Trn2
LayerNorm#
In this tutorial, we implement a kernel to perform LayerNorm of a 2D tensor, as described in Layer Normalization. LayerNorm is a common normalization mechanism used in Transformer models, similar to RMSNorm. However, LayerNorm requires more vector operations to optimize compute efficiency in Vector Engine. In doing so, we will revisit the key concepts we learned in the RMSNorm and additionally learn about:
Using nki.isa APIs to efficiently compute mean and variance, and minimize the number of traversals over input data by combining multiple vector instructions into one
Take surrounding compute into consideration when deciding tensor layouts
Before diving into LayerNorm for a 2D tensor, let’s go over the LayerNorm operator for a
1D vector y
defined as below:
The parameters are:
\(x\): Input 1D vector
\(y\): Output 1D vector, same shape as x
\(\mathbb{E}[x]\): Mean of x
\(\mathrm{var}[x]\): Variance of x
\(\epsilon\) : A small constant scalar for numerical stability
\(\gamma\), \(\beta\): LayerNorm affine transform parameters, each has the same shape as x
In Transformer models, we typically need to perform LayerNorm on a 2D input tensor
(with shape: [sequence_length, hidden_size]
),
where the first dimension is sequence_length long corresponding to the number of tokens currently being processed, and the second dimension is the embedding dimension of each token.
Different tokens (i.e., rows in the [sequence_length, hidden_size]
2D vector) undergo different 1D LayerNorm independently.
Therefore, we need to calculate different mean and variance for different rows and broadcast (i.e., share) the same \(\gamma\) , \(\beta\) parameters across the rows.
Figure below visualizes the tensor shape involved in 2D-LayerNorm,
where input_tensor
is 2D input vector and gamma_vector
and beta_vector
are affine transform parameters:
Compared to RMSNorm, LayerNorm requires calculations of mean and variance, instead of simple square and summation. Also, LayerNorm performs two instances of free-axis broadcast and two instances of partition-axis broadcast, while RMSNorm requires one instance of each. Therefore, LayerNorm involves way more computation (vector operations in particular) than RMSNorm.
Implement NKI kernel#
Next, we will present two versions of LayerNorm implementation, starting from a naive version using nki.language APIs and ending with an optimized version using nki.isa APIs.
Version 1: nki.language APIs only#
1import neuronxcc.nki as nki
2import neuronxcc.nki.language as nl
3import neuronxcc.nki.isa as nisa
4import numpy as np
5import math
6
7@nki.jit
8def nki_layernorm_kernel_v1(input_tensor, epsilon, gamma_vector, beta_vector):
9 """Computes LayerNorm.
10 Used nki.language APIs only.
11 """
12 output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype,
13 buffer=nl.shared_hbm)
14
15 # Ensure that the shapes of tensors match
16 assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0]
17
18 # Generate tile indices for loading/storing data
19 i_p_io = nl.arange(nl.tile_size.pmax)[:, None]
20 i_f_io = nl.arange(input_tensor.shape[1])[None, :]
21 i_p_param = nl.arange(1)[:, None]
22
23 # Number of rows in the input tensor
24 num_rows = input_tensor.shape[0]
25
26 # Load gamma and beta, which will be reused across rows/tiles of input_tensor
27 gamma_sb = nl.load(gamma_vector.reshape((1, gamma_vector.shape[0]))[i_p_param, i_f_io])
28 beta_sb = nl.load(beta_vector.reshape((1, beta_vector.shape[0]))[i_p_param, i_f_io])
29
30 # Broadcast the gamma and beta to match the dimensions of the tiles
31 gamma_sb_bcast = gamma_sb.broadcast_to((nl.tile_size.pmax, gamma_vector.shape[0]))
32 beta_sb_bcast = beta_sb.broadcast_to((nl.tile_size.pmax, beta_vector.shape[0]))
33
34 # Tile partition dimension of the input tensor by nl.tile_size.pmax
35 for i in nl.affine_range(math.ceil(input_tensor.shape[0]/nl.tile_size.pmax)):
36 # Load input tile
37 input_sb = nl.load(input_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io],
38 mask=(i * nl.tile_size.pmax + i_p_io < num_rows))
39
40 # Compute mean and variance
41 mean = nl.mean(input_sb, axis=1)
42 # Trick to calculate var with mean: mean(x^2) - mean(x)^2
43 var = nl.mean(nl.square(input_sb), axis=1) - mean * mean
44
45 # Normalize the input by shifting with the mean
46 # and scaling with rsqrt of variance and epsilon
47 shift_scale_tensor = (input_sb - mean) * nl.rsqrt(var + epsilon)
48
49 # Scale the normalized tile using gamma and add beta
50 output_sb = shift_scale_tensor * gamma_sb_bcast + beta_sb_bcast
51
52 nl.store(output_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io], value=output_sb,
53 mask=(i * nl.tile_size.pmax + i_p_io < num_rows))
54
55 return output_tensor
To adhere to NKI’s tile-size considerations (Tile Size Considerations), we limit the partition axis size of
input_tensor
tile to be 128 (nl.tile_size.pmax).- Load
gamma
andbeta
, and perform the partition-axis broadcast: The multiplication with
shift_scale_tensor
requires broadcasting ofgamma
andbeta
across all partitions(broadcast_to()
API)
- Load
The trip count of the compute loop is
math.ceil(input_tensor.shape[0]/nl.tile_size.pmax)
. In cases whereinput_tensor.shape[0]
is not a multiple of nl.tile_size.pmax, we can disable ineffectual data movement or compute in the last iteration using themask
field.- Within the compute loop:
We load one tile of
input_tensor
with shape(nl.tile_size.pmax, input_tensor.shape[1])
usingnl.load
API. We guard the loading boundary by specifyingmask=(i * nl.tile_size.pmax + i_p_io < input_tensor.shape[0])
, which ensures we don’t access out-of-bound memory when the number of rows ininput_tensor
is not a multiple of nl.tile_size.pmax.Compute the
mean
andvariance
using nki.language.meanNormalize one tile of
input_tensor
usingmean
andvariance
. Thevariance
is preprocessed using nki.language.rsqrtScale the normalized tile using gamma and add beta
Finally, we store the normalized tile back into HBM using the nl.store API. We guard the store boundary similar to load boundary using the
mask
field.
Next, we will optimize the above implementation using nki.isa
APIs in version 2
Version 2: nki.isa
APIs to calculate mean/variance and perform shift/scale#
1@nki.jit
2def nki_layernorm_kernel_v2(input_tensor, epsilon, gamma_vector, beta_vector):
3 """Computes LayerNorm.
4 Used nki.isa APIs to calculate mean/variance and perform shift/scale.
5 """
6 output_tensor = nl.ndarray(input_tensor.shape, dtype=input_tensor.dtype,
7 buffer=nl.shared_hbm)
8
9 # Ensure that the shapes of tensors match
10 assert input_tensor.shape[1] == gamma_vector.shape[0] == beta_vector.shape[0]
11
12 # Generate tile indices for loading/storing data
13 i_p_io = nl.arange(nl.tile_size.pmax)[:, None]
14 i_f_io = nl.arange(input_tensor.shape[1])[None, :]
15 i_p_param = nl.arange(1)[:, None]
16
17 # Number of rows in the input tensor
18 num_rows = input_tensor.shape[0]
19
20 # Load gamma and beta, which will be reused across rows/tiles of input_tensor
21 gamma_sb = nl.load(gamma_vector.reshape((1, gamma_vector.shape[0]))[i_p_param, i_f_io])
22 beta_sb = nl.load(beta_vector.reshape((1, beta_vector.shape[0]))[i_p_param, i_f_io])
23
24 # Broadcast the gamma and beta to match the dimensions of the tiles
25 gamma_sb_bcast = gamma_sb.broadcast_to((nl.tile_size.pmax, gamma_vector.shape[0]))
26 beta_sb_bcast = beta_sb.broadcast_to((nl.tile_size.pmax, beta_vector.shape[0]))
27
28 # Tile partition dimension of the input tensor by nl.tile_size.pmax
29 for i in nl.affine_range(math.ceil(input_tensor.shape[0]/nl.tile_size.pmax)):
30 # Load input tile
31 input_sb = nl.load(input_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io],
32 mask=(i * nl.tile_size.pmax + i_p_io < num_rows))
33
34 # Tile free dimension of the input tensor by nl.tile_size.bn_stats_fmax,
35 # as bn_stats has a free dimension size limit
36 i_f_bn = nl.arange(nl.tile_size.bn_stats_fmax)[None, :]
37 i_f_stats = nl.arange(6)[None, :]
38 num_bn_stats = math.ceil(input_tensor.shape[1]/nl.tile_size.bn_stats_fmax)
39 stats_results = nl.ndarray((nl.tile_size.pmax, 6*num_bn_stats), dtype=np.float32)
40 for j in nl.affine_range(num_bn_stats):
41 stats_results[i_p_io, j * 6 + i_f_stats] = nisa.bn_stats(
42 input_sb[i_p_io, j * nl.tile_size.bn_stats_fmax + i_f_bn],
43 mask=(j * nl.tile_size.bn_stats_fmax + i_f_bn < input_tensor.shape[1]),
44 dtype=np.float32)
45
46 # Aggregate bn_stats results to compute mean and var
47 i_f_aggr = nl.arange(6*num_bn_stats)[None, :]
48 mean_var = nisa.bn_aggr(stats_results[i_p_io, i_f_aggr])
49 mean = mean_var[i_p_io, 0]
50 var = mean_var[i_p_io, 1]
51
52 # Get reciprocal of sqrt(var + epsilon)
53 scale_var = nl.rsqrt(var + epsilon)
54
55 # Putting the shift and scale together in one line to trigger two alu_op tensor_vector instruction
56 # shift_scale_tensor = (input_sb - mean_var[i_p_stats, i_f_mean]) * scale_var
57 shift_scale_tensor = nisa.tensor_scalar(data=input_sb, op0=np.subtract,
58 operand0=mean,
59 op1=np.multiply,
60 operand1=scale_var)
61
62 # Scale the normalized tile using gamma and add beta
63 output_sb = shift_scale_tensor * gamma_sb_bcast + beta_sb_bcast
64
65 nl.store(output_tensor[i * nl.tile_size.pmax + i_p_io, i_f_io], value=output_sb,
66 mask=(i * nl.tile_size.pmax + i_p_io < num_rows))
67
68 return output_tensor
Considering the free dimension size limit of nki.isa.bn_stats, which is 512(nl.tile_size.bn_stats_fmax), the trip count of bn_stats compute loop is
math.ceil(input_tensor.shape[1]/nl.tile_size.bn_stats_fmax)
.Used nki.isa.bn_stats and nki.isa.bn_aggr to calculate the mean and variance
Used nki.isa.tensor_scalar to do shift and scale of mean and variance in a single instruction
Performance in Version 1 and Version 2#
Let’s assume the data type for the kernel is float32 and that the SBUF partition is sufficiently large to hold the intermediate data simultaneously without significant spilling.
Define the variable N
= input_tensor.shape[1]
.
- Compute mean and variance:
Version 1 : The performance cost of the mean calculation is
N
Vector Engine cycles, and the variance calculation isN
Scalar Engine +2N
Vector Engine cycles.Version 2 : By replacing these calculations with bn_stats and bn_aggr APIs, the cost is roughly reduced to
N
Vector Engine cycles, ignoring the cost of nki.isa.bn_aggr, assumingN
is sufficiently large.
- Perform shift and scale of mean and variance in a single instruction:
Version 1 : The performance cost of the shift/scale calculation requires two small instructions (nl.rsqrt(var + epsilon)) and two instructions with each iterating over
N
elements per partition (shift and scale,2N
).Version 2 : By replacing these calculations with the tensor_scalar API, the cost is reduced to
N
Vector Engine cycles
The latency measured on trn1 using an input tensor of (300, 1000) shows a 14.9% improvement.
>>>> Running version v1.
Latency results are:
NCLatency:
p0 = 2306us
p1 = 2306us
p10 = 2308us
p25 = 2309us
p50 = 2311us
p90 = 2313us
p99 = 2314us
p100 = 2314us
>>>> Running version v2.
Latency results are:
NCLatency:
p0 = 1963us
p1 = 1963us
p10 = 1965us
p25 = 1966us
p50 = 1969us
p90 = 1972us
p99 = 1974us
p100 = 1975us
Launching kernel and testing correctness#
Below is a reference PyTorch implementation of LayerNorm, which we use to verify our NKI kernel output against the reference output
1import torch
2from torch_xla.core import xla_model as xm
3import argparse
4import os
5
6os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
7
8# Reference torch implementation
9def layernorm_layer(input_tensor, epsilon, gamma_vector, beta_vector):
10 # Compute the mean and variance of the input tensor along the last dimension
11 mean = input_tensor.mean(dim=-1, keepdim=True)
12 variance = input_tensor.var(dim=-1, keepdim=True, unbiased=False)
13 # Subtract the mean from the input and divide by the square root of the variance plus epsilon
14 normalized_input = (input_tensor - mean) / torch.sqrt(variance + epsilon)
15 # Apply the affine transformation
16 normalized_input = normalized_input * gamma_vector + beta_vector
17 return normalized_input
18
19def parse_args():
20 parser = argparse.ArgumentParser(
21 """Run LayerNorm pytorch implementation.
22 """)
23 parser.add_argument("--nrows",
24 default=4*1024,
25 type=int,
26 help="""The number of input rows""")
27 parser.add_argument("--ncols",
28 default=8*1024,
29 type=int,
30 help="""The number of input columns""")
31 parser.add_argument("--version",
32 default="v1",
33 choices=["v1", "v2"],
34 help="Test versions")
35 args = parser.parse_args()
36 return args
37
38
39from neuronxcc.nki.docs.examples.layernorm.layernorm_nki_kernel import nki_layernorm_kernel_v1, \
40 nki_layernorm_kernel_v2
41
42if __name__ == "__main__":
43 args = parse_args()
44 func_dict = {"v1": nki_layernorm_kernel_v1,
45 "v2": nki_layernorm_kernel_v2,
46 }
47
48 device = xm.xla_device()
49 num_rows = args.nrows
50 num_cols = args.ncols
51
52 # Generate toy example
53 input_tensor = torch.rand((num_rows, num_cols), dtype=torch.float32)
54 gamma_vector = torch.rand((num_cols), dtype=torch.float32)
55 beta_vector = torch.rand((num_cols), dtype=torch.float32)
56 epsilon = 1e-5
57
58 # Compute torch layernorm layer in cpu
59 output_torch = layernorm_layer(input_tensor, epsilon, gamma_vector, beta_vector)
60
61 # Copy tensors to NeuronDevice
62 input_tensor = input_tensor.to(device=device)
63 gamma_vector = gamma_vector.to(device=device)
64 beta_vector = beta_vector.to(device=device)
65
66 print(f">>>> Running version {args.version}.")
67 func = func_dict[args.version]
68
69 # add nki_jit decorator
70
71 # Compute NKI layernorm kernel in NeuronDevice
72 xm.mark_step()
73 output_nki = func(input_tensor, epsilon, gamma_vector, beta_vector)
74 xm.mark_step()
75 output_nki = output_nki.to(device='cpu')
76
77 # Accuracy check : Compare the output tensors
78 allclose = torch.allclose(output_torch, output_nki, atol=1e-3, rtol=1e-2)
79 if allclose:
80 print("NKI and Torch match")
81 else:
82 print("NKI and Torch differ")
Download All Source Code#
Click the links to download source code of the kernels and the testing code discussed in this tutorial.
PyTorch reference implementation:
layernorm_torch.py
Two versions of NKI kernels:
layernorm_nki_kernel.py
You can also view the source code in the GitHub repository nki_samples
Example usage of the scripts#
Performance mode
Check the performance numbers for nki_layernorm_kernel_v1 and nki_layernorm_kernel_v2, and generate NEFF files for profiling:
python3 layernorm_nki_kernel.py --mode perfs
Accuracy mode
Check NKI kernel accuracy against PyTorch implementation:
python3 layernorm_torch.py --version v1
python3 layernorm_torch.py --version v2
Check optimized Layernorm kernel(nki_layernorm_kernel_v2) accuracy against nki_layernorm_kernel_v1:
python3 layernorm_nki_kernel.py --mode accuracy
Input tensor size
python3 layernorm_torch.py --nrows 4096 --ncols 8192
python3 layernorm_nki_kernel.py --nrows 4096 --ncols 8192
This document is relevant for: Inf2
, Trn1
, Trn2