Deploy a pretrained PyTorch BERT model from HuggingFace on Amazon SageMaker with Neuron container#

Overview#

In this tutotial we will deploy on SageMaker a pretraine BERT Base model from HuggingFace Transformers, using the AWS Deep Learning Containers. We will use the same model as shown in the Neuron Tutorial “PyTorch - HuggingFace Pretrained BERT Tutorial”. We will compile the model and build a custom AWS Deep Learning Container, to include the HuggingFace Transformers Library.

This Jupyter Notebook should run on a ml.c5.4xlarge SageMaker Notebook instance. You can set up your SageMaker Notebook instance by following the Get Started with Amazon SageMaker Notebook Instances documentation.

We recommend increasing the size of the base root volume of you SM notebook instance, to accomodate the models and containers built locally. A root volume of 10Gb should suffice.

Install Dependencies:#

This tutorial requires the following pip packages:

  • torch-neuron

  • neuron-cc[tensorflow]

  • transformers

[ ]:
%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect
!pip install --upgrade --no-cache-dir torch-neuron neuron-cc[tensorflow] torchvision torch --extra-index-url=https://pip.repos.neuron.amazonaws.com
!pip install --upgrade --no-cache-dir 'transformers==4.6.0'

Compile the model into an AWS Neuron optimized TorchScript#

[ ]:
import torch
import torch_neuron

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
[ ]:
# Build tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc", return_dict=False)

# Setup some example inputs
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"

max_length=128
paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt")

# Run the original PyTorch model on compilation exaple
paraphrase_classification_logits = model(**paraphrase)[0]

# Convert example inputs to a format that is compatible with TorchScript tracing
example_inputs_paraphrase = paraphrase['input_ids'], paraphrase['attention_mask'], paraphrase['token_type_ids']
example_inputs_not_paraphrase = not_paraphrase['input_ids'], not_paraphrase['attention_mask'], not_paraphrase['token_type_ids']
[ ]:
%%time
# Run torch.neuron.trace to generate a TorchScript that is optimized by AWS Neuron
# This step may need 3-5 min
model_neuron = torch.neuron.trace(model, example_inputs_paraphrase, verbose=1, compiler_workdir='./compilation_artifacts')

You may inspect model_neuron.graph to see which part is running on CPU versus running on the accelerator. All native aten operators in the graph will be running on CPU.

[ ]:
# See  which part is running on CPU versus running on the accelerator.
print(model_neuron.graph)

Save the compiled model, so it can be packaged and sent to S3.

[ ]:
# Save the TorchScript for later use
model_neuron.save('neuron_compiled_model.pt')

Package the pre-trained model and upload it to S3#

To make the model available for the SageMaker deployment, you will TAR the serialized graph and upload it to the default Amazon S3 bucket for your SageMaker session.

[ ]:
# Now you'll create a model.tar.gz file to be used by SageMaker endpoint
!tar -czvf model.tar.gz neuron_compiled_model.pt
[ ]:
import boto3
import time
from sagemaker.utils import name_from_base
import sagemaker
[ ]:
# upload model to S3
role = sagemaker.get_execution_role()
sess=sagemaker.Session()
region=sess.boto_region_name
bucket=sess.default_bucket()
sm_client=boto3.client('sagemaker')
[ ]:
model_key = '{}/model/model.tar.gz'.format('inf1_compiled_model')
model_path = 's3://{}/{}'.format(bucket, model_key)
boto3.resource('s3').Bucket(bucket).upload_file('model.tar.gz', model_key)
print("Uploaded model to S3:")
print(model_path)

Build and Push the container#

The following shell code shows how to build the container image using docker build and push the container image to ECR using docker push. The Dockerfile in this example is available in the container folder. Here’s an example of the Dockerfile:

FROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference-neuron:1.7.1-neuron-py36-ubuntu18.04

# Install packages
RUN pip install "transformers==4.7.0"
[ ]:
!cat container/Dockerfile

Before running the next cell, make sure your SageMaker IAM role has access to ECR. If not, you can attache the role AmazonEC2ContainerRegistryPowerUser to your IAM role ARN, which allows you to upload image layers to ECR.

It takes 5 minutes to build docker images and upload image to ECR

[ ]:
%%sh

# The name of our algorithm
algorithm_name=neuron-py36-inference

cd container

account=$(aws sts get-caller-identity --query Account --output text)

# Get the region defined in the current configuration (default to us-west-2 if none defined)
region=$(aws configure get region)
region=${region:-us-west-2}

fullname="${account}.dkr.ecr.${region}.amazonaws.com/${algorithm_name}:latest"

# If the repository doesn't exist in ECR, create it.

aws ecr describe-repositories --repository-names "${algorithm_name}" > /dev/null 2>&1

if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name "${algorithm_name}" > /dev/null
fi

