Transformers MarianMT Tutorial#

In this tutorial, you will deploy the HuggingFace MarianMT model for text translation.

This Jupyter notebook should be run on an inf1.6xlarge instance since you will be loading and compiling several large models.

Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the PyTorch Installation Guide. You can select the kernel from the “Kernel -> Change Kernel” option on the top of this Jupyter notebook page.

To generate text, you will be using the beam search algorithm to incrementally generate token candidates until the full output text has been created. Unlike simple single-pass models, this algorithm divides the work into two distinct phases:

  • Encoder: Convert the input text into an encoded representation. (Executed once)

  • Decoder: Use the encoded representation of the input text and the current output tokens to incrementally generate the set of next best candidate tokens. (Executed many times)

In this tutorial you will perform the following steps:

  • Compile: Compile both the Encoder and Decoder for Neuron using simplified interfaces for inference.

  • Infer: Run on CPU and Neuron and compare results.

Finally, a completely unrolled decoder will be built which simplifies the implementation at the cost of performing fixed-length inferences.

Install Dependencies:#

This tutorial has the following dependencies:

  • transformers==4.25.1

  • torch-neuron

  • sentencepiece

  • neuron-cc[tensorflow]

The following will install the required transformers version. Note that encoder/decoder API changes across different minor versions requires that you are specific about the version used. Also note that the torch-neuron version is pinned due to transformer compatibility issues.

[ ]:
!pip install sentencepiece transformers==4.26.1

Parameters#

The parameters of a generative model can be tuned for different use-cases. In this example, you’ll tailor the parameters to a single inference beam search for an on-demand inference use-case. See the MarianConfig for parameter details.

Rather than varying the encoder/decoder token sizes at runtime, you must define these parameters prior to compilation. The encoder/decoder token sizes are important tunable parameters as a large token sequence will offer greater sentence length flexibility but perform worse than a small token sequence.

To maximize performance on Neuron, the num_beams, max_encode_length and max_decoder_length should be made as small as possible for the use-case.

For this tutorial you will use a model that translates sentences of up to 32 token from English to German.

[ ]:
%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect
model_name = "Helsinki-NLP/opus-mt-en-de" # English -> German model
num_texts = 1                             # Number of input texts to decode
num_beams = 4                             # Number of beams per input text
max_encoder_length = 32                   # Maximum input token length
max_decoder_length = 32                   # Maximum output token length

CPU Model Inference#

Start by executing the model on CPU to test its execution.

The following defines the inference function which will be used to compare the Neuron and CPU output. In this example you will display all beam search sequences that were generated. For a real on-demand use case, set the num_beams to 1 to return only the top result.

[ ]:
def infer(model, tokenizer, text):

    # Truncate and pad the max length to ensure that the token size is compatible with fixed-sized encoder (Not necessary for pure CPU execution)
    batch = tokenizer(text, max_length=max_decoder_length, truncation=True, padding='max_length', return_tensors="pt")
    output = model.generate(**batch, max_length=max_decoder_length, num_beams=num_beams, num_return_sequences=num_beams)
    results = [tokenizer.decode(t, skip_special_tokens=True) for t in output]

    print('Texts:')
    for i, summary in enumerate(results):
        print(i + 1, summary)

Note that after loading the model, we also set the maximum length. This will later be used to limit the size of the compiled model.

[ ]:
from transformers import MarianMTModel, MarianTokenizer

model_cpu = MarianMTModel.from_pretrained(model_name)
model_cpu.config.max_length = max_decoder_length
model_cpu.eval()

tokenizer = MarianTokenizer.from_pretrained(model_name)

sample_text = "I am a small frog."
[ ]:
infer(model_cpu, tokenizer, sample_text)

Padded Model#

In order to perform inference on Neuron, the model must be changed in a way that it supports tracing and fixed-sized inputs. One way in which this is possible is to use a pad the model inputs to the maximum possible tensor sizes. The benefit of using a padded model is that it supports variable length text generation up to a specified length max_decoder_length. A consequence of padding is that it can negatively impact performance due to large data transfers.

PaddedEncoder & PaddedDecoder Modules#

Here you will define wrappers around the encoder and decoder portions of the generation model that are compatible with torch.jit.trace as well as fixed-sized inputs.

The following are important features which are distinct from the default configuration:

  1. Disabled return_dict. When this is enabled, the network uses dataclass type outputs which are not compatible with torch.jit.trace.

  2. Disabled use_cache. When this option is enabled, the network expects a collection of cache tensors which grow upon each iteration. Since Neuron requires fixed sized inputs, this must be disabled.

  3. The GenerationMixin:beam_search implementation uses only the logits for the current iteration index from the original decoder layer output. Since inputs must be padded, performance can be improved by selecting only a subset of the hidden state prior to the final linear layer. For efficiency on Neuron, this reduction uses an elementwise-multiply to mask out the unused hidden values and then sums along an axis.

  4. Since a reduction step is insterted between the decoder output and the final logit calculation, the original model attribute is not used. Instead the PaddedDecoder class combines the decoder, reducer, and linear layers into a combined forward pass. In the original model there is a clear distinction between the decoder layer and the final linear layer. These layers are fused together to get one large fully optimized graph.

[ ]:
import torch
from torch.nn import functional as F


class PaddedEncoder(torch.nn.Module):

    def __init__(self, model):
        super().__init__()
        self.encoder = model.model.encoder
        self.main_input_name = 'input_ids'

    def forward(self, input_ids, attention_mask):
        return self.encoder(input_ids, attention_mask=attention_mask, return_dict=False)


class PaddedDecoder(torch.nn.Module):

    def __init__(self, model):
        super().__init__()
        self.weight = model.model.shared.weight.clone().detach()
        self.bias = model.final_logits_bias.clone().detach()
        self.decoder = model.model.decoder

    def forward(self, input_ids, attention_mask, encoder_outputs, index):

        # Invoke the decoder
        hidden, = self.decoder(
            input_ids=input_ids,
            encoder_hidden_states=encoder_outputs,
            encoder_attention_mask=attention_mask,
            return_dict=False,
            use_cache=False,
        )

        _, n_length, _ = hidden.shape

        # Create selection mask
        mask = torch.arange(n_length, dtype=torch.float32) == index
        mask = mask.view(1, -1, 1)

        # Broadcast mask
        masked = torch.multiply(hidden, mask)

        # Reduce along 1st dimension
        hidden = torch.sum(masked, 1, keepdims=True)

        # Compute final linear layer for token probabilities
        logits = F.linear(
            hidden,
            self.weight,
            bias=self.bias
        )
        return logits

PaddedGenerator - GenerationMixin Class#

On text generation tasks, HuggingFace Transformers defines a GenerationMixin base class which provides standard methods and algorithms to generate text. For this tutorial, you will be using the beam search algorithm on encoder/decoder architectures.

To be able to use these methods, you will be defining your own class derived from the GenerationMixin class to run a beam search. This will invoke the encoder and decoder layers in a way that is compatible with fixed sized inputs and traced modules. This means you must import the base class and the output objects (Seq2SeqLMOutput, BaseModelOutput) used by the beam_search algorithm.

The GenerationMixin:generate method will use GenerationMixin:beam_search which requires that you to define your own class implementation that invokes the PaddedEncoder and PaddedDecoder modules using padded inputs. The standard generator model implementation will not work by default because it is intended to infer with variable-sized (growing) input tensors.

The from_model method is defined to create the PaddedGenerator from an existing pretrained generator class.

To invoke the Encoder and Decoder traced modules in a way that is compatible with the GenerationMixin:beam_search implementation, the get_encoder, __call__, and prepare_inputs_for_generation methods are overriden.

