{ "cells": [ { "cell_type": "markdown", "id": "caff04ba", "metadata": {}, "source": [ "# Running OpenPose on Inferentia\n" ] }, { "cell_type": "markdown", "id": "09b2919a", "metadata": {}, "source": [ "## Note: this tutorial runs on tensorflow-neuron 1.x only" ] }, { "cell_type": "markdown", "id": "4dcf9bb1", "metadata": {}, "source": [ "## Introduction:\n", "\n", "In this tutorial we will compile and deploy Openpose model for Inferentia. This jupyter notebook should run on an inf1.6xlarge instance for compilation and inference. The inference part of this tutorial requires inf1.6xlarge and not the compilation itself. For simplicity we will run this tutorial on a single instance but in real life scenario the compilation can be done on a compute c5.4xlarge instance and the deployment on the inf1 instance family.\n", "\n", "In this tutorial we provide two main sections:\n", "1. Compile the OpenPose model on inf1x6large.\n", "2. Infer the same compiled model on inf1x6large.\n", "\n", "Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the [Tensorflow Installation Guide](../../../../frameworks/tensorflow/tensorflow-neuron/setup/tensorflow-install.html#install-neuron-tensorflow). You can select the Kernel from the “Kernel -> Change Kernel” option on the top of this Jupyter notebook page.\n" ] }, { "cell_type": "markdown", "id": "04ae0838", "metadata": {}, "source": [ "## Acknowledgement:\n", "\n", "Many thanks to https://github.com/ildoonet for providing pretrained model as well as the image preprocessing/pose estimating infrastructure." ] }, { "cell_type": "markdown", "id": "d0d6d08e", "metadata": {}, "source": [ "## Download tensorflow pose net frozen graph." ] }, { "cell_type": "code", "execution_count": null, "id": "1926d4e3", "metadata": { "scrolled": false }, "outputs": [], "source": [ "!wget -c --tries=2 $( wget -q -O - http://www.mediafire.com/file/qlzzr20mpocnpa3/graph_opt.pb | grep -o 'http*://download[^\"]*' | tail -n 1 ) -O graph_opt.pb\n", "\n", "!pip install tensorflow_neuron==1.15.5.2.8.9.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com/\n", "!pip install neuron_cc==1.13.5.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com" ] }, { "cell_type": "markdown", "id": "83eb578b", "metadata": {}, "source": [ "## Compile\n", "Compile the pose net frozen graph into AWS Neuron compatible form. Network input image resolution is adjustable with argument --net_resolution (e. g., --net_resolution=656x368). The compiled model can accept arbitrary batch size input at runtime." ] }, { "cell_type": "code", "execution_count": null, "id": "362f322e", "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "Usage: python convert_graph_opt.py /path/to/graph_opt.pb /path/to/graph_opt_neuron.pb\n", "\"\"\"\n", "#import argparse\n", "import numpy as np\n", "import tensorflow as tf\n", "from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto\n", "import tensorflow.neuron as tfn\n", "\n", "\n", "def compile():\n", " #parser = argparse.ArgumentParser()\n", " #parser.add_argument('input_pb_path', help='Input serialized GraphDef protobuf')\n", " #parser.add_argument('output_pb_path', help='Ouput serialized GraphDef protobuf')\n", " #parser.add_argument('--net_resolution', default='656x368', help='Network resolution in WxH format, e. g., --net_resolution=656x368')\n", " #parser.add_argument('--debug_verify', action='store_true')\n", " #args = parser.parse_args()\n", " \n", " input_pb_path = './graph_opt.pb'\n", " net_resolution = '656x368'\n", " output_pb_path = './graph_opt_neuron_' + net_resolution + '.pb'\n", " \n", " debug_verify = 'store_true'\n", " dim_w, dim_h = net_resolution.split('x')\n", " dim_w = int(dim_w)\n", " dim_h = int(dim_h)\n", " graph_def = tf.GraphDef()\n", " with open(input_pb_path, 'rb') as f:\n", " graph_def.ParseFromString(f.read())\n", "\n", " if debug_verify:\n", " np.random.seed(0)\n", " feed_dict = {'image:0': np.random.rand(1, dim_h, dim_w, 3)}\n", " output_name = 'Openpose/concat_stage7:0'\n", " with tf.Session(graph=tf.Graph()) as sess:\n", " tf.import_graph_def(graph_def, name='')\n", " result_reference = sess.run(output_name, feed_dict)\n", "\n", " preprocessing_ops = {'preprocess_divide', 'preprocess_divide/y', 'preprocess_subtract', 'preprocess_subtract/y'}\n", " graph_def = nhwc_to_nchw(graph_def, preprocessing_ops)\n", " graph_def = inline_float32_to_float16(graph_def, preprocessing_ops)\n", " with tf.Session(graph=tf.Graph()) as sess:\n", " tf.import_graph_def(graph_def, name='')\n", " no_fuse_ops = preprocessing_ops.union({'Openpose/concat_stage7'})\n", " infer_graph = tfn.graph_util.inference_graph_from_session(\n", " sess, shape_feed_dict={'image:0': [1, dim_h, dim_w, 3]}, output_tensors=['Openpose/concat_stage7:0'],\n", " no_fuse_ops=no_fuse_ops, dynamic_batch_size=True,\n", " )\n", " with open(output_pb_path, 'wb') as f:\n", " f.write(infer_graph.as_graph_def().SerializeToString())\n", "\n", " if debug_verify:\n", " with tf.Session(graph=infer_graph) as sess:\n", " result_compiled = sess.run(output_name, feed_dict)\n", " np.testing.assert_allclose(result_compiled, result_reference, rtol=1e-2, atol=1e-3)\n", "\n", "\n", "def inline_float32_to_float16(graph_def, preprocessing_ops):\n", " float32_enum = tf.float32.as_datatype_enum\n", " float16_enum = tf.float16.as_datatype_enum\n", " graph = tf.Graph()\n", " with graph.as_default():\n", " tf.import_graph_def(graph_def, name='')\n", " graph_def = graph.as_graph_def()\n", " for node in graph_def.node:\n", " if node.name in preprocessing_ops or node.op == 'Placeholder':\n", " cast_input_node_name = node.name\n", " continue\n", " if node.op == 'Const':\n", " if node.attr['dtype'].type == float32_enum:\n", " node.attr['dtype'].type = float16_enum\n", " tensor_def = node.attr['value'].tensor\n", " tensor_def.dtype = float16_enum\n", " if tensor_def.tensor_content:\n", " const_np = np.frombuffer(tensor_def.tensor_content, dtype=np.float32).astype(np.float16)\n", " tensor_def.tensor_content = const_np.tobytes()\n", " elif len(tensor_def.float_val):\n", " const_np = np.array(tensor_def.float_val).astype(np.float16).view(np.uint16)\n", " tensor_def.float_val[:] = []\n", " tensor_def.half_val[:] = list(const_np)\n", " else:\n", " raise NotImplementedError\n", " elif 'T' in node.attr and node.attr['T'].type == float32_enum:\n", " node.attr['T'].type = float16_enum\n", " for node in graph_def.node:\n", " if node.name == cast_input_node_name:\n", " node.name = '{}_PreCastFloat32ToFlot16'.format(node.name)\n", " input_node = node\n", " break\n", " cast_input_node = _gen_cast_node_def(cast_input_node_name, tf.float16, input_node)\n", "\n", " output_node = graph_def.node[-1]\n", " cast_output_node_name = output_node.name\n", " output_node.name = '{}_PreCastFloat16ToFlot32'.format(output_node.name)\n", " cast_output_node = _gen_cast_node_def(cast_output_node_name, tf.float32, output_node)\n", "\n", " preprocessing_ops.add(input_node.name)\n", " new_graph_def = tf.GraphDef()\n", " new_graph_def.node.extend(graph_def.node)\n", " new_graph_def.node.append(cast_input_node)\n", " new_graph_def.node.append(cast_output_node)\n", " graph = tf.Graph()\n", " with graph.as_default():\n", " tf.import_graph_def(new_graph_def, name='')\n", " return graph.as_graph_def()\n", "\n", "\n", "def nhwc_to_nchw(graph_def, preprocessing_ops):\n", " graph = tf.Graph()\n", " with graph.as_default():\n", " tf.import_graph_def(graph_def, name='')\n", " graph_def = graph.as_graph_def()\n", " node_name_to_node = {node.name: node for node in graph_def.node}\n", " for node in graph_def.node:\n", " if node.name in preprocessing_ops or node.op == 'Placeholder':\n", " transpose_input_node_name = node.name\n", " continue\n", " if node.op == 'Conv2D':\n", " node.attr['data_format'].s = b'NCHW'\n", " strides = node.attr['strides'].list.i\n", " strides[:] = [strides[0], strides[3], strides[1], strides[2]]\n", " elif node.op == 'BiasAdd':\n", " if node.name != 'probs/BiasAdd':\n", " node.attr['data_format'].s = b'NCHW'\n", " elif node.op == 'MaxPool':\n", " node.attr['data_format'].s = b'NCHW'\n", " ksize = node.attr['ksize'].list.i\n", " ksize[:] = [ksize[0], ksize[3], ksize[1], ksize[2]]\n", " strides = node.attr['strides'].list.i\n", " strides[:] = [strides[0], strides[3], strides[1], strides[2]]\n", " elif node.op in {'Concat', 'ConcatV2'}:\n", " node_axes = node_name_to_node[node.input[-1]]\n", " node_axes.attr['value'].tensor.int_val[:] = [1]\n", " for node in graph_def.node:\n", " if node.name == transpose_input_node_name:\n", " node.name = '{}_PreTransposeNHWC2NCHW'.format(node.name)\n", " input_node = node\n", " break\n", " transpose_input_node, transpose_input_perm_node = _gen_transpose_def(transpose_input_node_name, [0, 3, 1, 2], input_node)\n", "\n", " output_node = graph_def.node[-1]\n", " transpose_output_node_name = output_node.name\n", " output_node.name = '{}_PreTransposeNCHW2NHWC'.format(output_node.name)\n", " transpose_output_node, transpose_output_perm_node = _gen_transpose_def(transpose_output_node_name, [0, 2, 3, 1], output_node)\n", "\n", " preprocessing_ops.add(input_node.name)\n", " preprocessing_ops.add(transpose_input_perm_node.name)\n", " new_graph_def = tf.GraphDef()\n", " new_graph_def.node.extend(graph_def.node)\n", " new_graph_def.node.append(transpose_input_perm_node)\n", " new_graph_def.node.append(transpose_input_node)\n", " new_graph_def.node.append(transpose_output_perm_node)\n", " new_graph_def.node.append(transpose_output_node)\n", " graph = tf.Graph()\n", " with graph.as_default():\n", " tf.import_graph_def(new_graph_def, name='')\n", " return graph.as_graph_def()\n", "\n", "\n", "def _gen_cast_node_def(name, target_dtype, input_node):\n", " cast_node = tf.NodeDef(name=name, op='Cast')\n", " cast_node.input.append(input_node.name)\n", " cast_node.attr['DstT'].type = target_dtype.as_datatype_enum\n", " cast_node.attr['SrcT'].type = input_node.attr['T'].type\n", " cast_node.attr['Truncate'].b = False\n", " return cast_node\n", "\n", "\n", "def _gen_transpose_def(name, perm, input_node):\n", " perm_node = tf.NodeDef(name='{}/perm'.format(name), op='Const')\n", " perm_node.attr['dtype'].type = tf.int32.as_datatype_enum\n", " tensor_def = perm_node.attr['value'].tensor\n", " tensor_def.dtype = tf.int32.as_datatype_enum\n", " tensor_def.tensor_shape.dim.append(TensorShapeProto.Dim(size=4))\n", " tensor_def.tensor_content = np.array(perm, dtype=np.int32).tobytes()\n", " transpose_node = tf.NodeDef(name=name, op='Transpose')\n", " transpose_node.input.append(input_node.name)\n", " transpose_node.input.append(perm_node.name)\n", " transpose_node.attr['T'].type = input_node.attr['T'].type\n", " transpose_node.attr['Tperm'].type = tf.int32.as_datatype_enum\n", " return transpose_node, perm_node\n" ] }, { "cell_type": "code", "execution_count": null, "id": "88c41e01", "metadata": { "scrolled": true }, "outputs": [], "source": [ "compile()\n", "\n", "# Sample output will look like below:\n", "# WARNING:tensorflow:From :47: inference_graph_from_session (from tensorflow_neuron.python.graph_util) is deprecated and will be removed in a future version.\n", "# Instructions for updating:\n", "# Please refer to AWS documentation on Neuron integrated TensorFlow 2.0.\n", "# INFO:tensorflow:Froze 0 variables.\n", "# INFO:tensorflow:Converted 0 variables to const ops.\n", "# INFO:tensorflow:fusing subgraph {subgraph neuron_op_ed41d2deb8c54255 with input tensors [\"\"], output tensors [\"\"]} with neuron-cc\n", "# INFO:tensorflow:Number of operations in TensorFlow session: 474\n", "# INFO:tensorflow:Number of operations after tf.neuron optimizations: 474\n", "# INFO:tensorflow:Number of operations placed on Neuron runtime: 465" ] }, { "cell_type": "markdown", "id": "5a9af0c7", "metadata": {}, "source": [ "## Deploy\n", "Using same instance to deploy the model.\n", "In case of different deployment instance, launch a deployment inf1 instance and copy the AWS Neuron optimized tensorflow frozen graph graph_opt_neuron_656x368.pb to the deployment inf1 instance. The smallest instance type inf1.xlarge is sufficient for this demo.\n", "\n", "Your graph_opt_neuron_656x368.pb can now be plugged into https://github.com/ildoonet seemlessly if you have tensorflow-neuron installed. When it is used at runtime, please ensure that the image resolution is the same as compile-time image resolution, i. e., 656x368.\n", "\n", "Measure performance on the compiled frozen graph using dummy inputs.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0481d049", "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "Copyright (C) 2020, Amazon.com. All Rights Reserved\n", "\"\"\"\n", "import os\n", "import atexit\n", "import time\n", "import math\n", "import json\n", "from collections import OrderedDict, Counter\n", "from contextlib import contextmanager, ContextDecorator\n", "from functools import wraps\n", "from tensorflow.python.client import session\n", "from tensorflow.python.platform import tf_logging as logging\n", "\n", "\n", "class measure_performance(ContextDecorator):\n", " \"\"\"Convenient tool for performance measurements.\n", " Can be apply on tensorflow session.run, tf-serving unary gRPC calls, or a given custom function.\n", " Usage:\n", " To generate performance report for the entire Python or gRPC-client process, insert\n", " the following function call before running inferences:\n", " `tfn.measure_performance()`\n", " Then latency/throughput report will be generated when the process terminates.\n", " Alternatively, it is possible to use `tfn.measure_performance` programmatically\n", " as a context manager. Performance measurement will be done for all inferences\n", " happening under this context. Report will be displayed as INFO level log when exiting\n", " the context. It is also possible to obtain a JSON format report in Python.\n", " For example:\n", " ```\n", " with tfn.measure_performance() as perf:\n", " ... (run some inferences) ...\n", " report_json = perf.report()\n", " report_full_json = perf.report(verbosity=1)\n", " ```\n", " \"\"\"\n", "\n", " def __init__(self, func=None, window_size=1):\n", " self.perf_tracker = PerformanceTracker(window_size)\n", " atexit.register(self.perf_tracker.report)\n", " self._original_run = session.Session.run\n", " self._original_grpc_call = None\n", " if callable(func):\n", " self.perf_tracker.register_func(self._track_performance(func))\n", " else:\n", " session.Session.run = self._track_performance(session.Session.run)\n", " try:\n", " import grpc\n", " from tensorflow_serving.apis import prediction_service_pb2_grpc\n", " dummy_stub = prediction_service_pb2_grpc.PredictionServiceStub(grpc.insecure_channel(''))\n", " self._grpc_callable_type = type(dummy_stub.Predict)\n", " self._original_grpc_call = self._grpc_callable_type.__call__\n", " except ImportError:\n", " pass\n", " if callable(self._original_grpc_call):\n", " self._grpc_callable_type.__call__ = self._track_performance(\n", " grpc._channel._UnaryUnaryMultiCallable.__call__\n", " )\n", "\n", " def __enter__(self):\n", " return self.perf_tracker\n", "\n", " def __exit__(self, *exc):\n", " atexit.unregister(self.perf_tracker.report)\n", " self.perf_tracker.report()\n", " session.Session.run = self._original_run\n", " if self._original_grpc_call is not None:\n", " self._grpc_callable_type.__call__ = self._original_grpc_call\n", " return False\n", "\n", " def _track_performance(self, func):\n", " @wraps(func)\n", " def wrapper(*args, **kwargs):\n", " start = time.time()\n", " result = func(*args, **kwargs)\n", " end = time.time()\n", " self.perf_tracker.add_timestamps(start, end)\n", " return result\n", " return wrapper\n", "\n", "\n", "class PerformanceTracker(ContextDecorator):\n", "\n", " description = (\n", " \"Latency unit: second. Throughput unit: number of batched inferences per second. \"\n", " \"Reported throughput is a lower bound of the actual throughput as inferences \"\n", " \"spanning across window boundaries are not counted towards any of the windows. \"\n", " \"'Quiet' periods (i. e., window buckets where the inference function is not called) \"\n", " \"are not counted towards the reported average throughput.\"\n", " )\n", "\n", " def __init__(self, window_size):\n", " self.window_size = window_size\n", " self.timestamps_list = []\n", " self._func = None\n", "\n", " def __call__(self, *args, **kwargs):\n", " return self._func(*args, **kwargs)\n", "\n", " def register_func(self, func):\n", " self._func = func\n", "\n", " def add_timestamps(self, start, end):\n", " self.timestamps_list.append([start, end])\n", "\n", " def report(self, verbosity=0):\n", " if self.timestamps_list:\n", " latency_list = [end - start for start, end in self.timestamps_list]\n", " latency_json = {\n", " 'p50': percentile(latency_list, 50),\n", " 'p90': percentile(latency_list, 90),\n", " 'p99': percentile(latency_list, 99),\n", " 'p100': percentile(latency_list, 100),\n", " }\n", " bucketed_timestamps = [self._get_bucket(start, end) for start, end in self.timestamps_list]\n", " counted_buckets = Counter(item for item in bucketed_timestamps if item is not None)\n", " bucket_throughputs = [(key, value / self.window_size) for key, value in sorted(counted_buckets.items())]\n", " busy_throughputs = list(OrderedDict((key, value) for key, value in bucket_throughputs).values())\n", " throughput_json = {\n", " 'peak': max(busy_throughputs),\n", " 'median': percentile(busy_throughputs, 50),\n", " 'average': sum(busy_throughputs) / len(busy_throughputs),\n", " }\n", " if verbosity > 0:\n", " throughput_json['trend'] = busy_throughputs\n", " report_json = {\n", " 'pid': os.getpid(),\n", " 'throughput': throughput_json,\n", " 'latency': latency_json,\n", " 'description': PerformanceTracker.description,\n", " }\n", " with _logging_show_info():\n", " logging.info('performance report:\\n{}'.format(json.dumps(report_json, indent=4)))\n", " return report_json\n", "\n", " def _get_bucket(self, start, end):\n", " bucketed_start = math.floor(start / self.window_size) * self.window_size\n", " bucketed_end = math.ceil(end / self.window_size) * self.window_size\n", " if bucketed_end - bucketed_start == self.window_size:\n", " return bucketed_start\n", " else:\n", " return None\n", "\n", "\n", "def percentile(number_list, percent):\n", " pos_float = len(number_list) * percent / 100\n", " max_pos = len(number_list) - 1\n", " pos_floor = min(math.floor(pos_float), max_pos)\n", " pos_ceil = min(math.ceil(pos_float), max_pos)\n", " number_list = sorted(number_list)\n", " return number_list[pos_ceil] if pos_float - pos_floor > 0.5 else number_list[pos_floor]\n", "\n", "\n", "@contextmanager\n", "def _logging_show_info():\n", " try:\n", " verbosity = logging.get_verbosity()\n", " logging.set_verbosity(logging.INFO)\n", " yield\n", " finally:\n", " logging.set_verbosity(verbosity)" ] }, { "cell_type": "code", "execution_count": null, "id": "960c6aa9", "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", "Below are the inputs for compiled frozen graph \n", "\n", "pb_path is a /path/graph_opt_neuron_656x368.pb\n", "num_thread = 8 ( Number of threads that work on each tensorflow session ) \n", "batch_size =1 ( batch_size )\n", "net_resolution ,default=656x368\n", "num_inferences = 200\n", "\"\"\"\n", "import os\n", "from concurrent import futures\n", "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow.neuron as tfn\n", "\n", "def run_with_dummy(sess, dummy_feed_dict, num_inferences):\n", " for _ in range(num_inferences):\n", " sess.run('Openpose/concat_stage7:0', dummy_feed_dict)\n", " \n", "def main():\n", " NUM_NEURON_CORES = 16\n", " pb_path = './graph_opt_neuron_656x368.pb'\n", " num_thread = 8\n", " batch_size = 1\n", " net_resolution = '656x368'\n", " num_inferences = 200\n", " dim_w, dim_h = net_resolution.split('x')\n", " dim_w = int(dim_w)\n", " dim_h = int(dim_h)\n", " graph_def = tf.GraphDef()\n", " with open(pb_path, 'rb') as f:\n", " graph_def.ParseFromString(f.read())\n", " \n", " graph_def = tfn.graph_util.tag_multicore(graph_def, NUM_NEURON_CORES)\n", " \n", " with tfn.measure_performance() as perf:\n", " with tf.Session(graph=tf.Graph()) as sess:\n", " tf.import_graph_def(graph_def, name='')\n", " input_name = 'image:0'\n", " input_shape = sess.graph.get_tensor_by_name(input_name).shape.as_list()\n", " input_shape[0] = batch_size\n", " input_shape[1] = dim_h\n", " input_shape[2] = dim_w\n", " dummy_feed_dict = {input_name: np.zeros(input_shape).astype(np.float32)}\n", " with futures.ThreadPoolExecutor(max_workers=num_thread) as executor:\n", " fut_list = [executor.submit(run_with_dummy, sess, dummy_feed_dict, num_inferences) for _ in range(num_thread)]\n", " res_list = [fut.result() for fut in fut_list] \n", "\n", "main()\n", "\n", "# Sample output will look like below:\n", "# INFO:tensorflow:performance report:\n", "# {\n", "# \"pid\": 17713,\n", "# \"throughput\": {\n", "# \"peak\": 66.0,\n", "# \"median\": 64.0,\n", "# \"average\": 61.56521739130435\n", "# },\n", "# \"latency\": {\n", "# \"p50\": 0.1106414794921875,\n", "# \"p90\": 0.11212301254272461,\n", "# \"p99\": 0.11337876319885254,\n", "# \"p100\": 7.08282732963562\n", "# },\n", "# \"description\": \"Latency unit: second. Throughput unit: number of batched inferences per second. Reported throughput is a lower bound of the actual throughput as inferences spanning across window boundaries are not counted towards any of the windows. 'Quiet' periods (i. e., window buckets where the inference function is not called) are not counted towards the reported average throughput.\"\n", "# }" ] }, { "cell_type": "raw", "id": "4f15e776", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.9 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.9" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 5 }