This document is relevant for: Inf2, Trn1, Trn1n

BERT TorchServe Tutorial#

Overview#

Update 10/02:This tutorial is currently broken and the AWS Neuron team is working on providing the fix.

This tutorial demonstrates the use of TorchServe with Neuron, the SDK for EC2 Inf2 and Trn1 instances. By the end of this tutorial, you will understand how TorchServe can be used to serve a model backed by EC2 Inf2/Trn1 instances. We will use a pretrained BERT-Base model to determine if one sentence is a paraphrase of another.

Run the tutorial#

Open a terminal, log into your remote instance, and activate a Pytorch virtual environment setup (see the:ref:Install PyTorch Neuron <setup-torch-neuronx>). To complete this tutorial, you will also need a compiled BERT model. You can run trace_bert_neuronx.py to obtain a traced BERT model.

You should now have a compiled bert_neuron_b6.pt file, which is required going forward.

Open a shell on the instance you prepared earlier, create a new directory named torchserve. Copy your compiled model from the previous tutorial into this new directory.

cd torchserve
ls
bert_neuron_b6.pt

Prepare a new Python virtual environment with the necessary Neuron and TorchServe components. Use a virtual environment to keep (most of) the various tutorial components isolated from the rest of the system in a controlled way.

pip install transformers==4.26.0 torchserve==0.7.0 torch-model-archiver==0.7.0 captum==0.6.0

Install the system requirements for TorchServe.

sudo yum -y install jq java-11-amazon-corretto-headless
sudo alternatives --config java
sudo alternatives --config javac
sudo apt -y install openjdk-11-jdk
java -version
openjdk version "11.0.17" 2022-10-18
OpenJDK Runtime Environment (build 11.0.17+8-post-Ubuntu-1ubuntu218.04)
OpenJDK 64-Bit Server VM (build 11.0.17+8-post-Ubuntu-1ubuntu218.04, mixed mode, sharing)
javac -version
javac 11.0.17

Verify that TorchServe is now available.

torchserve --version
TorchServe Version is 0.7.0

Setup TorchServe#

During this tutorial you will need to download a few files onto your instance. The simplest way to accomplish this is to paste the download links provided above each file into a wget command. (We don’t provide the links directly because they are subject to change.) For example, right-click and copy the download link for config.json shown below.

{
    "model_name": "bert-base-cased-finetuned-mrpc",
    "max_length": 128,
    "batch_size": 6
}

Now execute the following in your shell:

wget <paste link here>
ls
bert_neuron_b6.pt  config.json

Download the custom handler script that will eventually respond to inference requests.

  1import os
  2import json
  3import sys
  4import logging
  5from abc import ABC
  6
  7import torch
  8import torch_neuronx
  9
 10from transformers import AutoTokenizer
 11from ts.torch_handler.base_handler import BaseHandler
 12
 13
 14# one core per worker
 15os.environ['NEURON_RT_NUM_CORES'] = '1'
 16
 17logger = logging.getLogger(__name__)
 18
 19class BertEmbeddingHandler(BaseHandler, ABC):
 20    """
 21    Handler class for Bert Embedding computations.
 22    """
 23    def __init__(self):
 24        super(BertEmbeddingHandler, self).__init__()
 25        self.initialized = False
 26
 27    def initialize(self, ctx):
 28        self.manifest = ctx.manifest
 29        properties = ctx.system_properties
 30        self.device = 'cpu'
 31        model_dir = properties.get('model_dir')
 32        serialized_file = self.manifest['model']['serializedFile']
 33        model_pt_path = os.path.join(model_dir, serialized_file)
 34
 35        # point sys.path to our config file
 36        with open('config.json') as fp:
 37            config = json.load(fp)
 38        self.max_length = config['max_length']
 39        self.batch_size = config['batch_size']
 40        self.classes = ['not paraphrase', 'paraphrase']
 41
 42        self.model = torch.jit.load(model_pt_path)
 43        logger.debug(f'Model loaded from {model_dir}')
 44        self.model.to(self.device)
 45        self.model.eval()
 46
 47        self.tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
 48        self.initialized = True
 49
 50    def preprocess(self, input_data):
 51        """
 52        Tokenization pre-processing
 53        """
 54
 55        input_ids = []
 56        attention_masks = []
 57        token_type_ids = []
 58        for row in input_data:
 59            seq_0 = row['seq_0'].decode('utf-8')
 60            seq_1 = row['seq_1'].decode('utf-8')
 61            logger.debug(f'Received text: "{seq_0}", "{seq_1}"')
 62
 63            inputs = self.tokenizer.encode_plus(
 64                    seq_0,
 65                    seq_1,
 66                    max_length=self.max_length,
 67                    padding='max_length',
 68                    truncation=True,
 69                    return_tensors='pt'
 70                    )
 71
 72            input_ids.append(inputs['input_ids'])
 73            attention_masks.append(inputs['attention_mask'])
 74            token_type_ids.append(inputs['token_type_ids'])
 75
 76        batch = (torch.cat(input_ids, 0),
 77                torch.cat(attention_masks, 0),
 78                torch.cat(token_type_ids, 0))
 79
 80        return batch
 81
 82    def inference(self, inputs):
 83        """
 84        Predict the class of a text using a trained transformer model.
 85        """
 86
 87        # sanity check dimensions
 88        assert(len(inputs) == 3)
 89        num_inferences = len(inputs[0])
 90        assert(num_inferences <= self.batch_size)
 91
 92        # insert padding if we received a partial batch
 93        padding = self.batch_size - num_inferences
 94        if padding > 0:
 95            pad = torch.nn.ConstantPad1d((0, 0, 0, padding), value=0)
 96            inputs = [pad(x) for x in inputs]
 97
 98        outputs = self.model(*inputs)[0]
 99        predictions = []