Lastly, the class defines methods for serialization so that the model can be easily saved and loaded.

[ ]:
import os

from transformers import GenerationMixin, AutoConfig
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
from transformers.modeling_utils import PreTrainedModel


class PaddedGenerator(PreTrainedModel, GenerationMixin):

    @classmethod
    def from_model(cls, model):
        generator = cls(model.config)
        generator.encoder = PaddedEncoder(model)
        generator.decoder = PaddedDecoder(model)
        return generator

    def prepare_inputs_for_generation(
            self,
            input_ids,
            encoder_outputs=None,
            attention_mask=None,
            **kwargs,
    ):
        # Pad the inputs for Neuron
        current_length = input_ids.shape[1]
        pad_size = self.config.max_length - current_length
        return dict(
            input_ids=F.pad(input_ids, (0, pad_size)),
            attention_mask=attention_mask,
            encoder_outputs=encoder_outputs.last_hidden_state,
            current_length=torch.tensor(current_length - 1),
        )

    def get_encoder(self):
        def encode(input_ids, attention_mask, **kwargs):
            output, = self.encoder(input_ids, attention_mask)
            return BaseModelOutput(
                last_hidden_state=output,
            )
        return encode

    def forward(self, input_ids, attention_mask, encoder_outputs, current_length, **kwargs):
        logits = self.decoder(input_ids, attention_mask, encoder_outputs, current_length)
        return Seq2SeqLMOutput(logits=logits)

    @property
    def device(self):  # Attribute required by beam search
        return torch.device('cpu')

    def save_pretrained(self, directory):
        if os.path.isfile(directory):
            print(f"Provided path ({directory}) should be a directory, not a file")
            return
        os.makedirs(directory, exist_ok=True)
        torch.jit.save(self.encoder, os.path.join(directory, 'encoder.pt'))
        torch.jit.save(self.decoder, os.path.join(directory, 'decoder.pt'))
        self.config.save_pretrained(directory)

    @classmethod
    def from_pretrained(cls, directory):
        config = AutoConfig.from_pretrained(directory)
        obj = cls(config)
        obj.encoder = torch.jit.load(os.path.join(directory, 'encoder.pt'))
        obj.decoder = torch.jit.load(os.path.join(directory, 'decoder.pt'))
        setattr(obj.encoder, 'main_input_name', 'input_ids')  # Attribute required by beam search
        return obj

Padded CPU Inference#

To start, it is important to ensure that the transformations we have made to the model were successful. Using the classes defined above we can test that the padded model execution on CPU is identical to the original output also running on CPU.

[ ]:
padded_model_cpu = PaddedGenerator.from_model(model_cpu)
infer(padded_model_cpu, tokenizer, sample_text)

Padded Neuron Tracing & Inference#

Now that the padded version of model is confirmed to produce the same outputs as the non-padded version, the model can be compiled for Neuron.

[ ]:
import torch
import torch_neuron


def trace(model, num_texts, num_beams, max_decoder_length, max_encoder_length):
    """
    Traces the encoder and decoder modules for use on Neuron.

    This function fixes the network to the given sizes. Once the model has been
    compiled to a given size, the inputs to these networks must always be of
    fixed size.

    Args:
        model (PaddedGenerator): The padded generator to compile for Neuron
        num_texts (int): The number of input texts to translate at once
        num_beams (int): The number of beams to compute per text
        max_decoder_length (int): The maximum number of tokens to be generated
        max_encoder_length (int): The maximum number of input tokens that will be encoded
    """

    # Trace the encoder
    inputs = (
        torch.ones((num_texts, max_encoder_length), dtype=torch.long),
        torch.ones((num_texts, max_encoder_length), dtype=torch.long),
    )
    encoder = torch_neuron.trace(model.encoder, inputs)

    # Trace the decoder (with expanded inputs)
    batch_size = num_texts * num_beams
    inputs = (
        torch.ones((batch_size, max_decoder_length), dtype=torch.long),
        torch.ones((batch_size, max_encoder_length), dtype=torch.long),
        torch.ones((batch_size, max_encoder_length, model.config.d_model), dtype=torch.float),
        torch.tensor(0),
    )
    decoder = torch_neuron.trace(model.decoder, inputs)

    traced = PaddedGenerator(model.config)
    traced.encoder = encoder
    traced.decoder = decoder
    setattr(encoder, 'main_input_name', 'input_ids')  # Attribute required by beam search
    return traced
