{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# T5 model inference on Trn1 or Inf2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "In this tutorial we will compile and deploy a pretrained T5 model for accelerated inference on Neuron. \n", "\n", "This tutorial will use the [t5-large](https://huggingface.co/t5-large) model. The T5 model can be used for machine translation, document summarization, question answering, and classification tasks. \n", "\n", "This tutorial has the following main sections:\n", "\n", "1. Install dependencies\n", "1. Compile the T5 model\n", "1. Run inference with greedy decoding on Neuron\n", "1. Run infernece with beam search on Neuron\n", "\n", "This Jupyter notebook should be run on a Trn1 instance (`trn1.2xlarge` or larger.) or Inf2 instance (`inf2.xlarge` or larger.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies\n", "\n", "The code in this tutorial is written for Jupyter Notebooks. To use Jupyter Notebook on the Neuron instance, you\n", "can use this [guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/notebook/setup-jupyter-notebook-steps-troubleshooting.html).\n", "\n", "This tutorial requires the following pip packages:\n", "\n", "- `torch-neuronx`\n", "- `neuronx-cc`\n", "- `transformers`\n", "- `optimum-neuron`\n", "\n", "Most of these packages will be installed when configuring your environment using the Trn1/Inf2 setup guide. The additional dependencies must be installed here:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install --upgrade transformers==4.31.0 optimum-neuron==0.0.8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "🤗 Optimum Neuron is the interface between the 🤗 Transformers library and AWS Accelerators including AWS Trainium and AWS Inferentia. It provides a set of tools enabling easy model loading, training and inference on single- and multi-Accelerator settings for different downstream tasks. In this tutorial we use 🤗 HuggingFace Optimum Neuron's generate() method instead of 🤗 [transformers's generate()](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate) to perform greedy decoding. Optimum Neuron takes care of padding the inputs which is necessary to infer on Neuron.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compile the model into an AWS Neuron optimized TorchScript\n", "\n", "In the following section, we load the T5 model, compile the model's encoder and decoder for Neuron using `torch_neuronx.trace()`, and save the optimized encoder and decoder as `TorchScript`. \n", "\n", "Before we trace the model, we need to make a couple of changes. \n", "\n", "1. We need to write encoder and decoder wrappers - `torch_neuronx` can only trace functions with positional arguments. But the T5 encoder and decoder both use keyword arguments. So, in order to trace them, we have to write wrappers that convert keyword arguments to positional arguments \n", "2. We modify the t5 code to maximize the computation on the neuron device - Having sections of code running on cpu will reduce the performance. Moreover, we do not want to move data berween the neuron device and cpu during inference. The code we trace with `torch_neuronx` is the code that runs on the neuron device, so we refactor the t5 code to run computationally heavy operations within the wrapper. \n", "\n", "Let us start with the EncoderWrapper. \n", "\n", "In the huggingface t5 implementation, the encoder block takes in the input ids and returns the encoder hidden states. This hidden states are then used to initialize the KV cache in the decoder blocks during the first decoder invocation. We could trace both the encoder and the cache initialization step separately. But there is a better way, we could just compute the initial KV cache state within the encoder wrapper. This way, we remove the overhead of moving the hidden states from neuron device to cpu and back. This also allows neuron's compiler to optimize execution across both the encoder and cache initialization. \n", "\n", "*Why don't we just initalize the cache on the first decoder run?* \n", "\n", "This is harder to do on Neuron. Similar to `torch.jit.trace()`, `torch_neuronx.trace()` produces a function that has a fixed control flow, i.e. there are no conditional executions. So we cannot choose to conditionally initialize the cache in the first decoder iteration. Instead, we can compute the initial cache state outside the generation flow and pass the cache to it. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from transformers.models.t5.modeling_t5 import T5Stack, T5LayerCrossAttention\n", "\n", "class EncoderWrapper(torch.nn.Module):\n", " '''\n", " We will trace an instance of the EncoderWrapper. \n", " This wrapper just converts positional args to kwargs. \n", " '''\n", "\n", " def __init__(self, \n", " encoder,\n", " decoder, \n", " model_config, \n", " batch_size, \n", " max_length, \n", " device, \n", " num_beams,\n", " tp_degree=None):\n", " \n", " super().__init__()\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " self.batch_size = batch_size\n", " self.max_length = max_length\n", " self.model_config = model_config\n", " self.device = device\n", " self.num_beams = num_beams\n", " self.num_attention_heads_per_partition = model_config.num_heads\n", " self.tp_degree = tp_degree\n", "\n", " def forward(self, input_ids, attention_mask):\n", " '''\n", " This is the core functionality we want to trace. \n", " '''\n", " encoder_output = self.encoder(input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " output_attentions=False,\n", " output_hidden_states=False)\n", "\n", " last_hidden_state = encoder_output[\"last_hidden_state\"]\n", " encoder_hidden_states = torch.concat([tensor.unsqueeze(0).repeat(self.num_beams, 1, 1) for tensor in last_hidden_state])\n", "\n", " decoder_blocks = self.decoder.block\n", " present_key_value_states_sa = []\n", " present_key_value_states_ca = []\n", "\n", " for i, block in enumerate(decoder_blocks):\n", "\n", " # Cross attention has to be initialized with the encoder hidden state\n", " cross_attention: T5LayerCrossAttention = block.layer[1]\n", " attention = cross_attention.EncDecAttention\n", "\n", " def shape(states):\n", " \"\"\"projection\"\"\"\n", " return states.view(self.batch_size, -1, self.num_attention_heads_per_partition, attention.key_value_proj_dim).transpose(1, 2)\n", "\n", " key_states = shape(attention.k(encoder_hidden_states))\n", " value_states = shape(attention.v(encoder_hidden_states))\n", "\n", " # cross_attn_kv_state\n", " present_key_value_states_ca.append(key_states) \n", " present_key_value_states_ca.append(value_states) \n", " \n", " # Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant. \n", " # The kv cache will be an input to the decoder trace. Any traced function will have a fixed control flow. What this means \n", " # is that the trace performs the exact same computations on inputs of the same shape in each invocation. So the attention \n", " # kv cache is padded here to keep a fixed shape. \n", " present_key_value_states_sa.append(torch.zeros((self.batch_size, # key states\n", " self.model_config.num_heads, \n", " self.max_length-1, \n", " self.model_config.d_kv), dtype=torch.float32, device=self.device)) \n", " present_key_value_states_sa.append(torch.zeros((self.batch_size, # value states\n", " self.model_config.num_heads, \n", " self.max_length-1, \n", " self.model_config.d_kv), dtype=torch.float32, device=self.device))\n", "\n", " return present_key_value_states_sa + present_key_value_states_ca\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "In the decoder wrapper, in addition to converting keyword arguments to positional arguments we add support for attention caching. Generating text from the encoder decoder models is an autoregressive process. For each invocation, we have to compute the key and value states of the attention heads repeatedly. To improve the performance, we cache the key and value states. This cache is what HuggingFace transformers code refers to as `past_key_values`.\n", "\n", "In HuggingFace transformers, the `past_key_values` are updated outside the decoder. This works for training and evaluation but for inference we want to perform them within a single trace. This way, we can optimize across both the decoder execution and cache update. So, we move the cache update within the decoder wrapper." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class DecoderWrapper(torch.nn.Module):\n", "\n", " def __init__(self, \n", " decoder: T5Stack, \n", " lm_head: torch.nn.Linear,\n", " model_config,\n", " num_beams: int, \n", " max_length: int,\n", " device: str,\n", " tp_degree=None):\n", " super().__init__() \n", " self.decoder = decoder\n", " self.lm_head = lm_head\n", " self.model_dim=model_config.d_model\n", " self.device = device\n", " self.num_beams = num_beams\n", " self.batch_size = 1\n", " self.config = model_config\n", " \n", " num_heads=model_config.num_heads\n", " num_decoder_layers=model_config.num_decoder_layers\n", "\n", " self.num_attention_heads_per_partition = num_heads\n", "\n", " # (num_beams, n_heads, seq_length, dim_per_head)\n", " if device == \"cpu\":\n", " self.past_key_values_sa = [torch.ones((num_beams,num_heads,max_length-1,model_config.d_kv), dtype=torch.float32) for _ in range(num_decoder_layers * 2)]\n", " self.past_key_values_ca = [torch.ones((num_beams,num_heads,max_length,model_config.d_kv), dtype=torch.float32) for _ in range(num_decoder_layers * 2)]\n", " elif device == \"xla\":\n", " self.past_key_values_sa = torch.nn.ParameterList([torch.nn.Parameter(torch.ones((num_beams,self.num_attention_heads_per_partition,max_length-1,model_config.d_kv), dtype=torch.float32), requires_grad=False) for _ in range(num_decoder_layers * 2)])\n", " self.past_key_values_ca = torch.nn.ParameterList([torch.nn.Parameter(torch.ones((num_beams,self.num_attention_heads_per_partition,max_length,model_config.d_kv), dtype=torch.float32), requires_grad=False) for _ in range(num_decoder_layers * 2)])\n", "\n", " def update_past(self, past_key_values):\n", " new_past_sa = []\n", " new_past_ca = []\n", " for past_layer in past_key_values:\n", " new_past_layer = list(past_layer)\n", " for i in range(len(new_past_layer[:2])):\n", " new_past_layer[i] = past_layer[i][:, :, 1:]\n", " new_past_sa += [new_past_layer[:2],]\n", " new_past_ca += [new_past_layer[2:],]\n", " return new_past_sa, new_past_ca\n", " \n", " def reorder_cache(self, past_key_values, beam_idx):\n", " for i in range(len(past_key_values)):\n", " gather_index = beam_idx.view([beam_idx.shape[0],1,1,1]).expand_as(past_key_values[i])\n", " past_key_values[i] = torch.gather(past_key_values[i], dim = 0, index=gather_index)\n", " return past_key_values\n", "\n", " def forward(self,\n", " input_ids,\n", " decoder_attention_mask,\n", " encoder_hidden_states,\n", " encoder_attention_mask,\n", " beam_idx,\n", " beam_scores,\n", " **kwargs):\n", "\n", " if self.num_beams > 1:\n", " # We reorder the cache based on the beams selected in each iteration. Required step for beam search.\n", " past_key_values_sa = self.reorder_cache(self.past_key_values_sa, beam_idx)\n", " past_key_values_ca = self.reorder_cache(self.past_key_values_ca, beam_idx)\n", " else:\n", " # We do not need to reorder for greedy sampling\n", " past_key_values_sa = self.past_key_values_sa\n", " past_key_values_ca = self.past_key_values_ca\n", "\n", " # The cache is stored in a flatten form. We order the cache per layer before passing it to the decoder. \n", " # Each layer has 4 tensors, so we group by 4. \n", " past_key_values = [[*past_key_values_sa[i*2:i*2+2], *past_key_values_ca[i*2:i*2+2]] for i in range(0, int(len(past_key_values_ca)/2))]\n", "\n", " decoder_output = self.decoder(\n", " input_ids=input_ids,\n", " attention_mask=decoder_attention_mask,\n", " past_key_values=past_key_values,\n", " encoder_hidden_states=encoder_hidden_states,\n", " encoder_attention_mask=encoder_attention_mask,\n", " use_cache=True,\n", " output_attentions=False,\n", " output_hidden_states=False)\n", "\n", " last_hidden_state = decoder_output['last_hidden_state']\n", " past_key_values = decoder_output['past_key_values']\n", "\n", " if self.config.tie_word_embeddings:\n", " # Rescale output before projecting on vocab\n", " # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586\n", " last_hidden_state = last_hidden_state * (self.model_dim**-0.5)\n", " \n", " lm_logits = self.lm_head(last_hidden_state)\n", "\n", " past_key_values_sa, past_key_values_ca = self.update_past(past_key_values)\n", "\n", " # We flatten the cache to a single array. This is required for the input output aliasing to work\n", " past_key_values_sa = [vec for kv_per_layer in past_key_values_sa for vec in kv_per_layer]\n", " past_key_values_ca = [vec for kv_per_layer in past_key_values_ca for vec in kv_per_layer]\n", "\n", " if self.device == \"cpu\":\n", " self.past_key_values_sa = past_key_values_sa\n", " self.past_key_values_ca = past_key_values_ca\n", "\n", " # We calculate topk inside the wrapper\n", " next_token_logits = lm_logits[:, -1, :]\n", "\n", " if self.num_beams > 1:\n", " # This section of beam search is run outside the decoder in the huggingface t5 implementation. \n", " # To maximize the computation within the neuron device, we move this within the wrapper\n", " logit_max, _ = torch.max(next_token_logits, dim=-1, keepdim=True)\n", " logsumexp = torch.log(torch.exp(next_token_logits - logit_max).sum(dim=-1, keepdim=True))\n", " next_token_scores = next_token_logits - logit_max - logsumexp\n", " next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)\n", "\n", " # reshape for beam search\n", " vocab_size = next_token_scores.shape[-1]\n", " next_token_scores = next_token_scores.view(self.batch_size, self.num_beams * vocab_size)\n", " next_token_scores = next_token_scores * 1\n", "\n", " # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)\n", " next_token_scores, next_tokens = torch.topk(\n", " next_token_scores, 2 * self.num_beams, dim=1, largest=True, sorted=True\n", " ) \n", "\n", " next_indices = torch.div(next_tokens, vocab_size, rounding_mode=\"floor\")\n", " next_tokens = next_tokens % vocab_size\n", "\n", " return [next_token_scores, next_tokens, next_indices] + past_key_values_sa + past_key_values_ca\n", " else:\n", " # Greedy \n", " next_tokens = torch.argmax(next_token_logits, dim=-1)\n", " return [next_tokens] + past_key_values_sa + past_key_values_ca\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's create a T5 model wrapper to make it compatible with our traced encoder and decoder. \n", "\n", "There are two reasons for having this wrapper, \n", "\n", "1. The encoder and decoder traces can only be invoked with positional arguments. But the HuggingFace transformers code is written with keyword arguments. So we override the functions that invoke encoder and decoder to call with positional arguments. \n", "1. The generate() function in the NeuronGenerationMixin performs cache update within the CPU. As we are handling the cache within the DecoderWrapper, we disable the cache update on CPU. \n", "1. The topK computation to determine the next tokens for beam search was moved into the decoder wrapper. So, we need to override the huggingface's beam search implementation to accept the next tokens and the beam scores from the decoder. \n", "\n", "Let's also override the `generate()` function so that it will intialize the cache using the cache initalizer before starting the greedy decoding." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch_xla.core.xla_model as xm\n", "\n", "from transformers import T5Tokenizer, T5ForConditionalGeneration\n", "from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput\n", "from transformers.models.t5.modeling_t5 import T5Stack, T5LayerCrossAttention\n", "from transformers.generation.utils import ModelOutput\n", "from typing import Any, Dict, List, Optional, Tuple, Union\n", "from transformers.generation.beam_search import BeamScorer, BeamSearchScorer\n", "\n", "from optimum.neuron.generation import NeuronGenerationMixin\n", "\n", "from transformers.generation.logits_process import (\n", " LogitsProcessorList,\n", ")\n", "from transformers.generation.stopping_criteria import (\n", " MaxLengthCriteria,\n", " MaxTimeCriteria,\n", " StoppingCriteriaList,\n", " validate_stopping_criteria,\n", ")\n", "\n", "from transformers.generation.utils import (\n", " BeamSearchOutput,\n", " GreedySearchOutput,\n", ")\n", "\n", "class T5Wrapper(T5ForConditionalGeneration, NeuronGenerationMixin):\n", "\n", " def _prepare_encoder_decoder_kwargs_for_generation(\n", " self, \n", " inputs_tensor: torch.Tensor, \n", " model_kwargs, \n", " model_input_name: Optional[str] = None\n", " ) -> Dict[str, Any]:\n", " encoder = self.get_encoder()\n", " model_kwargs[\"encoder_outputs\"]: ModelOutput = encoder(inputs_tensor, model_kwargs[\"attention_mask\"])\n", " return model_kwargs\n", "\n", " # Override to cut the input_ids to just last token\n", " def prepare_inputs_for_generation(\n", " self,\n", " input_ids,\n", " past_key_values=None,\n", " attention_mask=None,\n", " head_mask=None,\n", " decoder_head_mask=None,\n", " decoder_attention_mask=None,\n", " cross_attn_head_mask=None,\n", " use_cache=None,\n", " encoder_outputs=None,\n", " **kwargs,\n", " ):\n", " # cut decoder_input_ids as past is cached\n", " input_ids = input_ids[:, -1:]\n", "\n", " return {\n", " \"decoder_input_ids\": input_ids,\n", " \"past_key_values\": past_key_values,\n", " \"encoder_outputs\": encoder_outputs,\n", " \"attention_mask\": attention_mask,\n", " \"head_mask\": head_mask,\n", " \"decoder_head_mask\": decoder_head_mask,\n", " \"decoder_attention_mask\": decoder_attention_mask,\n", " \"cross_attn_head_mask\": cross_attn_head_mask,\n", " \"use_cache\": use_cache,\n", " }\n", " \n", " '''\n", " We update the cache in the decoder trace, so lets override the _update_model_kwargs_for_xla_generation in NeuronGenerationMixin\n", " '''\n", " def _update_model_kwargs_for_xla_generation(\n", " self,\n", " model_kwargs: Dict[str, Any],\n", " batch_size: int,\n", " is_encoder_decoder: bool = False,\n", " standardize_cache_format: bool = False,\n", " max_length: Optional[int] = None,\n", " seq_length: Optional[int] = None,\n", " use_cache: bool = True,\n", " ) -> Dict[str, Any]:\n", "\n", " def _update_attention(model_kwargs, is_encoder_decoder):\n", " \"\"\"Updates the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`\"\"\"\n", "\n", " attention_mask_name = \"decoder_attention_mask\" if is_encoder_decoder else \"attention_mask\"\n", " attention_mask = model_kwargs.pop(attention_mask_name)\n", " attention_mask_update_slice = torch.ones(\n", " (batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device\n", " )\n", " attention_mask = torch.cat([attention_mask[:, 1:], attention_mask_update_slice], dim=-1)\n", " mask = {attention_mask_name: attention_mask}\n", " return mask\n", "\n", " mask = _update_attention(model_kwargs, is_encoder_decoder)\n", " # sets the updated variables (mask and past_key_values)\n", " model_kwargs.update(mask)\n", "\n", " # Set a mock cache tensor\n", " model_kwargs[\"past_key_values\"] = torch.tensor([])\n", "\n", " return model_kwargs\n", " \n", " def _reorder_cache(self, past_key_values, beam_idx):\n", " '''\n", " This is needed for beam search and not greedy sampling\n", " We reorder the cache within the trace so we can skip it in modelling_t5.py. So we override the _reorder_cache\n", " '''\n", " self.beam_idx = beam_idx\n", " return past_key_values\n", "\n", " def generate(self,\n", " tokenizer: T5Tokenizer,\n", " prompt: str,\n", " max_length: int,\n", " num_beams: int,\n", " num_return_sequences: int,\n", " device: str):\n", "\n", " batch_encoding = tokenizer(prompt, max_length=max_length, truncation=True, padding='max_length',\n", " return_tensors=\"pt\")\n", "\n", " past_key_values = self.encoder(batch_encoding['input_ids'],batch_encoding['attention_mask'])\n", " \n", " decoder_attention_mask = torch.cat([torch.zeros((1, max_length-1), dtype=torch.int32),\n", " torch.ones((1, 1), dtype=torch.int32)], axis=1)\n", "\n", " # copy the new cache state to the decoder\n", " if device == \"xla\":\n", " for state, tensor in zip(self.decoder.parameters(), past_key_values):\n", " state.copy_(tensor)\n", " else:\n", " # First half of the cache is self attention and the rest is cross attention\n", " self.decoder.past_key_values_sa = past_key_values[:len(past_key_values)//2]\n", " self.decoder.past_key_values_ca = past_key_values[len(past_key_values)//2:]\n", " \n", " output = super().generate(**batch_encoding,\n", " max_length=max_length,\n", " num_beams=num_beams,\n", " num_return_sequences=num_return_sequences,\n", " do_sample=False,\n", " use_cache=True,\n", " decoder_attention_mask=decoder_attention_mask, \n", " encoder_outputs={\"last_hidden_state\": torch.ones((1,128,1))}) # Pass fake encoder_outputs so the transfomers code will not invoke the encoder\n", " return output\n", "\n", " def forward(\n", " self,\n", " attention_mask: Optional[torch.FloatTensor] = None,\n", " decoder_input_ids: Optional[torch.LongTensor] = None,\n", " decoder_attention_mask: Optional[torch.BoolTensor] = None,\n", " encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n", " beam_scores = None,\n", " **kwargs\n", " ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:\n", "\n", " hidden_states = encoder_outputs[\"last_hidden_state\"]\n", "\n", " if not hasattr(self, 'beam_idx'):\n", " # Infering the number of beams from the attention mask\n", " num_beams = attention_mask.shape[0]\n", " self.beam_idx = torch.arange(0, num_beams, dtype=torch.int64)\n", "\n", " decoder_outputs = self.decoder(\n", " decoder_input_ids,\n", " decoder_attention_mask,\n", " hidden_states,\n", " attention_mask,\n", " self.beam_idx,\n", " beam_scores\n", " )\n", "\n", " # lm_logits = decoder_outputs[0]\n", " next_token_scores = decoder_outputs[0]\n", " next_tokens = decoder_outputs[1]\n", " next_indices = decoder_outputs[2]\n", "\n", " return next_token_scores, next_tokens, next_indices\n", "\n", " def beam_search(\n", " self,\n", " input_ids: torch.LongTensor,\n", " beam_scorer: BeamScorer,\n", " logits_processor: Optional[LogitsProcessorList] = None,\n", " stopping_criteria: Optional[StoppingCriteriaList] = None,\n", " max_length: Optional[int] = None,\n", " pad_token_id: Optional[int] = None,\n", " eos_token_id: Optional[Union[int, List[int]]] = None,\n", " output_attentions: Optional[bool] = None,\n", " output_hidden_states: Optional[bool] = None,\n", " output_scores: Optional[bool] = None,\n", " return_dict_in_generate: Optional[bool] = None,\n", " synced_gpus: Optional[bool] = False,\n", " seq_length: Optional[int] = None,\n", " **model_kwargs,\n", " ) -> Union[BeamSearchOutput, torch.LongTensor]:\n", "\n", " logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()\n", " stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()\n", " pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id\n", " eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id\n", " if isinstance(eos_token_id, int):\n", " eos_token_id = [eos_token_id]\n", " output_scores = output_scores if output_scores is not None else self.generation_config.output_scores\n", " output_attentions = (\n", " output_attentions if output_attentions is not None else self.generation_config.output_attentions\n", " )\n", " output_hidden_states = (\n", " output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states\n", " )\n", "\n", " batch_size = len(beam_scorer._beam_hyps)\n", " num_beams = beam_scorer.num_beams\n", "\n", " batch_beam_size, cur_len = input_ids.shape\n", "\n", " # Overwrite cur_len\n", " cur_len = seq_length\n", "\n", " if num_beams * batch_size != batch_beam_size:\n", " raise ValueError(\n", " f\"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}.\"\n", " )\n", "\n", " # init attention / hidden states / scores tuples\n", " scores = () if (return_dict_in_generate and output_scores) else None\n", " beam_indices = (\n", " tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None\n", " )\n", "\n", " # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens\n", " # of the first beam are considered to avoid sampling the exact same tokens across all beams.\n", " # beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)\n", " beam_scores_device = \"cpu\"\n", " beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=beam_scores_device)\n", " beam_scores[:, 1:] = -1e9\n", " beam_scores = beam_scores.view((batch_size * num_beams,))\n", "\n", " while True:\n", " # prepare model inputs\n", " # From max_length-sized input_ids, select first\n", " # cur_len - 1 values.\n", " update_indices = torch.stack(\n", " [torch.arange(input_ids.size(0)), torch.tensor(cur_len - 1).repeat(input_ids.size(0))], dim=-1\n", " )\n", " input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None]\n", " model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs)\n", "\n", " next_token_scores, next_tokens, next_indices = self(\n", " **model_inputs,\n", " return_dict=True,\n", " output_attentions=output_attentions,\n", " output_hidden_states=output_hidden_states,\n", " beam_scores=beam_scores\n", " )\n", "\n", " # stateless\n", " beam_outputs = beam_scorer.process(\n", " input_ids.to(\"cpu\")[:, :cur_len],\n", " next_token_scores.to(\"cpu\"),\n", " next_tokens.to(\"cpu\"),\n", " next_indices.to(\"cpu\"),\n", " pad_token_id=pad_token_id,\n", " eos_token_id=eos_token_id,\n", " beam_indices=beam_indices,\n", " )\n", "\n", " beam_scores = beam_outputs[\"next_beam_scores\"]\n", " beam_next_tokens = beam_outputs[\"next_beam_tokens\"]\n", " beam_idx = beam_outputs[\"next_beam_indices\"]\n", "\n", " update_indices = torch.stack(\n", " [torch.arange(batch_beam_size), torch.tensor(cur_len - 1).repeat(batch_beam_size)], dim=-1\n", " )\n", " update_indices_2 = torch.stack(\n", " [torch.arange(batch_beam_size), torch.tensor(cur_len).repeat(batch_beam_size)], dim=-1\n", " )\n", " # First select beam_indices\n", " device = input_ids.device\n", " beam_idx_device = beam_idx.to(device=input_ids.device)\n", " input_ids[:, :] = input_ids[beam_idx_device.long(), :]\n", "\n", " # Then append new tokens\n", " input_ids[update_indices_2[:, 0], update_indices_2[:, 1], None] = beam_next_tokens.unsqueeze(-1).to(device).to(torch.long)\n", " input_ids = input_ids * 1 # Hack to materialize tensor\n", "\n", " # update generated ids, model inputs, and length for next step\n", " model_kwargs = self._update_model_kwargs_for_xla_generation(\n", " model_kwargs,\n", " batch_size=batch_beam_size,\n", " is_encoder_decoder=self.config.is_encoder_decoder,\n", " max_length=stopping_criteria.max_length,\n", " seq_length=cur_len,\n", " use_cache=model_kwargs[\"use_cache\"],\n", " )\n", " if model_kwargs[\"past_key_values\"] is not None:\n", " model_kwargs[\"past_key_values\"] = self._reorder_cache(model_kwargs[\"past_key_values\"], beam_idx.to(torch.int64))\n", "\n", " if return_dict_in_generate and output_scores:\n", " beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))\n", "\n", " # increase cur_len\n", " cur_len = cur_len + 1\n", "\n", " # stop when each sentence is finished, or if we exceed the maximum length\n", " stop_criterion_1 = beam_scorer.is_done\n", " if isinstance(stopping_criteria, list):\n", " if len(stopping_criteria) == 1:\n", " stopping_criteria = stopping_criteria[0]\n", "\n", " # Cases that can be handled in XLA without requiring\n", " # non-padded input_ids\n", " if isinstance(stopping_criteria, MaxLengthCriteria):\n", " stop_criterion_2 = cur_len >= stopping_criteria.max_length\n", " elif isinstance(stopping_criteria, MaxTimeCriteria):\n", " stop_criterion_2 = stopping_criteria(input_ids, scores)\n", " else:\n", " # Other cases will be handled on CPU\n", " batch_size, _ = input_ids.shape\n", " input_ids_cpu = input_ids.to(\"cpu\")\n", " mask = torch.cat(\n", " [torch.ones(batch_size, cur_len), torch.zeros(batch_size, input_ids.shape[1] - cur_len)], dim=1\n", " ).bool()\n", " input_ids_cpu = torch.masked_select(input_ids_cpu, mask).reshape((batch_size, cur_len))\n", " scores_cpu = scores.to(\"cpu\") if torch.is_tensor(scores) else scores\n", " stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu)\n", "\n", " if stop_criterion_1 or stop_criterion_2:\n", " if not synced_gpus:\n", " break\n", " else:\n", " this_peer_finished = True\n", "\n", " sequence_outputs = beam_scorer.finalize(\n", " input_ids.to(\"cpu\"),\n", " beam_scores.to(\"cpu\"),\n", " next_tokens.to(\"cpu\"),\n", " next_indices.to(\"cpu\"),\n", " pad_token_id=pad_token_id,\n", " eos_token_id=eos_token_id,\n", " max_length=stopping_criteria.max_length,\n", " beam_indices=beam_indices,\n", " )\n", "\n", " for k, v in sequence_outputs.items():\n", " if type(v) == torch.Tensor:\n", " sequence_outputs[k] = sequence_outputs[k].to(input_ids.device)\n", "\n", " return sequence_outputs[\"sequences\"]\n", "\n", "\n", " def greedy_search(\n", " self,\n", " input_ids: torch.LongTensor,\n", " logits_processor: Optional[LogitsProcessorList] = None,\n", " stopping_criteria: Optional[StoppingCriteriaList] = None,\n", " max_length: Optional[int] = None,\n", " pad_token_id: Optional[int] = None,\n", " eos_token_id: Optional[Union[int, List[int]]] = None,\n", " output_attentions: Optional[bool] = None,\n", " output_hidden_states: Optional[bool] = None,\n", " output_scores: Optional[bool] = None,\n", " return_dict_in_generate: Optional[bool] = None,\n", " seq_length: Optional[int] = int,\n", " streamer: Optional[\"BaseStreamer\"] = None,\n", " **model_kwargs,\n", " ) -> Union[GreedySearchOutput, torch.LongTensor]:\n", " \"\"\"\n", " Overriding greedy sampling to use next tokens returned from neuron device instead of logits.\n", " \"\"\"\n", " # init values\n", " logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()\n", " use_cache = model_kwargs[\"use_cache\"] if \"use_cache\" in model_kwargs else False\n", " stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()\n", " pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id\n", " eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id\n", " if isinstance(eos_token_id, int):\n", " eos_token_id = [eos_token_id]\n", " eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None\n", " output_scores = output_scores if output_scores is not None else self.generation_config.output_scores\n", " output_attentions = (\n", " output_attentions if output_attentions is not None else self.generation_config.output_attentions\n", " )\n", " output_hidden_states = (\n", " output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states\n", " )\n", "\n", " # init attention / hidden states / scores tuples\n", " scores = () if (return_dict_in_generate and output_scores) else None\n", " decoder_attentions = () if (return_dict_in_generate and output_attentions) else None\n", " cross_attentions = () if (return_dict_in_generate and output_attentions) else None\n", " decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None\n", "\n", "\n", " # keep track of which sequences are already finished\n", " unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)\n", "\n", " this_peer_finished = False # used by synced_gpus only\n", " while True:\n", "\n", " # prepare model inputs\n", " # From max_length-sized input_ids, select first\n", " # seq_length - 1 values.\n", "\n", " if model_kwargs.get(\"past_key_values\") is None:\n", " input_ids_ = input_ids[:, :seq_length]\n", " else:\n", " update_indices = torch.stack(\n", " [torch.arange(input_ids.size(0)), torch.tensor(seq_length - 1).repeat(input_ids.size(0))],\n", " dim=-1,\n", " )\n", " input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None]\n", "\n", " model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs)\n", " \n", " # forward pass to get next token\n", " output = self(\n", " **model_inputs,\n", " return_dict=True,\n", " output_attentions=output_attentions,\n", " output_hidden_states=output_hidden_states,\n", " )\n", " next_tokens = output[0]\n", "\n", " # finished sentences should have their next token be a padding token\n", " if eos_token_id is not None:\n", " if pad_token_id is None:\n", " raise ValueError(\"If `eos_token_id` is defined, make sure that `pad_token_id` is defined.\")\n", " next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)\n", "\n", " # update generated ids, model inputs, and length for next step\n", "\n", " batch_size, _ = input_ids.shape\n", " update_indices = torch.stack(\n", " [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1\n", " )\n", " input_ids[update_indices[:, 0], update_indices[:, 1]] = next_tokens[:]\n", " model_kwargs = self._update_model_kwargs_for_xla_generation(\n", " model_kwargs,\n", " batch_size=batch_size,\n", " is_encoder_decoder=self.config.is_encoder_decoder,\n", " max_length=stopping_criteria.max_length,\n", " seq_length=seq_length,\n", " use_cache=use_cache,\n", " )\n", "\n", " seq_length += 1\n", "\n", " # if eos_token was found in one sentence, set sentence to finished\n", " if eos_token_id_tensor is not None:\n", " unfinished_sequences = unfinished_sequences.mul(\n", " next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)\n", " )\n", "\n", " # stop when each sentence is finished, or if we exceed the maximum length\n", " stop_criterion_1 = unfinished_sequences.max() == 0\n", "\n", " if isinstance(stopping_criteria, list):\n", " if len(stopping_criteria) == 1:\n", " stopping_criteria = stopping_criteria[0]\n", "\n", " # Cases that can be handled in XLA without requiring\n", " # non-padded input_ids\n", " if isinstance(stopping_criteria, MaxLengthCriteria):\n", " stop_criterion_2 = seq_length >= stopping_criteria.max_length\n", " elif isinstance(stopping_criteria, MaxTimeCriteria):\n", " stop_criterion_2 = stopping_criteria(input_ids, scores)\n", " else:\n", " # Other cases will be handled on CPU\n", " batch_size, _ = input_ids.shape\n", " mask = torch.cat(\n", " [torch.ones(batch_size, seq_length), torch.zeros(batch_size, input_ids.shape[1] - seq_length)],\n", " dim=1,\n", " ).bool()\n", " input_ids_cpu = torch.masked_select(input_ids, mask).reshape((batch_size, seq_length)).to(\"cpu\")\n", " scores_cpu = scores.to(\"cpu\") if torch.is_tensor(scores) else scores\n", " stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu)\n", "\n", " if stop_criterion_1 or stop_criterion_2:\n", " this_peer_finished = True\n", "\n", " if this_peer_finished:\n", " break\n", "\n", " if streamer is not None:\n", " streamer.end()\n", "\n", " return input_ids\n", " \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's test inference on CPU with all the wrappers before tracing." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Let's set some run parameters\n", "\n", "model_name = \"t5-large\"\n", "num_beams = 1\n", "num_return_sequences = 1\n", "max_length = 128" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Results:\n", "1 Lassen Sie uns gutes Essen essen.\n" ] } ], "source": [ "from transformers import T5Tokenizer\n", "\n", "\n", "prompt=\"translate English to German: Lets eat good food.\"\n", " \n", "tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=max_length)\n", "model = T5Wrapper.from_pretrained(model_name)\n", "\n", "model.encoder = EncoderWrapper(model.encoder, model.decoder, model.config, num_beams, max_length, \"cpu\", num_beams)\n", "setattr(model.encoder, 'main_input_name', 'input_ids') # Attribute required by beam search\n", "\n", "model.decoder = DecoderWrapper(decoder=model.decoder,\n", " lm_head=model.lm_head,\n", " model_config=model.config,\n", " num_beams=num_beams,\n", " max_length=max_length,\n", " device=\"cpu\")\n", "\n", "output = model.generate(tokenizer=tokenizer,\n", " prompt=prompt,\n", " max_length=max_length,\n", " num_beams=num_beams,\n", " num_return_sequences=num_return_sequences,\n", " device=\"cpu\")\n", "\n", "results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]\n", "\n", "print('Results:')\n", "for i, summary in enumerate(results):\n", " print(i + 1, summary)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that the wrappers are running as expected, let's trace the encoder, and decoder. To trace these functions, we pass the function and a sample input to the trace function. The result of the trace stage will be a static executable where the operations to be run upon inference are determined during compilation. This means that when inferring, the resulting Neuron model must be executed with tensors that are the exact same shape as those provided at compilation time. If a model is given a tensor at inference time whose shape does not match the tensor given at compilation time, an error will occur.\n", "\n", "The decoder wrapper returns the new state of the cache as an output which is copied back to the CPU. As the cache is a large tensor, copying it to and from the XLA device for each decoder invocation will significantly slow down the inference. Instead, we can use input output aliasing, a feature of `torch_neuronx` to keep these tensors on device rather than copying back to the CPU. To use input output aliasing, we need to map the outputs to input parameters while tracing. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch_neuronx\n", "\n", "from transformers import T5Tokenizer, T5ForConditionalGeneration\n", "\n", "def trace_encoder(model: T5ForConditionalGeneration,\n", " tokenizer: T5Tokenizer,\n", " max_length: int,\n", " num_beams: int):\n", " \n", " # Trace encoder\n", " batch_encoding = tokenizer(\"translate English to German: Lets go home now\",\n", " max_length=max_length, truncation=True, padding='max_length', return_tensors=\"pt\")\n", " input_ids = batch_encoding['input_ids']\n", " attention_mask = batch_encoding['attention_mask']\n", "\n", " encoder = EncoderWrapper(model.encoder, model.decoder, model.config, num_beams, max_length, \"xla\", num_beams)\n", " traced_encoder = torch_neuronx.trace(encoder, (input_ids, attention_mask), compiler_workdir=\"/tmp/encoder/\")\n", " setattr(traced_encoder, 'main_input_name', 'input_ids') # Attribute required by beam search\n", "\n", " return traced_encoder\n", "\n", "def trace_decoder(model: T5ForConditionalGeneration,\n", " num_beams: int,\n", " max_length: int):\n", "\n", " decoder = DecoderWrapper(decoder=model.decoder,\n", " lm_head=model.lm_head,\n", " model_config=model.config,\n", " num_beams=num_beams,\n", " max_length=max_length,\n", " device=\"xla\")\n", "\n", " # We create mock inputs so we can trace the decoder\n", " decoder_input_ids = torch.ones((num_beams, 1), dtype=torch.int64)\n", " decoder_attention_mask = torch.ones((num_beams, max_length), dtype=torch.int32)\n", " encoder_attention_mask = torch.ones((num_beams, max_length), dtype=torch.int64)\n", " encoder_hidden_states = torch.ones((num_beams, max_length, model.config.d_model), dtype=torch.float32)\n", "\n", " beam_idx = torch.arange(0, num_beams, dtype=torch.int64)\n", " beam_scores = torch.zeros((num_beams,), dtype=torch.float)\n", "\n", " num_outputs_from_trace = 3 if num_beams > 1 else 1\n", "\n", " aliases = {}\n", " for i in range(len(decoder.past_key_values_sa)):\n", " aliases[decoder.past_key_values_sa[i]] = i + num_outputs_from_trace\n", " for i in range(len(decoder.past_key_values_ca)):\n", " aliases[decoder.past_key_values_ca[i]] = len(decoder.past_key_values_sa) + i + num_outputs_from_trace\n", "\n", " traced_decoder = torch_neuronx.trace(decoder, (\n", " decoder_input_ids,\n", " decoder_attention_mask,\n", " encoder_hidden_states,\n", " encoder_attention_mask,\n", " beam_idx,\n", " beam_scores,\n", " ), input_output_aliases=aliases, compiler_workdir=\"/tmp/decoder/\")\n", "\n", " return traced_decoder\n", "\n", "\n", "tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=max_length)\n", "model = T5ForConditionalGeneration.from_pretrained(model_name)\n", "\n", "# We enable this flag to ensure model uses attention key value caching\n", "model.config.use_cache = True\n", "\n", "traced_encoder = trace_encoder(model, tokenizer, max_length, num_beams)\n", "traced_decoder = trace_decoder(model, num_beams, max_length)\n", "\n", "torch.jit.save(traced_encoder, \"TracedEncoder.pt\")\n", "torch.jit.save(traced_decoder, \"TracedDecoder.pt\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run inference with greedy decoding\n", "Now that we have the traced model, let's use it for inference. " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Results:\n", "1 Lassen Sie uns gutes Essen essen.\n" ] } ], "source": [ "runtime = torch.classes.neuron.Runtime()\n", "runtime.initialize()\n", "runtime.set_default_neuron_cores(0, 1)\n", "\n", "tokenizer = T5Tokenizer.from_pretrained(model_name)\n", "model = T5Wrapper.from_pretrained(model_name)\n", "\n", "model.encoder = torch.jit.load(\"TracedEncoder.pt\")\n", "# Attribute required by beam search\n", "setattr(model.encoder, 'main_input_name', 'input_ids') \n", "\n", "model.decoder = torch.jit.load(\"TracedDecoder.pt\")\n", "torch_neuronx.move_trace_to_device(model.decoder, 0)\n", "\n", "\n", "output = model.generate(tokenizer=tokenizer,\n", " prompt=\"translate English to German: Lets eat good food.\",\n", " max_length=max_length,\n", " num_beams=num_beams,\n", " num_return_sequences=num_return_sequences,\n", " device=\"xla\")\n", "\n", "results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]\n", "\n", "print('Results:')\n", "for i, summary in enumerate(results):\n", " print(i + 1, summary)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run inference with beam search" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's set some run parameters for beam search\n", "\n", "model_name = \"t5-large\"\n", "num_beams = 4\n", "num_return_sequences = 4\n", "max_length = 128\n", "\n", "tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=max_length)\n", "model = T5ForConditionalGeneration.from_pretrained(model_name)\n", "model.config.use_cache = True\n", "\n", "traced_encoder = trace_encoder(model, tokenizer, max_length, num_beams)\n", "traced_decoder = trace_decoder(model, num_beams, max_length)\n", "\n", "torch.jit.save(traced_encoder, \"TracedEncoder.pt\")\n", "torch.jit.save(traced_decoder, \"TracedDecoder.pt\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Results:\n", "1 Lassen Sie uns gutes Essen essen.\n", "2 Lassen Sie uns gutes Essen zu essen.\n", "3 Lassen Sie uns essen gutes Essen.\n", "4 Lassen Sie uns gutes Essen.\n" ] } ], "source": [ "tokenizer = T5Tokenizer.from_pretrained(model_name)\n", "model = T5Wrapper.from_pretrained(model_name)\n", "\n", "model.encoder = torch.jit.load(\"TracedEncoder.pt\")\n", "# Attribute required by beam search\n", "setattr(model.encoder, 'main_input_name', 'input_ids') \n", "\n", "model.decoder = torch.jit.load(\"TracedDecoder.pt\")\n", "torch_neuronx.move_trace_to_device(model.decoder, 0)\n", "\n", "\n", "output = model.generate(tokenizer=tokenizer,\n", " prompt=\"translate English to German: Lets eat good food.\",\n", " max_length=max_length,\n", " num_beams=num_beams,\n", " num_return_sequences=num_return_sequences,\n", " device=\"xla\")\n", "\n", "results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]\n", "\n", "print('Results:')\n", "for i, summary in enumerate(results):\n", " print(i + 1, summary)" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }