T5 inference with Tensor Parallelism#

This is an extension to the t5 inference tutorial. Here we will use NeuronxDistributed to improve the inference performance using tensor parallelism.

This tutorial has the following main sections:

  1. Install dependencies

  2. Plug in NeuronxDistributed layers into T5

  3. Compile the T5 model

  4. Run distributed inference with beam search

This Jupyter notebook should be run on a Inf2 instance (inf2.24xlarge) or Trn1 isntance (trn1.32xlarge)

The tutorial works for t5 and flan-t5 models. In this notebook we will run distributed inference with flan-t5-xl.

Install dependencies#

The code in this tutorial is written for Jupyter Notebooks. To use Jupyter Notebook on the Neuron instance, you can use this guide.

Run the notebook by cloning aws-neuron-sdk

git clone https://github.com/aws-neuron/aws-neuron-sdk.git
cd aws-neuron-sdk/src/examples/pytorch/neuronx_distributed/t5-inference/

Once done execute t5-inference-tutorial.ipynb

It is recommended to go through the t5 inference tutorial before you start this tutorial. In addition to the dependencies in the t5 inference tutorial, we need to install neuronx-distributed.

This tutorial requires the following pip packages:

  • torch-neuronx

  • neuronx-cc

  • transformers

  • optimum-neuron

  • neuronx-distributed

Most of these packages will be installed when configuring your environment using the Trn1/Inf2 setup guide. The additional dependencies must be installed here:

[ ]:
! pip install --upgrade transformers==4.33.1 optimum-neuron neuronx_distributed
[ ]:
# Pull the latest version of the compiler
! pip install --upgrade neuronx-cc>=2.11 --no-deps
[ ]:
# Lets update numpy to a newer version
! pip install --upgrade numpy>=1.22.2 --no-deps

Plug in NeuronxDistributed layers into T5#

We extend the huggingface’s T5 model to use the NeuronxDistributed parallel layers. To do so, we simply swap linear layers in T5LayerSelfAttention, T5LayerCrossAttention, and T5LayerFF definitions with ColumnParallelLinear and RowParallelLinear. We also need to swap the Embedding layer with ParallelEmbedding.

Let us take the example of T5Attention. The attention block has q, k, v, and o linear layers. The multi-head attention block uses q, k and v to compute the attention scores. The attention scores are then passed through o to compute the attention block output. So let us swap q, k and v layers with ColumnParallelLinear and o with RowParallelLinear. Having RowParallelLinear following a ColumnParallelLinear is a performance optimization. The attention scores computed with q, k and v are already split across Neuron devices. The row parallel layer can use this shared output directly. The embedding layer is simply swapped with the ParallelEmbedding.