100        for i in range(num_inferences):
101            prediction = self.classes[outputs[i].argmax(dim=-1).item()]
102            predictions.append([prediction])
103            logger.debug("Model predicted: '%s'", prediction)
104        return predictions
105
106    def postprocess(self, inference_output):
107        return inference_output

Next, we need to associate the handler script with the compiled model using torch-model-archiver. Run the following commands in your terminal:

mkdir model_store
MAX_LENGTH=$(jq '.max_length' config.json)
BATCH_SIZE=$(jq '.batch_size' config.json)
MODEL_NAME=bert-max_length$MAX_LENGTH-batch_size$BATCH_SIZE
torch-model-archiver --model-name "$MODEL_NAME" --version 1.0 --serialized-file ./bert_neuron_b6.pt --handler "./handler_bert_neuronx.py" --extra-files "./config.json" --export-path model_store

Note

If you modify your model or a dependency, you will need to rerun the archiver command with the -f flag appended to update the archive.

The result of the above will be a mar file inside the model_store directory.

$ ls model_store
bert-max_length128-batch_size6.mar

This file is essentially an archive associated with a fixed version of your model along with its dependencies (e.g. the handler code).

Note

The version specified in the torch-model-archiver command can be appended to REST API requests to access a specific version of your model. For example, if your model was hosted locally on port 8080 and named “bert”, the latest version of your model would be available at http://localhost:8080/predictions/bert, while version 1.0 would be accessible at http://localhost:8080/predictions/bert/1.0. We will see how to perform inference using this API in Step 6.

Create a custom config file to set some parameters. This file will be used to configure the server at launch when we run torchserve --start.

# bind inference API to all network interfaces with SSL enabled
inference_address=http://0.0.0.0:8080
default_workers_per_model=1

Note

This will cause TorchServe to bind on all interfaces. For security in real-world applications, you’ll probably want to use port 8443 and enable SSL.

Run TorchServe#

It’s time to start the server. Typically we’d want to launch this in a separate console, but for this demo we’ll just redirect output to a file.

torchserve --start --ncs --model-store model_store --ts-config torchserve.config 2>&1 >torchserve.log

Verify that the server seems to have started okay.

curl http://127.0.0.1:8080/ping
{
  "status": "Healthy"
}

Note

If you get an error when trying to ping the server, you may have tried before the server was fully launched. Check torchserve.log for details.

Use the Management API to instruct TorchServe to load our model.

First, determine the number of NeuronCores available based on your instance size.

Instance Size

# of NeuronCores

xlarge

2

8xlarge

2

24xlarge

12

48xlarge

24

Instance Size

# of NeuronCores

2xlarge

2

32xlarge

32

MAX_BATCH_DELAY=5000 # ms timeout before a partial batch is processed
INITIAL_WORKERS=<number of NeuronCores from table above>
curl -X POST "http://localhost:8081/models?url=$MODEL_NAME.mar&batch_size=$BATCH_SIZE&initial_workers=$INITIAL_WORKERS&max_batch_delay=$MAX_BATCH_DELAY"
{
  "status": "Model \"bert-max_length128-batch_size6\" Version: 1.0 registered with X initial workers"
}

Warning

You shouldn’t set INITIAL_WORKERS above the number of NeuronCores. If you attempt to load more models than NeuronCores available, one of two things will occur. Either the extra models will fit in device memory but performance will suffer, or you will encounter an error on your initial inference. However, you may want to use fewer cores if you are using the NeuronCore Pipeline feature.

Note

Any additional attempts to configure the model after the initial curl request will cause the server to return a 409 error. You’ll need to stop/start/configure the server to realize any changes.

The MAX_BATCH_DELAY is a timeout value that determines how long to wait before processing a partial batch. This is why the handler code needs to check the batch dimension and potentially add padding. TorchServe will instantiate the number of model handlers indicated by INITIAL_WORKERS, so this value controls how many models we will load onto Inferentia in parallel. If you want to control worker scaling more dynamically, see the docs.

It looks like everything is running successfully at this point, so it’s time for an inference.

Create the infer_bert.py file below on your instance.

 1import json
 2import concurrent.futures
 3import requests
 4
 5with open('config.json') as fp:
 6    config = json.load(fp)
 7max_length = config['max_length']
 8batch_size = config['batch_size']
 9name = f'bert-max_length{max_length}-batch_size{batch_size}'
