T5 model inference on Trn1 or Inf2#

Introduction#

In this tutorial we will compile and deploy a pretrained T5 model for accelerated inference on Neuron.

This tutorial will use the t5-large model. The T5 model can be used for machine translation, document summarization, question answering, and classification tasks.

This tutorial has the following main sections:

  1. Install dependencies

  2. Compile the T5 model

  3. Run inference with greedy decoding on Neuron

  4. Run infernece with beam search on Neuron

This Jupyter notebook should be run on a Trn1 instance (trn1.2xlarge or larger.) or Inf2 instance (inf2.xlarge or larger.)

Install dependencies#

The code in this tutorial is written for Jupyter Notebooks. To use Jupyter Notebook on the Neuron instance, you can use this guide.

This tutorial requires the following pip packages:

  • torch-neuronx

  • neuronx-cc

  • transformers

  • optimum-neuron

Most of these packages will be installed when configuring your environment using the Trn1/Inf2 setup guide. The additional dependencies must be installed here:

[ ]:
!pip install --upgrade transformers==4.31.0 optimum-neuron==0.0.8

🤗 Optimum Neuron is the interface between the 🤗 Transformers library and AWS Accelerators including AWS Trainium and AWS Inferentia. It provides a set of tools enabling easy model loading, training and inference on single- and multi-Accelerator settings for different downstream tasks. In this tutorial we use 🤗 HuggingFace Optimum Neuron’s generate() method instead of 🤗 transformers’s generate() to perform greedy decoding. Optimum Neuron takes care of padding the inputs which is necessary to infer on Neuron.

Compile the model into an AWS Neuron optimized TorchScript#

In the following section, we load the T5 model, compile the model’s encoder and decoder for Neuron using torch_neuronx.trace(), and save the optimized encoder and decoder as TorchScript.

Before we trace the model, we need to make a couple of changes.

  1. We need to write encoder and decoder wrappers - torch_neuronx can only trace functions with positional arguments. But the T5 encoder and decoder both use keyword arguments. So, in order to trace them, we have to write wrappers that convert keyword arguments to positional arguments

  2. We modify the t5 code to maximize the computation on the neuron device - Having sections of code running on cpu will reduce the performance. Moreover, we do not want to move data berween the neuron device and cpu during inference. The code we trace with torch_neuronx is the code that runs on the neuron device, so we refactor the t5 code to run computationally heavy operations within the wrapper.

Let us start with the EncoderWrapper.

In the huggingface t5 implementation, the encoder block takes in the input ids and returns the encoder hidden states. This hidden states are then used to initialize the KV cache in the decoder blocks during the first decoder invocation. We could trace both the encoder and the cache initialization step separately. But there is a better way, we could just compute the initial KV cache state within the encoder wrapper. This way, we remove the overhead of moving the hidden states from neuron device to cpu and back. This also allows neuron’s compiler to optimize execution across both the encoder and cache initialization.

Why don’t we just initalize the cache on the first decoder run?

This is harder to do on Neuron. Similar to torch.jit.trace(), torch_neuronx.trace() produces a function that has a fixed control flow, i.e. there are no conditional executions. So we cannot choose to conditionally initialize the cache in the first decoder iteration. Instead, we can compute the initial cache state outside the generation flow and pass the cache to it.

[ ]:
import torch

from transformers.models.t5.modeling_t5 import T5Stack, T5LayerCrossAttention

class EncoderWrapper(torch.nn.Module):
    '''
        We will trace an instance of the EncoderWrapper.
        This wrapper just converts positional args to kwargs.
    '''

    def __init__(self,
                 encoder,
                 decoder,
                 model_config,
                 batch_size,
                 max_length,
                 device,
                 num_beams,
                 tp_degree=None):

        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.batch_size = batch_size
        self.max_length = max_length
        self.model_config = model_config
        self.device = device
        self.num_beams = num_beams
        self.num_attention_heads_per_partition = model_config.num_heads
        self.tp_degree = tp_degree

    def forward(self, input_ids, attention_mask):
        '''
            This is the core functionality we want to trace.
        '''
        encoder_output =  self.encoder(input_ids=input_ids,
                                       attention_mask=attention_mask,
                                       output_attentions=False,
                                       output_hidden_states=False)

        last_hidden_state = encoder_output["last_hidden_state"]
        encoder_hidden_states = torch.concat([tensor.unsqueeze(0).repeat(self.num_beams, 1, 1) for tensor in last_hidden_state])

        decoder_blocks = self.decoder.block
        present_key_value_states_sa = []
        present_key_value_states_ca = []

        for i, block in enumerate(decoder_blocks):

            # Cross attention has to be initialized with the encoder hidden state
            cross_attention: T5LayerCrossAttention = block.layer[1]
            attention = cross_attention.EncDecAttention

            def shape(states):
                """projection"""
                return states.view(self.batch_size, -1, self.num_attention_heads_per_partition, attention.key_value_proj_dim).transpose(1, 2)

            key_states = shape(attention.k(encoder_hidden_states))
            value_states = shape(attention.v(encoder_hidden_states))

            # cross_attn_kv_state
            present_key_value_states_ca.append(key_states)
            present_key_value_states_ca.append(value_states)

            # Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant.
            # The kv cache will be an input to the decoder trace. Any traced function will have a fixed control flow. What this means
            # is that the trace performs the exact same computations on inputs of the same shape in each invocation. So the attention
            # kv cache is padded here to keep a fixed shape.
            present_key_value_states_sa.append(torch.zeros((self.batch_size,                                                     # key states
                                                            self.model_config.num_heads,
                                                            self.max_length-1,
                                                            self.model_config.d_kv), dtype=torch.float32, device=self.device))
            present_key_value_states_sa.append(torch.zeros((self.batch_size,                                                     # value states
                                                            self.model_config.num_heads,
                                                            self.max_length-1,
                                                            self.model_config.d_kv), dtype=torch.float32, device=self.device))

        return present_key_value_states_sa + present_key_value_states_ca

In the decoder wrapper, in addition to converting keyword arguments to positional arguments we add support for attention caching. Generating text from the encoder decoder models is an autoregressive process. For each invocation, we have to compute the key and value states of the attention heads repeatedly. To improve the performance, we cache the key and value states. This cache is what HuggingFace transformers code refers to as past_key_values.

In HuggingFace transformers, the past_key_values are updated outside the decoder. This works for training and evaluation but for inference we want to perform them within a single trace. This way, we can optimize across both the decoder execution and cache update. So, we move the cache update within the decoder wrapper.

[3]:
class DecoderWrapper(torch.nn.Module):

    def __init__(self,
                 decoder: T5Stack,
                 lm_head: torch.nn.Linear,
                 model_config,
                 num_beams: int,
                 max_length: int,
                 device: str,
                 tp_degree=None):
        super().__init__()
        self.decoder = decoder
        self.lm_head = lm_head
        self.model_dim=model_config.d_model
        self.device = device
        self.num_beams = num_beams
        self.batch_size = 1
        self.config = model_config

        num_heads=model_config.num_heads
        num_decoder_layers=model_config.num_decoder_layers

        self.num_attention_heads_per_partition = num_heads

        # (num_beams, n_heads, seq_length, dim_per_head)
        if device == "cpu":
            self.past_key_values_sa = [torch.ones((num_beams,num_heads,max_length-1,model_config.d_kv), dtype=torch.float32) for _ in range(num_decoder_layers * 2)]
            self.past_key_values_ca = [torch.ones((num_beams,num_heads,max_length,model_config.d_kv), dtype=torch.float32) for _ in range(num_decoder_layers * 2)]
        elif device == "xla":
            self.past_key_values_sa = torch.nn.ParameterList([torch.nn.Parameter(torch.ones((num_beams,self.num_attention_heads_per_partition,max_length-1,model_config.d_kv), dtype=torch.float32), requires_grad=False) for _ in range(num_decoder_layers * 2)])
            self.past_key_values_ca = torch.nn.ParameterList([torch.nn.Parameter(torch.ones((num_beams,self.num_attention_heads_per_partition,max_length,model_config.d_kv), dtype=torch.float32), requires_grad=False) for _ in range(num_decoder_layers * 2)])

    def update_past(self, past_key_values):
        new_past_sa = []
        new_past_ca = []
        for past_layer in past_key_values:
            new_past_layer = list(past_layer)
            for i in range(len(new_past_layer[:2])):
                new_past_layer[i] = past_layer[i][:, :, 1:]
            new_past_sa += [new_past_layer[:2],]
            new_past_ca += [new_past_layer[2:],]
        return new_past_sa, new_past_ca

    def reorder_cache(self, past_key_values, beam_idx):
        for i in range(len(past_key_values)):
            gather_index = beam_idx.view([beam_idx.shape[0],1,1,1]).expand_as(past_key_values[i])
            past_key_values[i] = torch.gather(past_key_values[i], dim = 0, index=gather_index)
        return past_key_values

    def forward(self,
                input_ids,
                decoder_attention_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                beam_idx,
                beam_scores,
                **kwargs):

        if self.num_beams > 1:
            # We reorder the cache based on the beams selected in each iteration. Required step for beam search.
            past_key_values_sa = self.reorder_cache(self.past_key_values_sa, beam_idx)
            past_key_values_ca = self.reorder_cache(self.past_key_values_ca, beam_idx)
        else:
            # We do not need to reorder for greedy sampling
            past_key_values_sa = self.past_key_values_sa
            past_key_values_ca = self.past_key_values_ca

        # The cache is stored in a flatten form. We order the cache per layer before passing it to the decoder.
        # Each layer has 4 tensors, so we group by 4.
        past_key_values = [[*past_key_values_sa[i*2:i*2+2], *past_key_values_ca[i*2:i*2+2]] for i in range(0, int(len(past_key_values_ca)/2))]

        decoder_output = self.decoder(
            input_ids=input_ids,
            attention_mask=decoder_attention_mask,
            past_key_values=past_key_values,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False)

        last_hidden_state = decoder_output['last_hidden_state']
        past_key_values = decoder_output['past_key_values']

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            last_hidden_state = last_hidden_state * (self.model_dim**-0.5)

        lm_logits = self.lm_head(last_hidden_state)

        past_key_values_sa, past_key_values_ca = self.update_past(past_key_values)

        # We flatten the cache to a single array. This is required for the input output aliasing to work
        past_key_values_sa = [vec for kv_per_layer in past_key_values_sa for vec in kv_per_layer]
        past_key_values_ca = [vec for kv_per_layer in past_key_values_ca for vec in kv_per_layer]

        if self.device == "cpu":
            self.past_key_values_sa = past_key_values_sa
            self.past_key_values_ca = past_key_values_ca

        # We calculate topk inside the wrapper
        next_token_logits = lm_logits[:, -1, :]

        if self.num_beams > 1:
            # This section of beam search is run outside the decoder in the huggingface t5 implementation.
            # To maximize the computation within the neuron device, we move this within the wrapper
            logit_max, _ = torch.max(next_token_logits, dim=-1, keepdim=True)
            logsumexp = torch.log(torch.exp(next_token_logits - logit_max).sum(dim=-1, keepdim=True))
            next_token_scores = next_token_logits - logit_max - logsumexp
            next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)

            # reshape for beam search
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(self.batch_size, self.num_beams * vocab_size)
            next_token_scores = next_token_scores * 1

            # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
            next_token_scores, next_tokens = torch.topk(
                next_token_scores, 2 * self.num_beams, dim=1, largest=True, sorted=True
            )

            next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
            next_tokens = next_tokens % vocab_size

            return [next_token_scores, next_tokens, next_indices] + past_key_values_sa + past_key_values_ca
        else:
            # Greedy
            next_tokens = torch.argmax(next_token_logits, dim=-1)
            return [next_tokens] + past_key_values_sa + past_key_values_ca


Now let’s create a T5 model wrapper to make it compatible with our traced encoder and decoder.

There are two reasons for having this wrapper,

  1. The encoder and decoder traces can only be invoked with positional arguments. But the HuggingFace transformers code is written with keyword arguments. So we override the functions that invoke encoder and decoder to call with positional arguments.

  2. The generate() function in the NeuronGenerationMixin performs cache update within the CPU. As we are handling the cache within the DecoderWrapper, we disable the cache update on CPU.

  3. The topK computation to determine the next tokens for beam search was moved into the decoder wrapper. So, we need to override the huggingface’s beam search implementation to accept the next tokens and the beam scores from the decoder.

Let’s also override the generate() function so that it will intialize the cache using the cache initalizer before starting the greedy decoding.

[4]:
import torch
import torch_xla.core.xla_model as xm

from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.models.t5.modeling_t5 import T5Stack, T5LayerCrossAttention
from transformers.generation.utils import ModelOutput
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.generation.beam_search import BeamScorer, BeamSearchScorer

from optimum.neuron.generation import NeuronGenerationMixin

from transformers.generation.logits_process import (
    LogitsProcessorList,
)
from transformers.generation.stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
    StoppingCriteriaList,
    validate_stopping_criteria,
)

from transformers.generation.utils import (
    BeamSearchOutput,
    GreedySearchOutput,
)

class T5Wrapper(T5ForConditionalGeneration, NeuronGenerationMixin):

    def _prepare_encoder_decoder_kwargs_for_generation(
        self,
        inputs_tensor: torch.Tensor,
        model_kwargs,
        model_input_name: Optional[str] = None
    ) -> Dict[str, Any]:
        encoder = self.get_encoder()
        model_kwargs["encoder_outputs"]: ModelOutput = encoder(inputs_tensor, model_kwargs["attention_mask"])
        return model_kwargs

    # Override to cut the input_ids to just last token
    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        decoder_attention_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids as past is cached
        input_ids = input_ids[:, -1:]

        return {
            "decoder_input_ids": input_ids,
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
        }

    '''
        We update the cache in the decoder trace, so lets override the _update_model_kwargs_for_xla_generation in NeuronGenerationMixin
    '''
    def _update_model_kwargs_for_xla_generation(
        self,
        model_kwargs: Dict[str, Any],
        batch_size: int,
        is_encoder_decoder: bool = False,
        standardize_cache_format: bool = False,
        max_length: Optional[int] = None,
        seq_length: Optional[int] = None,
        use_cache: bool = True,
    ) -> Dict[str, Any]:

        def _update_attention(model_kwargs, is_encoder_decoder):
            """Updates the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`"""

            attention_mask_name = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
            attention_mask = model_kwargs.pop(attention_mask_name)
            attention_mask_update_slice = torch.ones(
                (batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device
            )
            attention_mask = torch.cat([attention_mask[:, 1:], attention_mask_update_slice], dim=-1)
            mask = {attention_mask_name: attention_mask}
            return mask

        mask = _update_attention(model_kwargs, is_encoder_decoder)
        # sets the updated variables (mask and past_key_values)
        model_kwargs.update(mask)

        # Set a mock cache tensor
        model_kwargs["past_key_values"] = torch.tensor([])

        return model_kwargs

    def _reorder_cache(self, past_key_values, beam_idx):
        '''
            This is needed for beam search and not greedy sampling
            We reorder the cache within the trace so we can skip it in modelling_t5.py. So we override the _reorder_cache
        '''
        self.beam_idx = beam_idx
        return past_key_values

    def generate(self,
                tokenizer: T5Tokenizer,
                prompt: str,
                max_length: int,
                num_beams: int,
                num_return_sequences: int,
                device: str):

        batch_encoding = tokenizer(prompt, max_length=max_length, truncation=True, padding='max_length',
                                return_tensors="pt")

        past_key_values = self.encoder(batch_encoding['input_ids'],batch_encoding['attention_mask'])

        decoder_attention_mask = torch.cat([torch.zeros((1, max_length-1), dtype=torch.int32),
                                            torch.ones((1, 1), dtype=torch.int32)], axis=1)

        # copy the new cache state to the decoder
        if device == "xla":
            for state, tensor in zip(self.decoder.parameters(), past_key_values):
                state.copy_(tensor)
        else:
            # First half of the cache is self attention and the rest is cross attention
            self.decoder.past_key_values_sa = past_key_values[:len(past_key_values)//2]
            self.decoder.past_key_values_ca = past_key_values[len(past_key_values)//2:]

        output = super().generate(**batch_encoding,
                                max_length=max_length,
                                num_beams=num_beams,
                                num_return_sequences=num_return_sequences,
                                do_sample=False,
                                use_cache=True,
                                decoder_attention_mask=decoder_attention_mask,
                                encoder_outputs={"last_hidden_state": torch.ones((1,128,1))}) # Pass fake encoder_outputs so the transfomers code will not invoke the encoder
        return output

    def forward(
        self,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        beam_scores = None,
        **kwargs
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:

        hidden_states = encoder_outputs["last_hidden_state"]

        if not hasattr(self, 'beam_idx'):
            # Infering the number of beams from the attention mask
            num_beams = attention_mask.shape[0]
            self.beam_idx = torch.arange(0, num_beams, dtype=torch.int64)

        decoder_outputs = self.decoder(
            decoder_input_ids,
            decoder_attention_mask,
            hidden_states,
            attention_mask,
            self.beam_idx,
            beam_scores
        )

        # lm_logits = decoder_outputs[0]
        next_token_scores = decoder_outputs[0]
        next_tokens = decoder_outputs[1]
        next_indices = decoder_outputs[2]

        return next_token_scores, next_tokens, next_indices

    def beam_search(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: Optional[bool] = False,
        seq_length: Optional[int] = None,
        **model_kwargs,
    ) -> Union[BeamSearchOutput, torch.LongTensor]:

        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
        output_attentions = (
            output_attentions if output_attentions is not None else self.generation_config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
        )

        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

        # Overwrite cur_len
        cur_len = seq_length

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )

        # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
        # of the first beam are considered to avoid sampling the exact same tokens across all beams.
        # beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
        beam_scores_device = "cpu"
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=beam_scores_device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))

        while True:
            # prepare model inputs
            # From max_length-sized input_ids, select first
            # cur_len - 1 values.
            update_indices = torch.stack(
                [torch.arange(input_ids.size(0)), torch.tensor(cur_len - 1).repeat(input_ids.size(0))], dim=-1
            )
            input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None]
            model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs)

            next_token_scores, next_tokens, next_indices = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                beam_scores=beam_scores
            )

            # stateless
            beam_outputs = beam_scorer.process(
                input_ids.to("cpu")[:, :cur_len],
                next_token_scores.to("cpu"),
                next_tokens.to("cpu"),
                next_indices.to("cpu"),
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                beam_indices=beam_indices,
            )

            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            update_indices = torch.stack(
                [torch.arange(batch_beam_size), torch.tensor(cur_len - 1).repeat(batch_beam_size)], dim=-1
            )
            update_indices_2 = torch.stack(
                [torch.arange(batch_beam_size), torch.tensor(cur_len).repeat(batch_beam_size)], dim=-1
            )
            # First select beam_indices
            device = input_ids.device
            beam_idx_device = beam_idx.to(device=input_ids.device)
            input_ids[:, :] = input_ids[beam_idx_device.long(), :]

            # Then append new tokens
            input_ids[update_indices_2[:, 0], update_indices_2[:, 1], None] = beam_next_tokens.unsqueeze(-1).to(device).to(torch.long)
            input_ids = input_ids * 1  # Hack to materialize tensor

            # update generated ids, model inputs, and length for next step
            model_kwargs = self._update_model_kwargs_for_xla_generation(
                model_kwargs,
                batch_size=batch_beam_size,
                is_encoder_decoder=self.config.is_encoder_decoder,
                max_length=stopping_criteria.max_length,
                seq_length=cur_len,
                use_cache=model_kwargs["use_cache"],
            )
            if model_kwargs["past_key_values"] is not None:
                model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx.to(torch.int64))

            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

            # increase cur_len
            cur_len = cur_len + 1

            # stop when each sentence is finished, or if we exceed the maximum length
            stop_criterion_1 = beam_scorer.is_done
            if isinstance(stopping_criteria, list):
                if len(stopping_criteria) == 1:
                    stopping_criteria = stopping_criteria[0]

            # Cases that can be handled in XLA without requiring
            # non-padded input_ids
            if isinstance(stopping_criteria, MaxLengthCriteria):
                stop_criterion_2 = cur_len >= stopping_criteria.max_length
            elif isinstance(stopping_criteria, MaxTimeCriteria):
                stop_criterion_2 = stopping_criteria(input_ids, scores)
            else:
                # Other cases will be handled on CPU
                batch_size, _ = input_ids.shape
                input_ids_cpu = input_ids.to("cpu")
                mask = torch.cat(
                    [torch.ones(batch_size, cur_len), torch.zeros(batch_size, input_ids.shape[1] - cur_len)], dim=1
                ).bool()
                input_ids_cpu = torch.masked_select(input_ids_cpu, mask).reshape((batch_size, cur_len))
                scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores
                stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu)

            if stop_criterion_1 or stop_criterion_2:
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

        sequence_outputs = beam_scorer.finalize(
            input_ids.to("cpu"),
            beam_scores.to("cpu"),
            next_tokens.to("cpu"),
            next_indices.to("cpu"),
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
            beam_indices=beam_indices,
        )

        for k, v in sequence_outputs.items():
            if type(v) == torch.Tensor:
                sequence_outputs[k] = sequence_outputs[k].to(input_ids.device)

        return sequence_outputs["sequences"]


    def greedy_search(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        seq_length: Optional[int] = int,
        streamer: Optional["BaseStreamer"] = None,
        **model_kwargs,
    ) -> Union[GreedySearchOutput, torch.LongTensor]:
        """
            Overriding greedy sampling to use next tokens returned from neuron device instead of logits.
        """
        # init values
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        use_cache = model_kwargs["use_cache"] if "use_cache" in model_kwargs else False
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
        output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
        output_attentions = (
            output_attentions if output_attentions is not None else self.generation_config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
        )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None


        # keep track of which sequences are already finished
        unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)

        this_peer_finished = False  # used by synced_gpus only
        while True:

            # prepare model inputs
            # From max_length-sized input_ids, select first
            # seq_length - 1 values.

            if model_kwargs.get("past_key_values") is None:
                input_ids_ = input_ids[:, :seq_length]
            else:
                update_indices = torch.stack(
                    [torch.arange(input_ids.size(0)), torch.tensor(seq_length - 1).repeat(input_ids.size(0))],
                    dim=-1,
                )
                input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None]

            model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs)

            # forward pass to get next token
            output = self(
               **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )
            next_tokens = output[0]

            # finished sentences should have their next token be a padding token
            if eos_token_id is not None:
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

            # update generated ids, model inputs, and length for next step

            batch_size, _ = input_ids.shape
            update_indices = torch.stack(
                [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1
            )
            input_ids[update_indices[:, 0], update_indices[:, 1]] = next_tokens[:]
            model_kwargs = self._update_model_kwargs_for_xla_generation(
                model_kwargs,
                batch_size=batch_size,
                is_encoder_decoder=self.config.is_encoder_decoder,
                max_length=stopping_criteria.max_length,
                seq_length=seq_length,
                use_cache=use_cache,
            )

            seq_length += 1

            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id_tensor is not None:
                unfinished_sequences = unfinished_sequences.mul(
                    next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
                )

            # stop when each sentence is finished, or if we exceed the maximum length
            stop_criterion_1 = unfinished_sequences.max() == 0

            if isinstance(stopping_criteria, list):
                if len(stopping_criteria) == 1:
                    stopping_criteria = stopping_criteria[0]

            # Cases that can be handled in XLA without requiring
            # non-padded input_ids
            if isinstance(stopping_criteria, MaxLengthCriteria):
                stop_criterion_2 = seq_length >= stopping_criteria.max_length
            elif isinstance(stopping_criteria, MaxTimeCriteria):
                stop_criterion_2 = stopping_criteria(input_ids, scores)
            else:
                # Other cases will be handled on CPU
                batch_size, _ = input_ids.shape
                mask = torch.cat(
                    [torch.ones(batch_size, seq_length), torch.zeros(batch_size, input_ids.shape[1] - seq_length)],
                    dim=1,
                ).bool()
                input_ids_cpu = torch.masked_select(input_ids, mask).reshape((batch_size, seq_length)).to("cpu")
                scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores
                stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu)

            if stop_criterion_1 or stop_criterion_2:
                this_peer_finished = True

            if this_peer_finished:
                break

        if streamer is not None:
            streamer.end()

        return input_ids

Now let’s test inference on CPU with all the wrappers before tracing.

[5]:
# Let's set some run parameters

model_name = "t5-large"
num_beams = 1
num_return_sequences = 1
max_length = 128
[6]:
from transformers import T5Tokenizer


prompt="translate English to German: Lets eat good food."

tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=max_length)
model = T5Wrapper.from_pretrained(model_name)

model.encoder = EncoderWrapper(model.encoder, model.decoder, model.config, num_beams, max_length, "cpu", num_beams)
setattr(model.encoder, 'main_input_name', 'input_ids')  # Attribute required by beam search

model.decoder = DecoderWrapper(decoder=model.decoder,
                                lm_head=model.lm_head,
                                model_config=model.config,
                                num_beams=num_beams,
                                max_length=max_length,
                                device="cpu")

output = model.generate(tokenizer=tokenizer,
                        prompt=prompt,
                        max_length=max_length,
                        num_beams=num_beams,
                        num_return_sequences=num_return_sequences,
                        device="cpu")

results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]

print('Results:')
for i, summary in enumerate(results):
    print(i + 1, summary)

Results:
1 Lassen Sie uns gutes Essen essen.

Now that the wrappers are running as expected, let’s trace the encoder, and decoder. To trace these functions, we pass the function and a sample input to the trace function. The result of the trace stage will be a static executable where the operations to be run upon inference are determined during compilation. This means that when inferring, the resulting Neuron model must be executed with tensors that are the exact same shape as those provided at compilation time. If a model is given a tensor at inference time whose shape does not match the tensor given at compilation time, an error will occur.

The decoder wrapper returns the new state of the cache as an output which is copied back to the CPU. As the cache is a large tensor, copying it to and from the XLA device for each decoder invocation will significantly slow down the inference. Instead, we can use input output aliasing, a feature of torch_neuronx to keep these tensors on device rather than copying back to the CPU. To use input output aliasing, we need to map the outputs to input parameters while tracing.

[ ]:
import torch
import torch_neuronx

from transformers import T5Tokenizer, T5ForConditionalGeneration

def trace_encoder(model: T5ForConditionalGeneration,
                  tokenizer: T5Tokenizer,
                  max_length: int,
                  num_beams: int):

    # Trace encoder
    batch_encoding = tokenizer("translate English to German: Lets go home now",
                               max_length=max_length, truncation=True, padding='max_length', return_tensors="pt")
    input_ids = batch_encoding['input_ids']
    attention_mask = batch_encoding['attention_mask']

    encoder = EncoderWrapper(model.encoder, model.decoder, model.config, num_beams, max_length, "xla", num_beams)
    traced_encoder = torch_neuronx.trace(encoder, (input_ids, attention_mask), compiler_workdir="/tmp/encoder/")
    setattr(traced_encoder, 'main_input_name', 'input_ids')  # Attribute required by beam search

    return traced_encoder

def trace_decoder(model: T5ForConditionalGeneration,
                  num_beams: int,
                  max_length: int):

    decoder = DecoderWrapper(decoder=model.decoder,
                             lm_head=model.lm_head,
                             model_config=model.config,
                             num_beams=num_beams,
                             max_length=max_length,
                             device="xla")

    # We create mock inputs so we can trace the decoder
    decoder_input_ids = torch.ones((num_beams, 1), dtype=torch.int64)
    decoder_attention_mask = torch.ones((num_beams, max_length), dtype=torch.int32)
    encoder_attention_mask = torch.ones((num_beams, max_length), dtype=torch.int64)
    encoder_hidden_states = torch.ones((num_beams, max_length, model.config.d_model), dtype=torch.float32)

    beam_idx = torch.arange(0, num_beams, dtype=torch.int64)
    beam_scores = torch.zeros((num_beams,), dtype=torch.float)

    num_outputs_from_trace = 3 if num_beams > 1 else 1

    aliases = {}
    for i in range(len(decoder.past_key_values_sa)):
        aliases[decoder.past_key_values_sa[i]] = i + num_outputs_from_trace
    for i in range(len(decoder.past_key_values_ca)):
        aliases[decoder.past_key_values_ca[i]] = len(decoder.past_key_values_sa) + i + num_outputs_from_trace

    traced_decoder = torch_neuronx.trace(decoder, (
        decoder_input_ids,
        decoder_attention_mask,
        encoder_hidden_states,
        encoder_attention_mask,
        beam_idx,
        beam_scores,
    ), input_output_aliases=aliases, compiler_workdir="/tmp/decoder/")

    return traced_decoder


tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=max_length)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# We enable this flag to ensure model uses attention key value caching
model.config.use_cache = True

traced_encoder = trace_encoder(model, tokenizer, max_length, num_beams)
traced_decoder = trace_decoder(model, num_beams, max_length)

torch.jit.save(traced_encoder, "TracedEncoder.pt")
torch.jit.save(traced_decoder, "TracedDecoder.pt")

Run inference with greedy decoding#

Now that we have the traced model, let’s use it for inference.

[8]:
runtime = torch.classes.neuron.Runtime()
runtime.initialize()
runtime.set_default_neuron_cores(0, 1)

tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5Wrapper.from_pretrained(model_name)

model.encoder = torch.jit.load("TracedEncoder.pt")
# Attribute required by beam search
setattr(model.encoder, 'main_input_name', 'input_ids')

model.decoder = torch.jit.load("TracedDecoder.pt")
torch_neuronx.move_trace_to_device(model.decoder, 0)


output = model.generate(tokenizer=tokenizer,
                        prompt="translate English to German: Lets eat good food.",
                        max_length=max_length,
                        num_beams=num_beams,
                        num_return_sequences=num_return_sequences,
                        device="xla")

results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]

print('Results:')
for i, summary in enumerate(results):
    print(i + 1, summary)


Results:
1 Lassen Sie uns gutes Essen essen.