This document is relevant for: Inf2
, Trn1
, Trn2
Autobucketing for Inference (|torch-neuronx|)#
Introduction#
Autobucketing is a feature that enables you to use multiple bucket models. Each bucket model accepts a static input shape and a bucket kernel function. The models are then packaged into a single traced PyTorch model that can accept multiple different input shapes.
This gives you increased flexibility for inputs into Neuron models without the need to manage multiple Neuron models. The applications of this are extensive, from optimal model selection based on image resolution, to efficient sampling for token generation in language models.
While Autobucketing offers increased flexibility, Autobucketing is also useful for latency sensitive applications since small and large inputs can be applied on small and large models respectively, based on the bucket kernel function.
This Developer Guide will discuss best practices for implementing Autobucketing for your use case. For this Developer Guide, a BERT model will be used, where we bucket on the sequence length dimension.
Before continuing, it is recommended to familiarize yourself with the Autobucketing APIs, which can be found here.
Bucket Kernels#
Bucket kernels are user-defined functions that take in the model input as input to the function and return a tuple containing a potentially manipulated model input and a tensor representing the bucket index.
An important aspect of this function is that it must be able to be adapted to the TorchScript representation using torch.jit.script()
. This is because to support saving a traced bucket model with torch.jit.save()
and torch.jit.load()
, you need all elements of the model to be in TorchScript.
The below example shows a bucket kernel that is adaptable to TorchScript in this way.
import torch
from typing import List
def sequence_length_bucket_kernel(tensor_list: List[torch.Tensor]):
x = tensor_list[0]
bucket_dim = 1
x_shape = x.shape
tensor_sequence_length = x_shape[bucket_dim]
batch_size = x_shape[bucket_dim - 1]
buckets = [128, 512]
idx = 0
num_inputs = 3
bucket = buckets[0]
reshaped_tensors: List[torch.Tensor] = []
bucket_idx = 0
for idx, bucket in enumerate(buckets):
if tensor_sequence_length <= bucket:
bucket_idx = idx
for tensor in tensor_list:
if num_inputs == 0:
break
delta = bucket - tensor_sequence_length
padding_shape: List[int] = [batch_size, delta]
zeros = torch.zeros(padding_shape, dtype=x.dtype)
reshaped_tensors.append(torch.cat([tensor, zeros], dim=bucket_dim))
num_inputs -= 1
break
return reshaped_tensors, torch.tensor([bucket_idx])
def get_bucket_kernel(*_):
bk = torch.jit.script(sequence_length_bucket_kernel)
return bk
In the above example we define a bucket kernel that takes in an input to a transformers model, which is [input_ids,attention_mask,token_type_ids]
. We first obtain the first tensor in that list, since that tensor contains sequence_length
as a dimension, and retrieve the sequence_length
and batch_size
. We also define the sequence length buckets. The next major part of the code is the for loop, which first finds the matching sequence length bucket and then iterates through the tensors in the list to right pad the tensors to the desired sequence length. After this is done, we return the padded inputs as a list of tensors and a tensor containing the bucket index. Finally, we create a function get_bucket_kernel
which returns a version of the bucket kernel that has been adapted to TorchScript using using torch.jit.script()
. We can use this bucket kernel to pass in a tokenized input of sequence length 1-512, which is padded to the nearest bucket size rounded up.
Note that we call torch.jit.script()
instead of torch.jit.trace()
. This
is because we rely on control flow logic evaluating correctly for all inputs. This
results in certain challenges when writing compatible and accurate bucket kernels. We
discuss these challenges and resolutions in the next section.
Torchscript Best Practices for Bucket Kernels#
Below are some recommendations when creating these Bucket Kernels:
Type annotate non-tensor-like data types: Functions that have been adapted to the TorchScript representation using using
torch.jit.script()
treat variables that are defined by using another variable as tensor-like when they might not be. This can be seen when definingpadding_shape
in the above bucket kernel.Index selection support is limited: Functions that have been adapted to the TorchScript representation using using
torch.jit.script()
don’t support the use of variables for indexing very well. It could work in some scenarios, but there isn’t a discernable pattern to it, so for more reliable TorchScript-adapted functions relying on indexes, use an enumerated for loop or literals if possible.Initializing variables with literals: The Torchscript compiler often incorrectly removes a variable if it finds another variable initialized with the same literal, such as
0
. The compiler might also reuse variables initialized with a literal for other operations, such as indexing or function parameters. This can cause inaccurate results for certain inputs. Therefore, always validate the function by testing with the expected inputs. If the lowering does not behave as expected, you can see the lowered representation by callingbucket_kernel.graph
, wherebucket_kernel
is the return value ofget_bucket_kernel
, and analyze the graph for inaccurate lowerings.Use of aten functions might be necessary to guarantee correct lowering: The TorchScript interpreter supports certain operations, such as slicing, but can lower them in unexpected ways when using normal syntax. For example, with slicing, the most common way to slice is with indexing syntax such as
tensor[:,:2,:]
. However, this can cause lowering issues due to the aforementioned reasons. To mitigate this, it might be necessary to call the respective aten function directly. See the below example withshared_state_buffer_preprocessor
.
Bucket Model Config#
Given the above two examples, we can initialize a torch_neuronx.BucketModelConfig
object as follows:
import torch
import torch_neuronx
from typing import List
# above code
bucket_config = torch_neuronx.BucketModelConfig(get_bucket_kernel,shared_state_buffer_preprocessor=get_state_preprocessor)
Putting it all Together#
Here is a simple example using the BERT model:
import torch
import torch_neuronx
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from typing import List
def encode(tokenizer, *inputs, max_length=128, batch_size=1):
tokens = tokenizer.encode_plus(
*inputs,
max_length=max_length,
padding='max_length',
truncation=True,
return_tensors="pt"
)
return (
torch.repeat_interleave(tokens['input_ids'], batch_size, 0),
torch.repeat_interleave(tokens['attention_mask'], batch_size, 0),
)
def get_bert_model(*args):
name = "bert-base-cased-finetuned-mrpc"
model = AutoModelForSequenceClassification.from_pretrained(name, torchscript=True)
return model,{}
def sequence_length_bucket_kernel(tensor_list: List[torch.Tensor]):
x = tensor_list[0]
bucket_dim = 1
x_shape = x.shape
tensor_sequence_length = x_shape[bucket_dim]
batch_size = x_shape[bucket_dim - 1]
buckets = [128, 512]
idx = 0
num_inputs = 3
bucket = buckets[0]
reshaped_tensors: List[torch.Tensor] = []
bucket_idx = 0
for idx, bucket in enumerate(buckets):
if tensor_sequence_length <= bucket:
bucket_idx = idx
for tensor in tensor_list:
if num_inputs == 0:
break
delta = bucket - tensor_sequence_length
padding_shape: List[int] = [batch_size, delta]
zeros = torch.zeros(padding_shape, dtype=x.dtype)
reshaped_tensors.append(torch.cat([tensor, zeros], dim=bucket_dim))
num_inputs -= 1
break
return reshaped_tensors, torch.tensor([bucket_idx])
def get_bucket_kernel(*_):
bk = torch.jit.script(sequence_length_bucket_kernel)
return bk
if __name__ == '__main__':
name = "bert-base-cased-finetuned-mrpc"
# Build tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForSequenceClassification.from_pretrained(name, torchscript=True)
# Setup some example inputs
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "HuggingFace is named after the huggingface emoji"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"
paraphrase_s128 = encode(tokenizer, sequence_0, sequence_2)
paraphrase_s122 = encode(tokenizer, sequence_0, sequence_2, max_length=122)
paraphrase_s512 = encode(tokenizer, sequence_0, sequence_1, max_length=512)
paraphrase_s444 = encode(tokenizer, sequence_0, sequence_1, max_length=444)
# Note: Run on CPU before trace. Avoids running with XLA allocated params
paraphrase_expected_s128 = torch.argmax(model(*paraphrase_s128)[0])
paraphrase_expected_s512 = torch.argmax(model(*paraphrase_s512)[0])
# Trace model
bucket_config = torch_neuronx.BucketModelConfig(get_bucket_kernel)
bucket_trace_neuron = torch_neuronx.bucket_model_trace(get_bert_model, [paraphrase_s128,paraphrase_s512], bucket_config)
# Run traced model with shorter inputs to test bucket rounding
paraphrase_actual_s128 = torch.argmax(bucket_trace_neuron(*paraphrase_s122)[0])
paraphrase_actual_s512 = torch.argmax(bucket_trace_neuron(*paraphrase_s444)[0])
# Compare outputs
assert paraphrase_expected_s128 == paraphrase_actual_s128
assert paraphrase_expected_s512 == paraphrase_actual_s512
Autobucketing for Neuronx-Distributed#
To see this same example applied on Neuronx-Distributed, go to this section on the Neuronx-Distributed Inference Developer Guide
This document is relevant for: Inf2
, Trn1
, Trn2