[ ]:
padded_model_neuron = trace(padded_model_cpu, num_texts, num_beams, max_decoder_length, max_encoder_length)

Comparing the Neuron execution to the original CPU implementation, you will see the exact same generated text.

[ ]:
# CPU execution for comparison
infer(padded_model_neuron, tokenizer, sample_text)

Padded Neuron Serialization#

Finally, we can test that we can serialize and reload the model so that it can be used later in its precompiled format.

[ ]:
padded_model_neuron.save_pretrained('NeuronPaddedMarianMT')
padded_model_loaded = PaddedGenerator.from_pretrained('NeuronPaddedMarianMT')
infer(padded_model_loaded, tokenizer, sample_text)

Greedy Unrolled Model#

An unrolled version of the model can achieve better performance in some cases since all operations will be executed on the Neuron hardware without returning to CPU. The consequence of this type of model is that since the generation loop execution never returns to CPU, the entire sequence up to max_decoder_length is performed in a single forward pass.

The following module performs greedy text generation. Unlike the original beam search text generation, this implementation always selects the most probable token and does not generate multiple result texts.

GreedyUnrolledGenerator Module#

[ ]:
class GreedyUnrolledGenerator(torch.nn.Module):

    def __init__(self, model):
        super().__init__()
        self.config = model.config
        self.model = model

    def forward(self, input_ids, attention_mask):

        # Generate the encoder state for the input tokens. This is only done once and the state is reused.
        encoder_outputs, = self.model.model.encoder(input_ids, attention_mask=attention_mask, return_dict=False)

        # Set the intial state for the decode loop. This will grow per decoder iteration
        tokens = torch.full((input_ids.size(0), 2), self.config.decoder_start_token_id)

        # Iteratively invoke the decoder on incrementally generated `tokens` to generate a `next_token`.
        # Note that unlike the GeneratorMixin.generate function, there is no early-exit if the stop token
        # has been reached. This will always run a fixed number of iterations.
        for i in range(self.config.max_length):

            hidden, = self.model.model.decoder(
                input_ids=tokens,
                encoder_hidden_states=encoder_outputs,
                encoder_attention_mask=attention_mask,
                return_dict=False,
                use_cache=False,
            ) # size: [batch, current_length, vocab_size]

            logits = F.linear(
                hidden[:, -1, :],
                self.model.model.shared.weight,
                bias=self.model.final_logits_bias
            )
            next_tokens = torch.argmax(logits, dim=1, keepdims=True)
            tokens = torch.cat([tokens, next_tokens], dim=1)

        return tokens

Greedy CPU Inference#

The inference code must be updated since the generate method is no longer used. This is because the entire generative inference loop occurs within the GreedyUnrolledGenerator.forward method.

[ ]:
def infer_greedy(model, tokenizer, text):
    batch = tokenizer(text, max_length=max_decoder_length, truncation=True, padding='max_length', return_tensors="pt")
    inputs = batch['input_ids'], batch['attention_mask']
    tokens = greedy_cpu(*inputs)
    print('Texts:')
    for i, t in enumerate(tokens):
        result = tokenizer.decode(t, skip_special_tokens=True)
        print(i + 1, result)

Like in previous section of this tutorial, first the greedy model is executed on CPU to validate that the correct results were produced. In this example, the generated text matches the first result of the original beam search.

[ ]:
model_cpu.config.max_length = 8 # This controls the number of decoder loops. Reduced to improve compilation speed.
greedy_cpu = GreedyUnrolledGenerator(model_cpu)
infer_greedy(greedy_cpu, tokenizer, sample_text)

Greedy Neuron Tracing & Inference#

Similarly the tracing is simplified since the now the GreedyUnrolledGenerator.forward can be compiled as a single unit.

For compilation efficiency, two changes will be made compared to normal compilaition: - torch.jit.freeze is used because it can sometimes speed up compilation by in the case where a module is re-used multiple times. In this case, it is more efficient because the self.model.model.decoder is used in a loop. - The torch_neuron.trace option fallback is set to False. This forces all operations to execute on Neuron. Most of the time this is not recommended or efficient. In this case, it is more efficient because it means a single subgraph is produced rather than many. Usually one subgraph would be produced per decoder iteration since aten::embedding is executed in a loop. The aten::embedding operation is otherwise exected on CPU by default since this is usually more efficient than executing on Neuron.

You may notice that compilation will take significantly longer with the unrolled model since the model inserts new operations into the compute graph for every single decoder iteration. This creates a much larger model graph even though the weights are re-used.

[ ]:
example = (
    torch.ones((num_texts, max_encoder_length), dtype=torch.long),
    torch.ones((num_texts, max_encoder_length), dtype=torch.long),
)
greedy_cpu.eval()
greedy_trace = torch.jit.trace(greedy_cpu, example)
greedy_frozen = torch.jit.freeze(greedy_trace)
greedy_neuron = torch_neuron.trace(greedy_frozen, example, fallback=False)
[ ]:
infer_greedy(greedy_neuron, tokenizer, sample_text)

Greedy Neuron Serialization#

Unlike the previous version of the model that used the GenerationMixin base class. This greedy version of the model can be serialized using the regular torch.jit.save and torch.jit.load utilities since it is a pure torchscript module.

[ ]:
torch.jit.save(greedy_neuron, 'greedy_neuron.pt')
loaded_greedy_neuron = torch.jit.load('greedy_neuron.pt')
infer_greedy(loaded_greedy_neuron, tokenizer, sample_text)

Appendix#

BART (Mask Filling Task)#

These PaddedGenerator class can be applied to the BART model for the task of filling in mask tokens.

[ ]:
from transformers import BartForConditionalGeneration, BartTokenizer
bart_name = "facebook/bart-large"
bart_model = BartForConditionalGeneration.from_pretrained(bart_name)
bart_model.config.max_length = max_decoder_length
bart_tokenizer = BartTokenizer.from_pretrained(bart_name)
bart_text = "UN Chief Says There Is No <mask> in Syria"
[ ]:
# CPU Execution
infer(bart_model, bart_tokenizer, bart_text)
[ ]:
# Neuron Execution
paddded_bart = PaddedGenerator.from_model(bart_model)
bart_neuron = trace(paddded_bart, num_texts, num_beams, max_decoder_length, max_encoder_length)
infer(bart_neuron, bart_tokenizer, bart_text)

Pegasus (Summarization Task)#

These PaddedGenerator class can be applied to the Pegasus model for summarization.

[ ]:
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
pegasus_name = 'google/pegasus-xsum'
pegasus_model = PegasusForConditionalGeneration.from_pretrained(pegasus_name)
pegasus_model.config.max_length = max_decoder_length
pegasus_tokenizer = PegasusTokenizer.from_pretrained(pegasus_name)
pegasus_text = "PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires."
[ ]:
# CPU Execution
infer(pegasus_model, pegasus_tokenizer, pegasus_text)
[ ]:
# Neuron Execution
paddded_pegasus = PaddedGenerator.from_model(pegasus_model)
pegasus_neuron = trace(paddded_pegasus, num_texts, num_beams, max_decoder_length, max_encoder_length)
infer(pegasus_neuron, pegasus_tokenizer, pegasus_text)