Running ResNet50 on Inferentia#

Note: this tutorial runs on tensorflow-neuron 1.x only#

Introduction:#

In this tutorial we will compile and deploy ResNet50 model for Inferentia. In this tutorial we provide two main sections: 1. Compile the ResNet50 model. 2. Infer the same compiled model.

Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the Tensorflow Installation Guide. You can select the Kernel from the “Kernel -> Change Kernel” option on the top of this Jupyter notebook page.

Instructions of how to setup Neuron Tensorflow environment and run the tutorial as a Jupyter notebook are available in the Tensorflow Quick Setup

[ ]:
!pip install tensorflow_neuron==1.15.5.2.8.9.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com/
!pip install neuron_cc==1.13.5.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com

Compile for Neuron#

A trained model must be compiled to Inferentia target before it can be deployed on Inferentia instances. In this step we compile the Keras ResNet50 model and export it as a SavedModel which is an interchange format for TensorFlow models. At the end of compilation, the compiled SavedModel is saved in resnet50_neuron local directory:

[ ]:
import os
import time
import shutil
import tensorflow as tf
import tensorflow.neuron as tfn
import tensorflow.compat.v1.keras as keras
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input

# Create a workspace
WORKSPACE = './ws_resnet50'
os.makedirs(WORKSPACE, exist_ok=True)

# Prepare export directory (old one removed)
model_dir = os.path.join(WORKSPACE, 'resnet50')
compiled_model_dir = os.path.join(WORKSPACE, 'resnet50_neuron')
shutil.rmtree(model_dir, ignore_errors=True)
shutil.rmtree(compiled_model_dir, ignore_errors=True)

# Instantiate Keras ResNet50 model
keras.backend.set_learning_phase(0)
keras.backend.set_image_data_format('channels_last')

model = ResNet50(weights='imagenet')

# Export SavedModel
tf.saved_model.simple_save(
    session            = keras.backend.get_session(),
    export_dir         = model_dir,
    inputs             = {'input': model.inputs[0]},
    outputs            = {'output': model.outputs[0]})

# Compile using Neuron
tfn.saved_model.compile(model_dir, compiled_model_dir)

[ ]:
!ls

Deploy on Inferentia#

Using same instance to deploy the model. In case of different deployment instance, launch a deployment inf1 instance and copy compiled model to the deployment inf1 instance.

Download the example image, and install pillow module for inference on deployement instance

[ ]:
!curl -O https://raw.githubusercontent.com/awslabs/mxnet-model-server/master/docs/images/kitten_small.jpg
!pip install pillow  # Necessary for loading images

After downloading the example image, run the inference.#

[ ]:
import os
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import resnet50

tf.keras.backend.set_image_data_format('channels_last')

# Create input from image
img_sgl = image.load_img('kitten_small.jpg', target_size=(224, 224))
img_arr = image.img_to_array(img_sgl)
img_arr2 = np.expand_dims(img_arr, axis=0)
img_arr3 = resnet50.preprocess_input(img_arr2)

# Load model
COMPILED_MODEL_DIR = './ws_resnet50/resnet50_neuron/'
predictor_inferentia = tf.contrib.predictor.from_saved_model(COMPILED_MODEL_DIR)

# Run inference
model_feed_dict={'input': img_arr3}
infa_rslts = predictor_inferentia(model_feed_dict);

# Display results
print(resnet50.decode_predictions(infa_rslts["output"], top=5)[0])

# Sample output will look like below:
#[('n02123045', 'tabby', 0.68817204), ('n02127052', 'lynx', 0.12701613), ('n02123159', 'tiger_cat', 0.08736559), ('n02124075', 'Egyptian_cat', 0.063844085), ('n02128757', 'snow_leopard', 0.009240591)]