Using Data Parallel Mode with Gluon MXNet#

In this tutorial, you will compile a Gluon BERT model and run in data-parallel mode to completely utilize the NeuronCores. Here you will benchmark a multi-worker setup and compare it with a single worker.

This tutorial is intended only for MXNet-1.8.

In this tutorial, we will be using an inf1.2xlarge with the latest AWS Deep Learning AMI (DLAMI). The inf1.2xlarge instance has 1 AWS Inferentia Chip with 4 NeuronCores.

Setting up your environment#

To run this tutorial, please make sure you deactivate any existing MXNet conda environments you already using. Install MXNet 1.8 by following the instructions at MXNet Setup Guide. You would also need to change your kernel to use the correct Python environment setup earlier by clicking Kerenel->Change Kernel->Python (Neuron MXNet)

Install dependencies#

We have to install gluon-nlp to get the BERT model. Run the following command to install:

[ ]:
!python -m pip install gluonnlp

Compiling BERT Model#

Next, we compile the Gluon BERT model and save it. Once the model is compiled, we use the same model across the entire tutorial. In this tutorial, we will be using a BERT model with sequence length 32

[ ]:
import os
import mxnet as mx
import mx_neuron
import gluonnlp as nlp
[ ]:
BERT_MODEL = 'bert_12_768_12'
BERT_DATA = 'book_corpus_wiki_en_uncased'
batch_size = 1
seq_len = 32
num_cores = 1
dtype = 'float32'

compiled_model_path = '{}.compiled.{}.{}'.format(BERT_MODEL, batch_size, seq_len)

model, vocab = nlp.model.get_model(BERT_MODEL,
                                   dataset_name=BERT_DATA,
                                   use_classifier=False,
                                   use_decoder=False, ctx=mx.cpu())

# Create sample inputs for compilation
words = mx.nd.ones([batch_size, seq_len], name='words', dtype=dtype)
valid_len = mx.nd.ones([batch_size,], name='valid_len', dtype=dtype)
segments = mx.nd.ones([batch_size, seq_len], name='segments', dtype=dtype)
inputs = {'data0': words, 'data1': segments, 'data2': valid_len}

# Compiler Args ~~
options = {}
embeddingNames = ['bertmodel0_word_embed_embedding0_fwd', 'bertmodel0_token_type_embed_embedding0_fwd', 'bertencoder0_embedding0']
options.update({'force_incl_node_names': embeddingNames})
options.update({'flags': ['--fp32-cast matmult']})

# Compile and save ~~
model = mx_neuron.compile(model, inputs=inputs, **options)
model.export(compiled_model_path)

Data Parallel Mode#

Data Parallel Mode is a setup in which you launch multiple copies of the same model, such that each model is running independently of the other. In other words, each model has its own resources to run inference.

On an inf1.2xlarge instance, we have 4 NeuronCores. Hence, we can launch 4 models such that each model is loaded on a single NeuronCore. This unables us to process 4 request concurrently without linear increase in latency. As a result, the throughput of the system increases when compared to a single model inference. This would also allow us to utilize all the 4 NeuronCores on the instance.

Run through the next set of cells to see the difference in throughput as we scale from one model to 4 models running in parallel.

[ ]:
import numpy as np

def get_sample_inputs(batch_size, seq_len):
    words = np.ones([batch_size, seq_len], dtype=np.float32)
    valid_len = np.ones([batch_size,], dtype=np.float32)
    segments = np.ones([batch_size, seq_len], dtype=np.float32)
    inputs = {'data0': words, 'data1': segments, 'data2': valid_len}
    return inputs

Next for comparison purposes, we run the setup with 1 worker. To do this, we set the num_cores=1. This would launch only 1 model running on a single NeuronCore. After running the below cell, note down the latency and throughput for the system

[ ]:
from parallel import NeuronSimpleDataParallel
from benchmark_utils import Results
import time
import functools
import os
import numpy as np
import warnings

num_cores = 1
batch_size=1

# Each worker process should use one core, hence we set
#    os.environ['NEURON_RT_NUM_CORES'] = "1"
os.environ["NEURON_RT_NUM_CORES"] = "1"

#Result aggregation class (code in bert_benchmark_utils.py)
results = Results(batch_size, num_cores)
def result_handler(output, start, end):
    elapsed = end - start
    results.add_result([elapsed], [end], [start])

inputs = get_sample_inputs(batch_size, seq_len)
parallel_neuron_model = NeuronSimpleDataParallel(compiled_model_path, num_cores, inputs)

#Starting the inference threads
parallel_neuron_model.start_continuous_inference()

# Warm up the cores
for _ in range(num_cores*4):
    parallel_neuron_model.warmup(inputs)

# Need to run for high number of iterations to benchmark the models
for _ in range(1000):
    parallel_neuron_model.infer(inputs)
    # Passing the result_handler as a callback function
    parallel_neuron_model.add_result(result_handler)

# Stop inference
parallel_neuron_model.stop()
# Since we are using a multi-process execution with a shared queue, some inferences
# may still be in execution phase. Hence we need to wait till all the inputs are processed
# add_all_results() will collect all the results of requests which are in this state
parallel_neuron_model.add_all_results(result_handler)


with open("benchmark.txt", "w") as f:
    results.report(f, window_size=1)

with open("benchmark.txt", "r") as f:
    for line in f:
        print(line)

Now we run the setup with 4 workers. To do this, we set the num_cores=4. This would launch 4 model running each running on individual NeuronCore. All the 4 models are running in individual processes, in other words the models are running in parallel.

To feed the models efficiently, we use the producer-consumer setup, in which all processes running a model act as consumers. All consumers are fed using a sharing input queue.

Now we run the below setup. You may notice, that the throughput increase by >2x when compared to a single worker setup.

[ ]:
from parallel import NeuronSimpleDataParallel
from benchmark_utils import Results
import time
import functools
import os
import numpy as np

num_cores = 4
batch_size=1

os.environ["NEURON_RT_NUM_CORES"] = "1"

#Result aggregation class (code in bert_benchmark_utils.py)
results = Results(batch_size, num_cores)
def result_handler(output, start, end):
    elapsed = end - start
    results.add_result([elapsed], [end], [start])

inputs = get_sample_inputs(batch_size, seq_len)
parallel_neuron_model = NeuronSimpleDataParallel(compiled_model_path, num_cores, inputs)

#Starting the inference threads
parallel_neuron_model.start_continuous_inference()

# Warm up the cores
for _ in range(num_cores*4):
    parallel_neuron_model.warmup(inputs)

# Need to run for high number of iterations to benchmark the models
for _ in range(5000):
    parallel_neuron_model.infer(inputs)
    # Passing the result_handler as a callback function
    parallel_neuron_model.add_result(result_handler)

# Stop inference
parallel_neuron_model.stop()
# Since we are using a multi-process execution with a shared queue, some inferences
# may still be in execution phase. Hence we need to wait till all the inputs are processed
# add_all_results() will collect all the results of requests which are in this state
parallel_neuron_model.add_all_results(result_handler)


with open("benchmark.txt", "w") as f:
    results.report(f, window_size=1)

with open("benchmark.txt", "r") as f:
    for line in f:
        print(line)