# Get the login command from ECR in order to pull down the SageMaker PyTorch image
aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-east-1.amazonaws.com
# Build the docker image locally with the image name and then push it to ECR
# with the full name.
docker build  -t ${algorithm_name} . --build-arg REGION=${region}
docker tag ${algorithm_name} ${fullname}

# Get the login command from ECR and execute it directly
aws ecr get-login-password --region ${region} | docker login --username AWS --password-stdin ${account}.dkr.ecr.${region}.amazonaws.com
docker push ${fullname}

Deploy Container and run inference based on the pretrained model#

To deploy a pretrained PyTorch model, you’ll need to use the PyTorch estimator object to create a PyTorchModel object and set a different entry_point.

You’ll use the PyTorchModel object to deploy a PyTorchPredictor. This creates a SageMaker Endpoint – a hosted prediction service that we can use to perform inference.

[ ]:
import sys

!{sys.executable} -m pip install Transformers
[ ]:
import os
import boto3
import sagemaker

role = sagemaker.get_execution_role()
sess = sagemaker.Session()

bucket = sess.default_bucket()
prefix = "inf1_compiled_model/model"

# Get container name in ECR
client=boto3.client('sts')
account=client.get_caller_identity()['Account']

my_session=boto3.session.Session()
region=my_session.region_name

algorithm_name="neuron-py36-inference"
ecr_image='{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, algorithm_name)
print(ecr_image)

An implementation of model_fn is required for inference script. We are going to implement our own model_fn and predict_fn for Hugging Face Bert, and use default implementations of input_fn and output_fn defined in sagemaker-pytorch-containers.

In this example, the inference script is put in code folder. Run the next cell to see it:

[ ]:
!pygmentize code/inference.py

Path of compiled pretrained model in S3:

[ ]:
key = os.path.join(prefix, "model.tar.gz")
pretrained_model_data = "s3://{}/{}".format(bucket, key)
print(pretrained_model_data)

The model object is defined by using the SageMaker Python SDK’s PyTorchModel and pass in the model from the estimator and the entry_point. The endpoint’s entry point for inference is defined by model_fn as seen in the previous code block that prints out inference.py. The model_fn function will load the model and required tokenizer.

Note, image_uri must be user’s own ECR images.

[ ]:
from sagemaker.pytorch.model import PyTorchModel

pytorch_model = PyTorchModel(
    model_data=pretrained_model_data,
    role=role,
    source_dir="code",
    framework_version="1.7.1",
    entry_point="inference.py",
    image_uri=ecr_image
)

# Let SageMaker know that we've already compiled the model via neuron-cc
pytorch_model._is_compiled_model = True

The arguments to the deploy function allow us to set the number and type of instances that will be used for the Endpoint.

Here you will deploy the model to a single ml.inf1.2xlarge instance. It may take 6-10 min to deploy.

[ ]:
%%time

predictor = pytorch_model.deploy(initial_instance_count=1, instance_type="ml.inf1.2xlarge")
[ ]:
print(predictor.endpoint_name)

Since in the input_fn we declared that the incoming requests are json-encoded, we need to use a json serializer, to encode the incoming data into a json string. Also, we declared the return content type to be json string, we Need to use a json deserializer to parse the response.

[ ]:
predictor.serializer = sagemaker.serializers.JSONSerializer()
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()

Using a list of sentences, now SageMaker endpoint is invoked to get predictions.

[ ]:
%%time
result = predictor.predict(
    [
        "Never allow the same bug to bite you twice.",
        "The best part of Amazon SageMaker is that it makes machine learning easy.",
    ]
)
print(result)
[ ]:
%%time
result = predictor.predict(
    [
        "The company HuggingFace is based in New York City",
        "HuggingFace's headquarters are situated in Manhattan",
    ]
)
print(result)

Benchmarking your endpoint#

The following cells create a load test for your endpoint. You first define some helper functions: inference_latency runs the endpoint request, collects cliend side latency and any errors, random_sentence builds random to be sent to the endpoint.

[ ]:
import numpy as np
import datetime
import math
import time
import boto3
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
import numpy as np
from tqdm import tqdm
import random
[ ]:
def inference_latency(model,*inputs):
    """
    infetence_time is a simple method to return the latency of a model inference.

        Parameters:
            model: torch model onbject loaded using torch.jit.load
            inputs: model() args

        Returns:
            latency in seconds
    """
    error = False
    start = time.time()
    try:
        results = model(*inputs)
    except:
        error = True
        results = []
    return {'latency':time.time() - start, 'error': error, 'result': results}
