Transformers MarianMT Tutorial
Contents
Transformers MarianMT Tutorial#
In this tutorial, you will deploy the HuggingFace MarianMT model for text translation.
This Jupyter notebook should be run on an inf1.6xlarge instance since you will be loading and compiling several large models.
Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the PyTorch Installation Guide. You can select the kernel from the “Kernel -> Change Kernel” option on the top of this Jupyter notebook page.
To generate text, you will be using the beam search algorithm to incrementally generate token candidates until the full output text has been created. Unlike simple single-pass models, this algorithm divides the work into two distinct phases:
Encoder: Convert the input text into an encoded representation. (Executed once)
Decoder: Use the encoded representation of the input text and the current output tokens to incrementally generate the set of next best candidate tokens. (Executed many times)
In this tutorial you will perform the following steps:
Compile: Compile both the Encoder and Decoder for Neuron using simplified interfaces for inference.
Infer: Run on CPU and Neuron and compare results.
Install Dependencies:#
This tutorial has the following dependencies:
transformers==4.0.1
torch-neuron==1.7.*
sentencepiece
neuron-cc[tensorflow]
The following will install the required transformers
version. Note that encoder/decoder API changes across different minor versions requires that you are specific about the version used. Also note that the torch-neuron
version is pinned due to transformer
compatibility issues.
[ ]:
!pip install --force-reinstall --extra-index-url=https://pip.repos.neuron.amazonaws.com "torch-neuron==1.7.*" "transformers==4.0.1" "protobuf<4" sentencepiece "neuron-cc[tensorflow]"
Parameters#
The parameters of a generative model can be tuned for different use-cases. In this example, you’ll tailor the parameters to a single inference beam search for an on-demand inference use-case. See the MarianConfig for parameter details.
Rather than varying the encoder/decoder token sizes at runtime, you must define these parameters prior to compilation. The encoder/decoder token sizes are important tunable parameters as a large token sequence will offer greater sentence length flexibility but perform worse than a small token sequence.
To maximize performance on Neuron, the num_beams
, max_encode_length
and max_decoder_length
should be made as small as possible for the use-case.
For this tutorial you will use a model that translates sentences of up to 32 token from English to German.
[ ]:
model_name = "Helsinki-NLP/opus-mt-en-de" # English -> German model
num_texts = 1 # Number of input texts to decode
num_beams = 4 # Number of beams per input text
max_encoder_length = 32 # Maximum input token length
max_decoder_length = 32 # Maximum output token length
Imports#
On text generation tasks, HuggingFace Transformers defines a GenerationMixin base class which provides standard methods and algorithms to generate text. For this tutorial, you will be using the beam search algorithm on encoder/decoder architectures.
To be able to use these methods, you will be defining your own class derived from the GenerationMixin class to run a beam search. This will invoke the encoder and decoder layers in a way that is compatible with fixed sized inputs and traced modules. This means you must import the base class and the output objects (Seq2SeqLMOutput, BaseModelOutput) used by the beam_search algorithm.
[ ]:
import os
import torch
import numpy as np
from torch.nn import functional as F
from transformers import MarianMTModel, MarianTokenizer, MarianConfig
from transformers.generation_utils import GenerationMixin
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
import torch_neuron
CPU Model Execution#
Start by executing the model on CPU to test its execution.
The following defines the inference function which will be used to compare the Neuron and CPU output. In this example you will display all beam search sequences that were generated. On a real on-demand use case, set the num_return_sequences
to 1
to return only the top result.
[ ]:
def infer(model, tokenizer, text):
# Truncate and pad the max length to ensure that the token size is compatible with fixed-sized encoder (Not necessary for pure CPU execution)
batch = tokenizer(text, max_length=max_decoder_length, truncation=True, padding='max_length', return_tensors="pt")
output = model.generate(**batch, max_length=max_decoder_length, num_beams=num_beams, num_return_sequences=num_beams)
results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]
print('Texts:')
for i, summary in enumerate(results):
print(i + 1, summary)
[ ]:
model_cpu = MarianMTModel.from_pretrained(model_name)
model_cpu.eval()
tokenizer_cpu = MarianTokenizer.from_pretrained(model_name)
sample_text = "I am a small frog."
[ ]:
infer(model_cpu, tokenizer_cpu, sample_text)
Encoder & Decoder Modules#
Here you will define wrappers around the encoder and decoder portions of the generation model that are compatible with torch.jit.trace
as well as fixed-sized inputs.
Important features which are distinct from the default configuration:
A fixed sized
causal_mask
in theNeuronDecoder
rather than varying the size for each iteration. This is because Neuron requires paddedinput_ids
input rather than the default behavior where this grows for each beam search iteration.Disabled
return_dict
. When this is enabled, the network usesdataclass
type outputs which are not compatible withtorch.jit.trace
.Disabled
use_cache
. When this option is enabled, the network expects a collection of cache tensors which grow upon each iteration. Since Neuron requires fixed sized inputs, this must be disabled.The
GenerationMixin:beam_search
implementation uses only the logits for the current iteration index from the original decoder layer output. Since inputs are padded, performance can be improved by selecting only a subset of the hidden state prior to the final linear layer. For efficiency on Neuron, this reduction (reduce
) uses an elementwise-multiply to mask out the unused hidden values and then sums along an axis.Since a reduction step is insterted between the decoder output and the final logit calculation, the original
model
attribute is not used. Instead theNeuronDecoder
class combines the decoder, reducer, and linear layers into a combined forward pass. In the original model there is a clear distinction between the decoder layer and the final linear layer. These layers are fused together to get one large fully optimized graph.
[ ]:
def reduce(hidden, index):
_, n_length, _ = hidden.shape
# Create selection mask
mask = torch.arange(n_length, dtype=torch.float32) == index
mask = mask.view(1, -1, 1)
# Broadcast mask
masked = torch.multiply(hidden, mask)
# Reduce along 1st dimension
summed = torch.sum(masked, 1)
return torch.unsqueeze(summed, 1)
class NeuronEncoder(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.encoder = model.model.encoder
def forward(self, input_ids, attention_mask):
return self.encoder(input_ids, attention_mask=attention_mask, return_dict=False)
class NeuronDecoder(torch.nn.Module):
def __init__(self, model, max_length):
super().__init__()
self.weight = model.model.shared.weight.clone().detach()
self.bias = model.final_logits_bias.clone().detach()
self.decoder = model.model.decoder
self.max_length = max_length
def forward(self, input_ids, attention_mask, encoder_outputs, index):
# Build a fixed sized causal mask for the padded decoder input ids
mask = np.triu(np.ones((self.max_length, self.max_length)), 1)
mask[mask == 1] = -np.inf
causal_mask = torch.tensor(mask, dtype=torch.float)
# Invoke the decoder
hidden, = self.decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_outputs,
encoder_padding_mask=attention_mask,
decoder_padding_mask=None,
decoder_causal_mask=causal_mask,
return_dict=False,
use_cache=False,
)
# Reduce decoder outputs to the specified index (current iteration)
hidden = reduce(hidden, index)
# Compute final linear layer for token probabilities
logits = F.linear(
hidden,
self.weight,
bias=self.bias
)
return logits
GenerationMixin Class#
To be able to use GenerationMixin:beam_search
you must define your own class implementation that invokes the traced NeuronEncoder
and NeuronDecoder
modules. The standard generator model implementation will not work by default because it is not designed to invoke the traced models with padded inputs.
Below, the NeuronGeneration:trace
method uses the loaded generator model and traces both the Encoder and Decoder.
Next, the following methods are copied directly from the to the original class to ensure that inference behavior is identical: - adjust_logits_during_generation
- _force_token_id_to_be_generated
To invoke the Encoder and Decoder traced modules in a way that is compatible with the GenerationMixin:beam_search
implementation, the get_encoder
, __call__
, and prepare_inputs_for_generation
methods are overriden.
Lastly, the class defines methods for serialization so that the model can be easily saved and loaded.
[ ]:
class NeuronGeneration(PreTrainedModel, GenerationMixin):
def trace(self, model, num_texts, num_beams, max_encoder_length, max_decoder_length):
"""
Traces the encoder and decoder modules for use on Neuron.
This function fixes the network to the given sizes. Once the model has been
compiled to a given size, the inputs to these networks must always be of
fixed size.
Args:
model (GenerationMixin): The transformer-type generator model to trace
num_texts (int): The number of input texts to translate at once
num_beams (int): The number of beams to computer per text
max_encoder_length (int): The maximum number of encoder tokens
max_encoder_length (int): The maximum number of decoder tokens
"""
self.config.max_decoder_length = max_decoder_length
# Trace the encoder
inputs = (
torch.ones((num_texts, max_encoder_length), dtype=torch.long),
torch.ones((num_texts, max_encoder_length), dtype=torch.long),
)
encoder = NeuronEncoder(model)
self.encoder = torch_neuron.trace(encoder, inputs)
# Trace the decoder (with expanded inputs)
batch_size = num_texts * num_beams
inputs = (
torch.ones((batch_size, max_decoder_length), dtype=torch.long),
torch.ones((batch_size, max_encoder_length), dtype=torch.long),
torch.ones((batch_size, max_encoder_length, model.config.d_model), dtype=torch.float),
torch.tensor(0),
)
decoder = NeuronDecoder(model, max_decoder_length)
self.decoder = torch_neuron.trace(decoder, inputs)
# ------------------------------------------------------------------------
# Beam Search Methods (Copied directly from transformers)
# ------------------------------------------------------------------------
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits
@staticmethod
def _force_token_id_to_be_generated(scores, token_id) -> None:
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
# ------------------------------------------------------------------------
# Encoder/Decoder Invocation
# ------------------------------------------------------------------------
def prepare_inputs_for_generation(
self,
decoder_input_ids,
encoder_outputs=None,
attention_mask=None,
**model_kwargs
):
# Pad the inputs for Neuron
current_length = decoder_input_ids.shape[1]
pad_size = self.config.max_decoder_length - current_length
return dict(
input_ids=F.pad(decoder_input_ids, (0, pad_size)),
attention_mask=attention_mask,
encoder_outputs=encoder_outputs.last_hidden_state,
current_length=torch.tensor(current_length - 1),
)
def get_encoder(self):
"""Helper to invoke the encoder and wrap the results in the expected structure"""
def encode(input_ids, attention_mask, **kwargs):
output, = self.encoder(input_ids, attention_mask)
return BaseModelOutput(
last_hidden_state=output,
)
return encode
def __call__(self, input_ids, attention_mask, encoder_outputs, current_length, **kwargs):
"""Helper to invoke the decoder and wrap the results in the expected structure"""
logits = self.decoder(input_ids, attention_mask, encoder_outputs, current_length)
return Seq2SeqLMOutput(logits=logits)
# ------------------------------------------------------------------------
# Serialization
# ------------------------------------------------------------------------
def save_pretrained(self, directory):
if os.path.isfile(directory):
print(f"Provided path ({directory}) should be a directory, not a file")
return
os.makedirs(directory, exist_ok=True)
torch.jit.save(self.encoder, os.path.join(directory, 'encoder.pt'))
torch.jit.save(self.decoder, os.path.join(directory, 'decoder.pt'))
self.config.save_pretrained(directory)
@classmethod
def from_pretrained(cls, directory):
config = MarianConfig.from_pretrained(directory)
obj = cls(config)
obj.encoder = torch.jit.load(os.path.join(directory, 'encoder.pt'))
obj.decoder = torch.jit.load(os.path.join(directory, 'decoder.pt'))
return obj
@property
def device(self):
return torch.device('cpu')
Execution#
Using everything together from above, now the process to deploy the model is as follows:
Compile the model
Serialize an artifact
Load the serialized artifact
Execute the model Neuron
[ ]:
# This is the name of the folder where the artifacts will be stored on disk
neuron_name = 'NeuronMarianMT'
[ ]:
model_neuron = NeuronGeneration(model_cpu.config)
# 1. Compile the model
# Note: This may take a couple of minutes since both the encoder/decoder will be compiled
model_neuron.trace(
model=model_cpu,
num_texts=num_texts,
num_beams=num_beams,
max_encoder_length=max_encoder_length,
max_decoder_length=max_decoder_length,
)
# 2. Serialize an artifact
# After this call you will have an `encoder.pt`, `decoder.pt` and `config.json` in the neuron_name folder
model_neuron.save_pretrained(neuron_name)
[ ]:
# 3. Load the serialized artifact
model_neuron = NeuronGeneration.from_pretrained(neuron_name)
[ ]:
# 4. Execute the model Neuron
infer(model_neuron, tokenizer_cpu, sample_text)
Comparing the Neuron execution to the original CPU implementation, you will see the exact same generated text.
[ ]:
# CPU execution for comparison
infer(model_cpu, tokenizer_cpu, sample_text)
Appendix - BART (Mask Filling Task)#
These NeuronGeneration
class can be applied to the BART model for the task of filling in mask tokens.
[ ]:
from transformers import BartForConditionalGeneration, BartTokenizer
bart_name = "facebook/bart-large"
bart_model = BartForConditionalGeneration.from_pretrained(bart_name, force_bos_token_to_be_generated=True)
bart_tokenizer = BartTokenizer.from_pretrained(bart_name)
bart_text = "UN Chief Says There Is No <mask> in Syria"
[ ]:
# CPU Execution
infer(bart_model, bart_tokenizer, bart_text)
[ ]:
# Neuron Execution
bart_neuron = NeuronGeneration(bart_model.config)
bart_neuron.trace(
model=bart_model,
num_texts=num_texts,
num_beams=num_beams,
max_encoder_length=max_encoder_length,
max_decoder_length=max_decoder_length,
)
infer(bart_neuron, bart_tokenizer, bart_text)
Appendix - Pegasus (Summarization Task)#
These NeuronGeneration
class can be applied to the Pegasus model for summarization.
[ ]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
pegasus_name = 'google/pegasus-xsum'
pegasus_model = PegasusForConditionalGeneration.from_pretrained(pegasus_name)
pegasus_tokenizer = PegasusTokenizer.from_pretrained(pegasus_name)
pegasus_text = "PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires."
[ ]:
# CPU Execution
infer(pegasus_model, pegasus_tokenizer, pegasus_text)
[ ]:
# Neuron Execution
pegasus_neuron = NeuronGeneration(pegasus_model.config)
pegasus_neuron.trace(
model=pegasus_model,
num_texts=num_texts,
num_beams=num_beams,
max_encoder_length=max_encoder_length,
max_decoder_length=max_decoder_length,
)
infer(pegasus_neuron, pegasus_tokenizer, pegasus_text)