class ParallelAttention(T5Attention):
    def __init__(self, config: T5Config, has_relative_attention_bias=False):
        super().__init__(config, has_relative_attention_bias)
        # Per attention head and per partition values
        world_size = parallel_state.get_tensor_model_parallel_size()
        self.num_attention_heads_per_partition = divide(self.n_heads, world_size)
        self.hidden_size_per_partition = self.num_attention_heads_per_partition * self.key_value_proj_dim

        # Mesh TensorFlow initialization to avoid scaling before softmax
        self.q = ColumnParallelLinear(self.d_model,
        self.k = ColumnParallelLinear(self.d_model,
        self.v = ColumnParallelLinear(self.d_model,
        self.o = RowParallelLinear(self.inner_dim,

        if self.has_relative_attention_bias:
            self.relative_attention_bias = ParallelEmbedding(self.relative_attention_num_buckets, self.n_heads)
        self.n_heads = self.num_attention_heads_per_partition

You can find the all modified T5 layers defined in t5_model_layers.py.

Once we have the modified T5 layers, we can plug in the T5Attention and T5LayerFF into the pretrained model. Here is how you do that.

def load_pretrained_with_parallel_attn(model_name):

    model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")

    # Parallel implementation of Attention modules.
    from t5_model_layers import ParallelSelfAttention, ParallelFF, ParallelCrossAttention

    for index, block in enumerate(model.decoder.block):
        if index == 0:
            block.layer[0] = ParallelSelfAttention(model.config,
            block.layer[0] = ParallelSelfAttention(model.config)
        block.layer[1] = ParallelCrossAttention(model.config)
        block.layer[2] = ParallelFF(model.config)
    # Load the weights into the parallel layers
    neuronx_distributed.parallel_layers.load(model_name + ".pt", model, sharded=False)

    return model

Compile the parallel T5 model#

Let us set some model parameters.

[ ]:
model_name = "google/flan-t5-xl"
max_length = 128
num_beams = 4
tp_degree = 8 # tensor parallelism degree

Download and save the model that we want to trace.

[ ]:
import torch
from transformers import T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
torch.save({"model":model.state_dict()}, model_name.split("/")[-1] + ".pt")
model.config.use_cache = True

To run HuggingFace T5 models on Neuron, we need to make a couple of changes. Let us reuse the code from the t5 inference tutorial which makes T5 compatible with Neuron. For your convenience, the code copied into wrapper.py and t5_models.py. This notebook will import these files.

The only change made to this code is that we use neuronx_distributed.trace instead of torch_neuronx.trace.

Let us trace the encoder and decoder.

[ ]:
import t5_models
import neuronx_distributed
import time

# This can take up to 20 minutes
encoder_compile_start_time = time.time()
traced_encoder = t5_models.parallel_trace_encoder(model_name, max_length, num_beams, tp_degree)
print("Encoder compilation time {}".format(time.time() - encoder_compile_start_time))

neuronx_distributed.trace.parallel_model_save(traced_encoder, "TracedParallelEncoder.pt")
[ ]:
# This can take up to 15 minutes
decoder_compile_start_time = time.time()
traced_decoder = t5_models.parallel_trace_decoder(model, model_name, num_beams, max_length, tp_degree)
print("Decoder compilation time {}".format(time.time() - decoder_compile_start_time))

neuronx_distributed.trace.parallel_model_save(traced_decoder, "TracedParallelDecoder.pt")

Inference with the traced parallel T5 model#

With the traced model, let us try using beam search for inference.

[ ]:
import neuronx_distributed
from wrapper import T5Wrapper
from transformers import T5Tokenizer

num_return_sequences = 4

traced_encoder = neuronx_distributed.trace.parallel_model_load("TracedParallelEncoder.pt")
traced_decoder = neuronx_distributed.trace.parallel_model_load("TracedParallelDecoder.pt")

tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5Wrapper.from_pretrained(model_name)

model.encoder = traced_encoder
model.decoder = traced_decoder
setattr(model.encoder, 'main_input_name', 'input_ids')  # Attribute required by beam search

output = model.parallel_infer(tokenizer=tokenizer,
                              prompt="translate English to German: Lets eat good food.",

results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]

for i, summary in enumerate(results):
    print(i + 1, summary)

1 Lassen Sie uns gutes Essen essen.
2 Lassen Sie uns gut essen.
3 Lassen Sie uns gutes Essen zu essen.
4 Lassen Sie uns gutes Essen zu sich nehmen.


Let us benchmark the per token decoder latency

[ ]:
# Let us install NeuronPerf. We will use it to measure the performance.
! pip install neuronperf --extra-index-url=https://pip.repos.neuron.amazonaws.com
[ ]:
import os
import neuronperf as npf

d_model = model.config.d_model
model_dir = "TracedParallelDecoder.pt"
decoder_run_count = 128

def load_fn(model_path, **kwargs):
    return neuronx_distributed.trace.parallel_model_load(model_path)

# NeuronPerf can't see tp_degree at the moment, so just expose all cores
def env_setup_fn(*_):
    del os.environ["NEURON_RT_VISIBLE_CORES"]

def benchmark():

    # Create some sample inputs for the decoder
    decoder_input_ids = torch.ones((num_beams, 1), dtype=torch.int64)
    decoder_attention_mask = torch.ones((num_beams, max_length), dtype=torch.int32)
    encoder_attention_mask = torch.ones((num_beams, max_length), dtype=torch.int64)
    encoder_hidden_states = torch.ones((num_beams, max_length, d_model), dtype=torch.float32)
    beam_idx = torch.arange(0, num_beams, dtype=torch.int64)
    beam_scores = torch.zeros((num_beams,), dtype=torch.float)

    inputs = (decoder_input_ids,

    reports = npf.benchmark(
        workers_per_model=1,  # no bottleneck on model inputs, so 1 is fine

    report = reports[0]

    # let's update throughput to be tokens / second and add a new recor
    latency_in_s = report["latency_ms_avg"] / 1000
    tokens_per_s = decoder_run_count / latency_in_s
    report["throughput_avg"] = tokens_per_s

    # display and save results
    npf.print_reports(reports, cols=["throughput_avg", "latency_ms_p50", "latency_ms_p99"])
    print(f"Results saved to: {npf.write_json(reports[0])}")


Now lets benchmark inference as a whole including sampling.

[ ]:
import os
import torch
import neuronx_distributed
import neuronperf as npf

from transformers import T5Tokenizer
from wrapper import T5Wrapper

tokenizer = T5Tokenizer.from_pretrained(model_name)

generated_token_count = 0

class Wrapper(torch.nn.Module):
    def __init__(self,
        self.model = T5Wrapper.from_pretrained(model_name)
        self.model.encoder = traced_encoder
        self.model.decoder = traced_decoder
        setattr(self.model.encoder, 'main_input_name', 'input_ids')  # Attribute required by beam search

    def forward(self, *inputs):
        input_ids = inputs[0]['input_ids']
        attention_mask = inputs[0]['attention_mask']
        return self.model.parallel_infer(input_ids=input_ids,

def load_fn(filename, **kwargs):
    traced_encoder = neuronx_distributed.trace.parallel_model_load(filename + "TracedParallelEncoder.pt")
    traced_decoder = neuronx_distributed.trace.parallel_model_load(filename + "TracedParallelDecoder.pt")
    return Wrapper(traced_encoder, traced_decoder)

# NeuronPerf can't see tp_degree at the moment, so just expose all cores
def env_setup_fn(*_):
    del os.environ["NEURON_RT_VISIBLE_CORES"]

def preprocess_fn(inputs):

    encoding = []
    for text in inputs:
        batch_encoding = tokenizer(text,
        input_ids = batch_encoding['input_ids']
        attention_mask = batch_encoding['attention_mask']
        encoding.append({"input_ids": input_ids,
                         "attention_mask": attention_mask})
    return encoding

def postprocess_fn(outputs):
    output = [tokenizer.decode(seq) for seq in outputs]
    global generated_token_count
    generated_token_count = len(outputs[0])
    return output

def benchmark():
    inputs = ["summarize: The Inflation Reduction Act lowers prescription drug costs, health care costs, and energy costs. It's the most aggressive action on tackling the climate crisis in American history, which will lift up American workers and create good-paying, union jobs across the country. It'll lower the deficit and ask the ultra-wealthy and corporations to pay their fair share. And no one making under $400,000 per year will pay a penny more in taxes."]
    reports = npf.benchmark(
        "",   # Model dir
        max_duration=0,       # sampling can take a while, so let's not timeout

    report = reports[0]

    report["throughput_avg"] = round(generated_token_count / (report["latency_ms_avg"] / 1000), 2)
    report["latency_per_token_ms_p50"] = round((report["latency_ms_p50"])/generated_token_count, 2)
    report["latency_per_token_ms_p99"] = round((report["latency_ms_p99"])/generated_token_count, 2)

    # display and save results
    npf.print_reports(reports, cols=["throughput_avg", "latency_per_token_ms_p50", "latency_per_token_ms_p99"])
    print(f"Results saved to: {npf.write_json(report)}")