10
11# dispatch requests in parallel
12url = f'http://localhost:8080/predictions/{name}'
13paraphrase = {'seq_0': "HuggingFace's headquarters are situated in Manhattan",
14        'seq_1': "The company HuggingFace is based in New York City"}
15not_paraphrase = {'seq_0': paraphrase['seq_0'], 'seq_1': 'This is total nonsense.'}
16
17with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
18    def worker_thread(worker_index):
19        # we'll send half the requests as not_paraphrase examples for sanity
20        data = paraphrase if worker_index < batch_size//2 else not_paraphrase
21        response = requests.post(url, data=data)
22        print(worker_index, response.json())
23
24    for worker_index in range(batch_size):
25        executor.submit(worker_thread, worker_index)

This script will send a batch_size number of requests to our model. In this example, we are using a model that estimates the probability that one sentence is a paraphrase of another. The script sends positive examples in the first half of the batch and negative examples in the second half.

Execute the script in your terminal.

$ python infer_bert.py
1 ['paraphrase']
3 ['not paraphrase']
4 ['not paraphrase']
0 ['paraphrase']
5 ['not paraphrase']
2 ['paraphrase']

We can see that the first three threads (0, 1, 2) all report paraphrase, as expected. If we instead modify the script to send an incomplete batch and then wait for the timeout to expire, the excess padding results will be discarded.

Benchmark TorchServe#

We’ve seen how to perform a single batched inference, but how many inferences can we process per second? A separate upcoming tutorial will document performance tuning to maximize throughput. In the meantime, we can still perform a simple naïve stress test. The code below will spawn 64 worker threads, with each thread repeatedly sending a full batch of data to process. A separate thread will periodically print throughput and latency measurements.

 1import os
 2import argparse
 3import time
 4import numpy as np
 5import requests
 6import sys
 7from concurrent import futures
 8
 9import torch
10
11
12parser = argparse.ArgumentParser()
13parser.add_argument('--url', help='Torchserve model URL', type=str, default=f'http://127.0.0.1:8080/predictions/bert-max_length128-batch_size6')
14parser.add_argument('--num_thread', type=int, default=64, help='Number of threads invoking the model URL')
15parser.add_argument('--batch_size', type=int, default=6)
16parser.add_argument('--sequence_length', type=int, default=128)
17parser.add_argument('--latency_window_size', type=int, default=1000)
18parser.add_argument('--throughput_time', type=int, default=300)
19parser.add_argument('--throughput_interval', type=int, default=10)
20args = parser.parse_args()
21
22data = { 'seq_0': 'A completely made up sentence.',
23    'seq_1': 'Well, I suppose they are all made up.' }
24live = True
25num_infer = 0
26latency_list = []
27
28
29def one_thread(pred, feed_data):
30    global latency_list
31    global num_infer
32    global live
33    session = requests.Session()
34    while True:
35        start = time.time()
36        result = session.post(pred, data=feed_data)
37        latency = time.time() - start
38        latency_list.append(latency)
39        num_infer += 1
40        if not live:
41            break
42
43
44def current_performance():
45    last_num_infer = num_infer
46    for _ in range(args.throughput_time // args.throughput_interval):
47        current_num_infer = num_infer
48        throughput = (current_num_infer - last_num_infer) / args.throughput_interval
49        p50 = 0.0
50        p90 = 0.0
51        if latency_list:
52            p50 = np.percentile(latency_list[-args.latency_window_size:], 50)
53            p90 = np.percentile(latency_list[-args.latency_window_size:], 90)
54        print('pid {}: current throughput {}, latency p50={:.3f} p90={:.3f}'.format(os.getpid(), throughput, p50, p90))
55        sys.stdout.flush()
56        last_num_infer = current_num_infer
57        time.sleep(args.throughput_interval)
58    global live
59    live = False
60
61
62with futures.ThreadPoolExecutor(max_workers=args.num_thread+1) as executor:
63    executor.submit(current_performance)
64    for _ in range(args.num_thread):
65        executor.submit(one_thread, args.url, data)

Run the benchmarking script.

python benchmark_bert.py
pid 1214554: current throughput 0.0, latency p50=0.000 p90=0.000
pid 1214554: current throughput 713.9, latency p50=0.071 p90=0.184
pid 1214554: current throughput 737.9, latency p50=0.071 p90=0.184
pid 1214554: current throughput 731.6, latency p50=0.068 p90=0.192
pid 1214554: current throughput 732.2, latency p50=0.070 p90=0.194
pid 1214554: current throughput 733.9, latency p50=0.070 p90=0.187
pid 1214554: current throughput 739.3, latency p50=0.071 p90=0.184
...

Note

Your throughput numbers may differ from these based on instance type and size.

Congratulations! By now you should have successfully served a batched model over TorchServe.

You can now shutdown torchserve.

torchserve --stop

This document is relevant for: Inf2, Trn1, Trn1n