[ ]:
def random_sentence():

    s_nouns = ["A dude", "My mom", "The king", "Some guy", "A cat with rabies", "A sloth", "Your homie", "This cool guy my gardener met yesterday", "Superman"]
    p_nouns = ["These dudes", "Both of my moms", "All the kings of the world", "Some guys", "All of a cattery's cats", "The multitude of sloths living under your bed", "Your homies", "Like, these, like, all these people", "Supermen"]
    s_verbs = ["eats", "kicks", "gives", "treats", "meets with", "creates", "hacks", "configures", "spies on", "retards", "meows on", "flees from", "tries to automate", "explodes"]
    p_verbs = ["eat", "kick", "give", "treat", "meet with", "create", "hack", "configure", "spy on", "retard", "meow on", "flee from", "try to automate", "explode"]
    infinitives = ["to make a pie.", "for no apparent reason.", "because the sky is green.", "for a disease.", "to be able to make toast explode.", "to know more about archeology."]

    return (random.choice(s_nouns) + ' ' + random.choice(s_verbs) + ' ' + random.choice(s_nouns).lower() or random.choice(p_nouns).lower() + ' ' + random.choice(infinitives))

print([random_sentence(), random_sentence()])

The following cell creates number_of_clients concurrent threads to run number_of_runs requests. Once completed, a boto3 CloudWatch client will query for the server side latency metrics for comparison.

[ ]:
# Defining Auxiliary variables
number_of_clients = 2
number_of_runs = 1000
t = tqdm(range(number_of_runs),position=0, leave=True)

# Starting parallel clients
cw_start = datetime.datetime.utcnow()

results = Parallel(n_jobs=number_of_clients,prefer="threads")(delayed(inference_latency)(predictor.predict,[random_sentence(), random_sentence()]) for mod in t)
avg_throughput = t.total/t.format_dict['elapsed']

cw_end = datetime.datetime.utcnow()

# Computing metrics and print
latencies = [res['latency'] for res in results]
errors = [res['error'] for res in results]
error_p = sum(errors)/len(errors) *100
p50 = np.quantile(latencies[-1000:],0.50) * 1000
p90 = np.quantile(latencies[-1000:],0.95) * 1000
p95 = np.quantile(latencies[-1000:],0.99) * 1000

print(f'Avg Throughput: :{avg_throughput:.1f}\n')
print(f'50th Percentile Latency:{p50:.1f} ms')
print(f'90th Percentile Latency:{p90:.1f} ms')
print(f'95th Percentile Latency:{p95:.1f} ms\n')
print(f'Errors percentage: {error_p:.1f} %\n')

# Querying CloudWatch
print('Getting Cloudwatch:')
cloudwatch = boto3.client('cloudwatch')
statistics=['SampleCount', 'Average', 'Minimum', 'Maximum']
extended=['p50', 'p90', 'p95', 'p100']

# Give 5 minute buffer to end
cw_end += datetime.timedelta(minutes=5)

# Period must be 1, 5, 10, 30, or multiple of 60
# Calculate closest multiple of 60 to the total elapsed time
factor = math.ceil((cw_end - cw_start).total_seconds() / 60)
period = factor * 60
print('Time elapsed: {} seconds'.format((cw_end - cw_start).total_seconds()))
print('Using period of {} seconds\n'.format(period))

cloudwatch_ready = False
# Keep polling CloudWatch metrics until datapoints are available
while not cloudwatch_ready:
  time.sleep(30)
  print('Waiting 30 seconds ...')
  # Must use default units of microseconds
  model_latency_metrics = cloudwatch.get_metric_statistics(MetricName='ModelLatency',
                                             Dimensions=[{'Name': 'EndpointName',
                                                          'Value': predictor.endpoint_name},
                                                         {'Name': 'VariantName',
                                                          'Value': "AllTraffic"}],
                                             Namespace="AWS/SageMaker",
                                             StartTime=cw_start,
                                             EndTime=cw_end,
                                             Period=period,
                                             Statistics=statistics,
                                             ExtendedStatistics=extended
                                             )
  # Should be 1000
  if len(model_latency_metrics['Datapoints']) > 0:
    print('{} latency datapoints ready'.format(model_latency_metrics['Datapoints'][0]['SampleCount']))
    side_avg = model_latency_metrics['Datapoints'][0]['Average'] / number_of_runs
    side_p50 = model_latency_metrics['Datapoints'][0]['ExtendedStatistics']['p50'] / number_of_runs
    side_p90 = model_latency_metrics['Datapoints'][0]['ExtendedStatistics']['p90'] / number_of_runs
    side_p95 = model_latency_metrics['Datapoints'][0]['ExtendedStatistics']['p95'] / number_of_runs
    side_p100 = model_latency_metrics['Datapoints'][0]['ExtendedStatistics']['p100'] / number_of_runs

    print(f'50th Percentile Latency:{side_p50:.1f} ms')
    print(f'90th Percentile Latency:{side_p90:.1f} ms')
    print(f'95th Percentile Latency:{side_p95:.1f} ms\n')

    cloudwatch_ready = True



Cleanup#

Endpoints should be deleted when no longer in use, to avoid costs.

[ ]:
predictor.delete_endpoint(predictor.endpoint)
[ ]: