{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluate YOLO v4 on Inferentia" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "This tutorial walks through compiling and evaluating YOLO v4 model implemented in PyTorch on Inferentia. \n", "\n", "The tutorial has five main sections:\n", "\n", "1. Define YOLO v4 model in PyTorch\n", "2. Download the COCO 2017 evaluation dataset and define the data loader function\n", "3. Build, Compile, and Save Neuron-Optimized YOLO v4 TorchScript\n", "4. Evaluate Accuracy on the COCO 2017 Dataset\n", "5. Benchmark COCO Dataset Performance of the Neuron-Optimized TorchScript\n", "\n", "Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the [PyTorch Installation Guide](../../../frameworks/torch/torch-neuron/setup/pytorch-install.html). You can select the kernel from the \"Kernel -> Change Kernel\" option on the top of this Jupyter notebook page." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install Dependencies:\n", "This tutorial requires the following pip packages:\n", "\n", "- `torch-neuron`\n", "- `torchvision`\n", "- `pillow`\n", "- `pycocotools`\n", "- `neuron-cc[tensorflow]`\n", "\n", "Many of these packages will be installed by default when configuring your environment using the Neuron PyTorch setup guide. The additional dependencies must be installed here." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install --upgrade pillow pycocotools " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 1: Define YOLO v4 model in PyTorch \n", "The following PyTorch model definition is from https://github.com/Tianxiaomo/pytorch-YOLOv4/." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "import torch.neuron\n", "from torch import nn\n", "import torch.nn.functional as F\n", "import os\n", "import warnings\n", "\n", "# Setting up NeuronCore groups for inf1.6xlarge with 16 cores\n", "n_cores = 16 # This value should be 4 on inf1.xlarge and inf1.2xlarge\n", "os.environ['NEURON_RT_NUM_CORES'] = str(n_cores)\n", "\n", "\n", "class Mish(torch.nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", "\n", " def forward(self, x):\n", " x = x * (torch.tanh(torch.nn.functional.softplus(x)))\n", " return x\n", "\n", "\n", "class Upsample(nn.Module):\n", " def __init__(self):\n", " super(Upsample, self).__init__()\n", "\n", " def forward(self, x, target_size, inference=False):\n", " assert (x.data.dim() == 4)\n", "\n", " if inference:\n", "\n", " return x.view(x.size(0), x.size(1), x.size(2), 1, x.size(3), 1).\\\n", " expand(x.size(0), x.size(1), x.size(2), target_size[2] // x.size(2), x.size(3), target_size[3] // x.size(3)).\\\n", " contiguous().view(x.size(0), x.size(1), target_size[2], target_size[3])\n", " else:\n", " return F.interpolate(x, size=(target_size[2], target_size[3]), mode='nearest')\n", "\n", "\n", "class Conv_Bn_Activation(nn.Module):\n", " def __init__(self, in_channels, out_channels, kernel_size, stride, activation, bn=True, bias=False):\n", " super().__init__()\n", " pad = (kernel_size - 1) // 2\n", "\n", " self.conv = nn.ModuleList()\n", " if bias:\n", " self.conv.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad))\n", " else:\n", " self.conv.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad, bias=False))\n", " if bn:\n", " self.conv.append(nn.BatchNorm2d(out_channels))\n", " if activation == \"mish\":\n", " self.conv.append(Mish())\n", " elif activation == \"relu\":\n", " self.conv.append(nn.ReLU(inplace=True))\n", " elif activation == \"leaky\":\n", " self.conv.append(nn.LeakyReLU(0.1, inplace=True))\n", " elif activation == \"linear\":\n", " pass\n", " else:\n", " print(\"activate error !!! {} {} {}\".format(sys._getframe().f_code.co_filename,\n", " sys._getframe().f_code.co_name, sys._getframe().f_lineno))\n", "\n", " def forward(self, x):\n", " for l in self.conv:\n", " x = l(x)\n", " return x\n", "\n", "\n", "class ResBlock(nn.Module):\n", " \"\"\"\n", " Sequential residual blocks each of which consists of \\\n", " two convolution layers.\n", " Args:\n", " ch (int): number of input and output channels.\n", " nblocks (int): number of residual blocks.\n", " shortcut (bool): if True, residual tensor addition is enabled.\n", " \"\"\"\n", "\n", " def __init__(self, ch, nblocks=1, shortcut=True):\n", " super().__init__()\n", " self.shortcut = shortcut\n", " self.module_list = nn.ModuleList()\n", " for i in range(nblocks):\n", " resblock_one = nn.ModuleList()\n", " resblock_one.append(Conv_Bn_Activation(ch, ch, 1, 1, 'mish'))\n", " resblock_one.append(Conv_Bn_Activation(ch, ch, 3, 1, 'mish'))\n", " self.module_list.append(resblock_one)\n", "\n", " def forward(self, x):\n", " for module in self.module_list:\n", " h = x\n", " for res in module:\n", " h = res(h)\n", " x = x + h if self.shortcut else h\n", " return x\n", "\n", "\n", "class DownSample1(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = Conv_Bn_Activation(3, 32, 3, 1, 'mish')\n", "\n", " self.conv2 = Conv_Bn_Activation(32, 64, 3, 2, 'mish')\n", " self.conv3 = Conv_Bn_Activation(64, 64, 1, 1, 'mish')\n", " # [route]\n", " # layers = -2\n", " self.conv4 = Conv_Bn_Activation(64, 64, 1, 1, 'mish')\n", "\n", " self.conv5 = Conv_Bn_Activation(64, 32, 1, 1, 'mish')\n", " self.conv6 = Conv_Bn_Activation(32, 64, 3, 1, 'mish')\n", " # [shortcut]\n", " # from=-3\n", " # activation = linear\n", "\n", " self.conv7 = Conv_Bn_Activation(64, 64, 1, 1, 'mish')\n", " # [route]\n", " # layers = -1, -7\n", " self.conv8 = Conv_Bn_Activation(128, 64, 1, 1, 'mish')\n", "\n", " def forward(self, input):\n", " x1 = self.conv1(input)\n", " x2 = self.conv2(x1)\n", " x3 = self.conv3(x2)\n", " # route -2\n", " x4 = self.conv4(x2)\n", " x5 = self.conv5(x4)\n", " x6 = self.conv6(x5)\n", " # shortcut -3\n", " x6 = x6 + x4\n", "\n", " x7 = self.conv7(x6)\n", " # [route]\n", " # layers = -1, -7\n", " x7 = torch.cat([x7, x3], dim=1)\n", " x8 = self.conv8(x7)\n", " return x8\n", "\n", "\n", "class DownSample2(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = Conv_Bn_Activation(64, 128, 3, 2, 'mish')\n", " self.conv2 = Conv_Bn_Activation(128, 64, 1, 1, 'mish')\n", " # r -2\n", " self.conv3 = Conv_Bn_Activation(128, 64, 1, 1, 'mish')\n", "\n", " self.resblock = ResBlock(ch=64, nblocks=2)\n", "\n", " # s -3\n", " self.conv4 = Conv_Bn_Activation(64, 64, 1, 1, 'mish')\n", " # r -1 -10\n", " self.conv5 = Conv_Bn_Activation(128, 128, 1, 1, 'mish')\n", "\n", " def forward(self, input):\n", " x1 = self.conv1(input)\n", " x2 = self.conv2(x1)\n", " x3 = self.conv3(x1)\n", "\n", " r = self.resblock(x3)\n", " x4 = self.conv4(r)\n", "\n", " x4 = torch.cat([x4, x2], dim=1)\n", " x5 = self.conv5(x4)\n", " return x5\n", "\n", "\n", "class DownSample3(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = Conv_Bn_Activation(128, 256, 3, 2, 'mish')\n", " self.conv2 = Conv_Bn_Activation(256, 128, 1, 1, 'mish')\n", " self.conv3 = Conv_Bn_Activation(256, 128, 1, 1, 'mish')\n", "\n", " self.resblock = ResBlock(ch=128, nblocks=8)\n", " self.conv4 = Conv_Bn_Activation(128, 128, 1, 1, 'mish')\n", " self.conv5 = Conv_Bn_Activation(256, 256, 1, 1, 'mish')\n", "\n", " def forward(self, input):\n", " x1 = self.conv1(input)\n", " x2 = self.conv2(x1)\n", " x3 = self.conv3(x1)\n", "\n", " r = self.resblock(x3)\n", " x4 = self.conv4(r)\n", "\n", " x4 = torch.cat([x4, x2], dim=1)\n", " x5 = self.conv5(x4)\n", " return x5\n", "\n", "\n", "class DownSample4(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = Conv_Bn_Activation(256, 512, 3, 2, 'mish')\n", " self.conv2 = Conv_Bn_Activation(512, 256, 1, 1, 'mish')\n", " self.conv3 = Conv_Bn_Activation(512, 256, 1, 1, 'mish')\n", "\n", " self.resblock = ResBlock(ch=256, nblocks=8)\n", " self.conv4 = Conv_Bn_Activation(256, 256, 1, 1, 'mish')\n", " self.conv5 = Conv_Bn_Activation(512, 512, 1, 1, 'mish')\n", "\n", " def forward(self, input):\n", " x1 = self.conv1(input)\n", " x2 = self.conv2(x1)\n", " x3 = self.conv3(x1)\n", "\n", " r = self.resblock(x3)\n", " x4 = self.conv4(r)\n", "\n", " x4 = torch.cat([x4, x2], dim=1)\n", " x5 = self.conv5(x4)\n", " return x5\n", "\n", "\n", "class DownSample5(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = Conv_Bn_Activation(512, 1024, 3, 2, 'mish')\n", " self.conv2 = Conv_Bn_Activation(1024, 512, 1, 1, 'mish')\n", " self.conv3 = Conv_Bn_Activation(1024, 512, 1, 1, 'mish')\n", "\n", " self.resblock = ResBlock(ch=512, nblocks=4)\n", " self.conv4 = Conv_Bn_Activation(512, 512, 1, 1, 'mish')\n", " self.conv5 = Conv_Bn_Activation(1024, 1024, 1, 1, 'mish')\n", "\n", " def forward(self, input):\n", " x1 = self.conv1(input)\n", " x2 = self.conv2(x1)\n", " x3 = self.conv3(x1)\n", "\n", " r = self.resblock(x3)\n", " x4 = self.conv4(r)\n", "\n", " x4 = torch.cat([x4, x2], dim=1)\n", " x5 = self.conv5(x4)\n", " return x5\n", "\n", "\n", "class Neck(nn.Module):\n", " def __init__(self, inference=False):\n", " super().__init__()\n", " self.inference = inference\n", "\n", " self.conv1 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')\n", " self.conv2 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')\n", " self.conv3 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')\n", " # SPP\n", " self.maxpool1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5 // 2)\n", " self.maxpool2 = nn.MaxPool2d(kernel_size=9, stride=1, padding=9 // 2)\n", " self.maxpool3 = nn.MaxPool2d(kernel_size=13, stride=1, padding=13 // 2)\n", "\n", " # R -1 -3 -5 -6\n", " # SPP\n", " self.conv4 = Conv_Bn_Activation(2048, 512, 1, 1, 'leaky')\n", " self.conv5 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')\n", " self.conv6 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')\n", " self.conv7 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')\n", " # UP\n", " self.upsample1 = Upsample()\n", " # R 85\n", " self.conv8 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')\n", " # R -1 -3\n", " self.conv9 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')\n", " self.conv10 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')\n", " self.conv11 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')\n", " self.conv12 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')\n", " self.conv13 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')\n", " self.conv14 = Conv_Bn_Activation(256, 128, 1, 1, 'leaky')\n", " # UP\n", " self.upsample2 = Upsample()\n", " # R 54\n", " self.conv15 = Conv_Bn_Activation(256, 128, 1, 1, 'leaky')\n", " # R -1 -3\n", " self.conv16 = Conv_Bn_Activation(256, 128, 1, 1, 'leaky')\n", " self.conv17 = Conv_Bn_Activation(128, 256, 3, 1, 'leaky')\n", " self.conv18 = Conv_Bn_Activation(256, 128, 1, 1, 'leaky')\n", " self.conv19 = Conv_Bn_Activation(128, 256, 3, 1, 'leaky')\n", " self.conv20 = Conv_Bn_Activation(256, 128, 1, 1, 'leaky')\n", "\n", " def forward(self, input, downsample4, downsample3, inference=False):\n", " x1 = self.conv1(input)\n", " x2 = self.conv2(x1)\n", " x3 = self.conv3(x2)\n", " # SPP\n", " m1 = self.maxpool1(x3)\n", " m2 = self.maxpool2(x3)\n", " m3 = self.maxpool3(x3)\n", " spp = torch.cat([m3, m2, m1, x3], dim=1)\n", " # SPP end\n", " x4 = self.conv4(spp)\n", " x5 = self.conv5(x4)\n", " x6 = self.conv6(x5)\n", " x7 = self.conv7(x6)\n", " # UP\n", " up = self.upsample1(x7, downsample4.size(), self.inference)\n", " # R 85\n", " x8 = self.conv8(downsample4)\n", " # R -1 -3\n", " x8 = torch.cat([x8, up], dim=1)\n", "\n", " x9 = self.conv9(x8)\n", " x10 = self.conv10(x9)\n", " x11 = self.conv11(x10)\n", " x12 = self.conv12(x11)\n", " x13 = self.conv13(x12)\n", " x14 = self.conv14(x13)\n", "\n", " # UP\n", " up = self.upsample2(x14, downsample3.size(), self.inference)\n", " # R 54\n", " x15 = self.conv15(downsample3)\n", " # R -1 -3\n", " x15 = torch.cat([x15, up], dim=1)\n", "\n", " x16 = self.conv16(x15)\n", " x17 = self.conv17(x16)\n", " x18 = self.conv18(x17)\n", " x19 = self.conv19(x18)\n", " x20 = self.conv20(x19)\n", " return x20, x13, x6\n", "\n", "\n", "class Yolov4Head(nn.Module):\n", " def __init__(self, output_ch, n_classes, inference=False):\n", " super().__init__()\n", " self.inference = inference\n", "\n", " self.conv1 = Conv_Bn_Activation(128, 256, 3, 1, 'leaky')\n", " self.conv2 = Conv_Bn_Activation(256, output_ch, 1, 1, 'linear', bn=False, bias=True)\n", "\n", " self.yolo1 = YoloLayer(\n", " anchor_mask=[0, 1, 2], num_classes=n_classes,\n", " anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],\n", " num_anchors=9, stride=8)\n", "\n", " # R -4\n", " self.conv3 = Conv_Bn_Activation(128, 256, 3, 2, 'leaky')\n", "\n", " # R -1 -16\n", " self.conv4 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')\n", " self.conv5 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')\n", " self.conv6 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')\n", " self.conv7 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')\n", " self.conv8 = Conv_Bn_Activation(512, 256, 1, 1, 'leaky')\n", " self.conv9 = Conv_Bn_Activation(256, 512, 3, 1, 'leaky')\n", " self.conv10 = Conv_Bn_Activation(512, output_ch, 1, 1, 'linear', bn=False, bias=True)\n", " \n", " self.yolo2 = YoloLayer(\n", " anchor_mask=[3, 4, 5], num_classes=n_classes,\n", " anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],\n", " num_anchors=9, stride=16)\n", "\n", " # R -4\n", " self.conv11 = Conv_Bn_Activation(256, 512, 3, 2, 'leaky')\n", "\n", " # R -1 -37\n", " self.conv12 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')\n", " self.conv13 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')\n", " self.conv14 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')\n", " self.conv15 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')\n", " self.conv16 = Conv_Bn_Activation(1024, 512, 1, 1, 'leaky')\n", " self.conv17 = Conv_Bn_Activation(512, 1024, 3, 1, 'leaky')\n", " self.conv18 = Conv_Bn_Activation(1024, output_ch, 1, 1, 'linear', bn=False, bias=True)\n", " \n", " self.yolo3 = YoloLayer(\n", " anchor_mask=[6, 7, 8], num_classes=n_classes,\n", " anchors=[12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401],\n", " num_anchors=9, stride=32)\n", "\n", " def forward(self, input1, input2, input3):\n", " x1 = self.conv1(input1)\n", " x2 = self.conv2(x1)\n", "\n", " x3 = self.conv3(input1)\n", " # R -1 -16\n", " x3 = torch.cat([x3, input2], dim=1)\n", " x4 = self.conv4(x3)\n", " x5 = self.conv5(x4)\n", " x6 = self.conv6(x5)\n", " x7 = self.conv7(x6)\n", " x8 = self.conv8(x7)\n", " x9 = self.conv9(x8)\n", " x10 = self.conv10(x9)\n", "\n", " # R -4\n", " x11 = self.conv11(x8)\n", " # R -1 -37\n", " x11 = torch.cat([x11, input3], dim=1)\n", "\n", " x12 = self.conv12(x11)\n", " x13 = self.conv13(x12)\n", " x14 = self.conv14(x13)\n", " x15 = self.conv15(x14)\n", " x16 = self.conv16(x15)\n", " x17 = self.conv17(x16)\n", " x18 = self.conv18(x17)\n", " \n", " if self.inference:\n", " y1 = self.yolo1(x2)\n", " y2 = self.yolo2(x10)\n", " y3 = self.yolo3(x18)\n", "\n", " return get_region_boxes([y1, y2, y3])\n", " \n", " else:\n", " return [x2, x10, x18]\n", "\n", "\n", "class Yolov4(nn.Module):\n", " def __init__(self, yolov4conv137weight=None, n_classes=80, inference=False):\n", " super().__init__()\n", "\n", " output_ch = (4 + 1 + n_classes) * 3\n", "\n", " # backbone\n", " self.down1 = DownSample1()\n", " self.down2 = DownSample2()\n", " self.down3 = DownSample3()\n", " self.down4 = DownSample4()\n", " self.down5 = DownSample5()\n", " # neck\n", " self.neek = Neck(inference)\n", " # yolov4conv137\n", " if yolov4conv137weight:\n", " _model = nn.Sequential(self.down1, self.down2, self.down3, self.down4, self.down5, self.neek)\n", " pretrained_dict = torch.load(yolov4conv137weight)\n", "\n", " model_dict = _model.state_dict()\n", " # 1. filter out unnecessary keys\n", " pretrained_dict = {k1: v for (k, v), k1 in zip(pretrained_dict.items(), model_dict)}\n", " # 2. overwrite entries in the existing state dict\n", " model_dict.update(pretrained_dict)\n", " _model.load_state_dict(model_dict)\n", " \n", " # head\n", " self.head = Yolov4Head(output_ch, n_classes, inference)\n", "\n", "\n", " def forward(self, input):\n", " d1 = self.down1(input)\n", " d2 = self.down2(d1)\n", " d3 = self.down3(d2)\n", " d4 = self.down4(d3)\n", " d5 = self.down5(d4)\n", "\n", " x20, x13, x6 = self.neek(d5, d4, d3)\n", "\n", " output = self.head(x20, x13, x6)\n", " return output\n", "\n", "\n", "def yolo_forward_dynamic(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1,\n", " validation=False):\n", " # Output would be invalid if it does not satisfy this assert\n", " # assert (output.size(1) == (5 + num_classes) * num_anchors)\n", "\n", " # print(output.size())\n", "\n", " # Slice the second dimension (channel) of output into:\n", " # [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ]\n", " # And then into\n", " # bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ]\n", " # batch = output.size(0)\n", " # H = output.size(2)\n", " # W = output.size(3)\n", "\n", " bxy_list = []\n", " bwh_list = []\n", " det_confs_list = []\n", " cls_confs_list = []\n", "\n", " for i in range(num_anchors):\n", " begin = i * (5 + num_classes)\n", " end = (i + 1) * (5 + num_classes)\n", " \n", " bxy_list.append(output[:, begin : begin + 2])\n", " bwh_list.append(output[:, begin + 2 : begin + 4])\n", " det_confs_list.append(output[:, begin + 4 : begin + 5])\n", " cls_confs_list.append(output[:, begin + 5 : end])\n", "\n", " # Shape: [batch, num_anchors * 2, H, W]\n", " bxy = torch.cat(bxy_list, dim=1)\n", " # Shape: [batch, num_anchors * 2, H, W]\n", " bwh = torch.cat(bwh_list, dim=1)\n", "\n", " # Shape: [batch, num_anchors, H, W]\n", " det_confs = torch.cat(det_confs_list, dim=1)\n", " # Shape: [batch, num_anchors * H * W]\n", " det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3))\n", "\n", " # Shape: [batch, num_anchors * num_classes, H, W]\n", " cls_confs = torch.cat(cls_confs_list, dim=1)\n", " # Shape: [batch, num_anchors, num_classes, H * W]\n", " cls_confs = cls_confs.view(output.size(0), num_anchors, num_classes, output.size(2) * output.size(3))\n", " # Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes] \n", " cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(output.size(0), num_anchors * output.size(2) * output.size(3), num_classes)\n", "\n", " # Apply sigmoid(), exp() and softmax() to slices\n", " #\n", " bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1)\n", " bwh = torch.exp(bwh)\n", " det_confs = torch.sigmoid(det_confs)\n", " cls_confs = torch.sigmoid(cls_confs)\n", "\n", " # Prepare C-x, C-y, P-w, P-h (None of them are torch related)\n", " grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(3) - 1, output.size(3)), axis=0).repeat(output.size(2), 0), axis=0), axis=0)\n", " grid_y = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(2) - 1, output.size(2)), axis=1).repeat(output.size(3), 1), axis=0), axis=0)\n", " # grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)\n", " # grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)\n", "\n", " anchor_w = []\n", " anchor_h = []\n", " for i in range(num_anchors):\n", " anchor_w.append(anchors[i * 2])\n", " anchor_h.append(anchors[i * 2 + 1])\n", "\n", " device = None\n", " cuda_check = output.is_cuda\n", " if cuda_check:\n", " device = output.get_device()\n", "\n", " bx_list = []\n", " by_list = []\n", " bw_list = []\n", " bh_list = []\n", "\n", " # Apply C-x, C-y, P-w, P-h\n", " for i in range(num_anchors):\n", " ii = i * 2\n", " # Shape: [batch, 1, H, W]\n", " bx = bxy[:, ii : ii + 1] + torch.tensor(grid_x, device=device, dtype=torch.float32) # grid_x.to(device=device, dtype=torch.float32)\n", " # Shape: [batch, 1, H, W]\n", " by = bxy[:, ii + 1 : ii + 2] + torch.tensor(grid_y, device=device, dtype=torch.float32) # grid_y.to(device=device, dtype=torch.float32)\n", " # Shape: [batch, 1, H, W]\n", " bw = bwh[:, ii : ii + 1] * anchor_w[i]\n", " # Shape: [batch, 1, H, W]\n", " bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i]\n", "\n", " bx_list.append(bx)\n", " by_list.append(by)\n", " bw_list.append(bw)\n", " bh_list.append(bh)\n", "\n", "\n", " ########################################\n", " # Figure out bboxes from slices #\n", " ########################################\n", " \n", " # Shape: [batch, num_anchors, H, W]\n", " bx = torch.cat(bx_list, dim=1)\n", " # Shape: [batch, num_anchors, H, W]\n", " by = torch.cat(by_list, dim=1)\n", " # Shape: [batch, num_anchors, H, W]\n", " bw = torch.cat(bw_list, dim=1)\n", " # Shape: [batch, num_anchors, H, W]\n", " bh = torch.cat(bh_list, dim=1)\n", "\n", " # Shape: [batch, 2 * num_anchors, H, W]\n", " bx_bw = torch.cat((bx, bw), dim=1)\n", " # Shape: [batch, 2 * num_anchors, H, W]\n", " by_bh = torch.cat((by, bh), dim=1)\n", "\n", " # normalize coordinates to [0, 1]\n", " bx_bw /= output.size(3)\n", " by_bh /= output.size(2)\n", "\n", " # Shape: [batch, num_anchors * H * W, 1]\n", " bx = bx_bw[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)\n", " by = by_bh[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)\n", " bw = bx_bw[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)\n", " bh = by_bh[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)\n", "\n", " bx1 = bx - bw * 0.5\n", " by1 = by - bh * 0.5\n", " bx2 = bx1 + bw\n", " by2 = by1 + bh\n", "\n", " # Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]\n", " boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(output.size(0), num_anchors * output.size(2) * output.size(3), 1, 4)\n", " # boxes = boxes.repeat(1, 1, num_classes, 1)\n", "\n", " # boxes: [batch, num_anchors * H * W, 1, 4]\n", " # cls_confs: [batch, num_anchors * H * W, num_classes]\n", " # det_confs: [batch, num_anchors * H * W]\n", "\n", " det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)\n", " confs = cls_confs * det_confs\n", "\n", " # boxes: [batch, num_anchors * H * W, 1, 4]\n", " # confs: [batch, num_anchors * H * W, num_classes]\n", "\n", " return boxes, confs\n", "\n", "class YoloLayer(nn.Module):\n", " \"\"\"\n", " Yolo layer\n", " model_out: while inference,is post-processing inside or outside the model\n", " true:outside\n", " \"\"\"\n", " def __init__(self, anchor_mask=[], num_classes=0, anchors=[], num_anchors=1, stride=32, model_out=False):\n", " super(YoloLayer, self).__init__()\n", " self.anchor_mask = anchor_mask\n", " self.num_classes = num_classes\n", " self.anchors = anchors\n", " self.num_anchors = num_anchors\n", " self.anchor_step = len(anchors) // num_anchors\n", " self.coord_scale = 1\n", " self.noobject_scale = 1\n", " self.object_scale = 5\n", " self.class_scale = 1\n", " self.thresh = 0.6\n", " self.stride = stride\n", " self.seen = 0\n", " self.scale_x_y = 1\n", "\n", " self.model_out = model_out\n", "\n", " def forward(self, output, target=None):\n", " if self.training:\n", " return output\n", " masked_anchors = []\n", " for m in self.anchor_mask:\n", " masked_anchors += self.anchors[m * self.anchor_step:(m + 1) * self.anchor_step]\n", " masked_anchors = [anchor / self.stride for anchor in masked_anchors]\n", "\n", " return yolo_forward_dynamic(output, self.thresh, self.num_classes, masked_anchors, len(self.anchor_mask),scale_x_y=self.scale_x_y)\n", "\n", "\n", "def get_region_boxes(boxes_and_confs):\n", "\n", " # print('Getting boxes from boxes and confs ...')\n", "\n", " boxes_list = []\n", " confs_list = []\n", "\n", " for item in boxes_and_confs:\n", " boxes_list.append(item[0])\n", " confs_list.append(item[1])\n", "\n", " # boxes: [batch, num1 + num2 + num3, 1, 4]\n", " # confs: [batch, num1 + num2 + num3, num_classes]\n", " boxes = torch.cat(boxes_list, dim=1)\n", " confs = torch.cat(confs_list, dim=1)\n", " \n", " return boxes, confs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 2: Download the COCO 2017 evaluation dataset and define the data loader function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "!curl -LO http://images.cocodataset.org/zips/val2017.zip\n", "!curl -LO http://images.cocodataset.org/annotations/annotations_trainval2017.zip\n", "!unzip -q val2017.zip\n", "!unzip annotations_trainval2017.zip" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!ls" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define data loader" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import time\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "import torchvision.datasets as dset\n", "from pycocotools.coco import COCO\n", "\n", "\n", "def get_image_filenames(root=os.getcwd()):\n", " \"\"\"\n", " Generate paths to the coco dataset image files.\n", " \n", " Args:\n", " root (str): The root folder contains.\n", " \n", " Yields:\n", " filename (str): The path to an image file.\n", " \"\"\"\n", " image_path = os.path.join(root, 'val2017')\n", " for root, dirs, files in os.walk(image_path):\n", " for filename in files:\n", " yield os.path.join(image_path, filename)\n", "\n", " \n", "def get_coco_dataloader(coco2017_root, transform, subset_indices=None):\n", " \"\"\"\n", " Create the dataset loader and ground truth coco dataset.\n", " \n", " Arguments:\n", " coco2017_root (str): The root directory to load the data/labels from.\n", " transform (torchvision.Transform): A transform to apply to the images.\n", " subset_indices (list): Indices used to create a subset of the dataset.\n", "\n", " Returns: \n", " loader (iterable): Produces transformed images and labels.\n", " cocoGt (pycocotools.coco.COCO): Contains the ground truth in coco \n", " format.\n", " label_info (dict): A mapping from label id to the human-readable name.\n", " \"\"\"\n", "\n", " # Create the dataset\n", " coco2017_img_path = os.path.join(coco2017_root, 'val2017')\n", " coco2017_ann_path = os.path.join(\n", " coco2017_root, 'annotations/instances_val2017.json')\n", "\n", " # check the number of images in val2017 - Should be 5000\n", " num_files = len(list(get_image_filenames(coco2017_root)))\n", " print('\\nNumber of images in val2017 = {}\\n'.format(num_files))\n", "\n", " # load annotations to decode classification results\n", " with open(coco2017_ann_path) as f:\n", " annotate_json = json.load(f)\n", " label_info = {label[\"id\"]: label[\"name\"]\n", " for label in annotate_json['categories']}\n", "\n", " # initialize COCO ground truth dataset\n", " cocoGt = COCO(coco2017_ann_path)\n", "\n", " # create the dataset using torchvision's coco detection dataset\n", " coco_val_data = dset.CocoDetection(\n", " root=coco2017_img_path, \n", " annFile=coco2017_ann_path, \n", " transform=transform\n", " )\n", "\n", " if subset_indices is not None:\n", " # Create a smaller subset of the data for testing - e.g. to pinpoint error at image 516\n", " coco_val_data = torch.utils.data.Subset(coco_val_data, subset_indices)\n", "\n", " # create the dataloader using torch dataloader\n", " loader = torch.utils.data.DataLoader(coco_val_data, batch_size=1, shuffle=False)\n", "\n", " return loader, cocoGt, label_info\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load dataset\n", "Here 2 dataset loaders are created and the resulting data is displayed\n", "- `orig_coco_val_data_loader`: Contains the original unmodified image\n", "- `coco_val_data_loader`: Contains images of a standardized size of 608x608 pixels " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "coco2017_root = './'\n", "orig_coco_val_data_loader, *_ = get_coco_dataloader(coco2017_root, transforms.ToTensor())\n", "transform = transforms.Compose([transforms.Resize([608, 608]), transforms.ToTensor()])\n", "coco_val_data_loader, cocoGt, label_info = get_coco_dataloader(coco2017_root, transform)\n", "image_orig, _ = next(iter(orig_coco_val_data_loader))\n", "print(image_orig.shape)\n", "image, image_info = next(iter(coco_val_data_loader))\n", "image_id = image_info[0][\"image_id\"].item()\n", "print(image.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define some helper functions for deployment (inference)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def postprocess(boxes, scores, score_threshold=0.05, iou_threshold=0.5):\n", " \"\"\"\n", " Classifies and filters bounding boxes from Yolo V4 output.\n", " \n", " Performs classification, filtering, and non-maximum suppression to remove\n", " boxes that are irrelevant. The result is the filtered set of boxes, the \n", " associated label confidence score, and the predicted label.\n", " \n", " See: https://pytorch.org/docs/stable/torchvision/ops.html#torchvision.ops.nms\n", " \n", " Args:\n", " boxes (torch.Tensor): The Yolo V4 bounding boxes.\n", " scores (torch.Tensor): The categories scores for each box.\n", " score_threshold (float): Ignore boxes with scores below threshold.\n", " iou_threshold (float): Discards boxes with intersection above threshold. \n", " \n", " Returns:\n", " boxes (torch.Tensor): The filtered Yolo V4 bounding boxes.\n", " scores (torch.Tensor): The label score for each box.\n", " labels (torch.Tensor): The label for each box.\n", " \"\"\"\n", " \n", " # shape: [n_batch, n_boxes, 1, 4] => [n_boxes, 4] # Assumes n_batch size is 1\n", " boxes = boxes.squeeze()\n", "\n", " # shape: [n_batch, n_boxes, 80] => [n_boxes, 80] # Assumes n_batch size is 1\n", " scores = scores.squeeze()\n", "\n", " # Classify each box according to the maximum category score\n", " score, column = torch.max(scores, dim=1)\n", "\n", " # Filter out rows for scores which are below threshold\n", " mask = score > score_threshold\n", "\n", " # Filter model output data\n", " boxes = boxes[mask]\n", " score = score[mask]\n", " idxs = column[mask]\n", "\n", " # Perform non-max suppression on all categories at once. shape: [n_keep,]\n", " keep = torchvision.ops.batched_nms(\n", " boxes=boxes, \n", " scores=score, \n", " idxs=idxs,\n", " iou_threshold=iou_threshold,\n", " )\n", "\n", " # The image category id associated with each column\n", " categories = torch.tensor([\n", " 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16,\n", " 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31,\n", " 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43,\n", " 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56,\n", " 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72,\n", " 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85,\n", " 86, 87, 88, 89, 90\n", " ])\n", " \n", " boxes = boxes[keep] # shape: [n_keep, 4]\n", " score = score[keep] # shape: [n_keep,]\n", " idxs = idxs[keep]\n", " label = categories[idxs] # shape: [n_keep,]\n", " \n", " return boxes, score, label\n", "\n", "\n", "def get_results_as_dict(boxes, scores, labels, image_orig):\n", " \"\"\"\n", " Transforms post-processed output into dictionary output.\n", " \n", " This translates the model coordinate bounding boxes (x1, y1, x2, y2) \n", " into a rectangular description (x, y, width, height) scaled to the \n", " original image size.\n", " \n", " Args:\n", " boxes (torch.Tensor): The Yolo V4 bounding boxes.\n", " scores (torch.Tensor): The label score for each box.\n", " labels (torch.Tensor): The label for each box.\n", " image_orig (torch.Tensor): The image to scale the bounding boxes to.\n", " \n", " Returns:\n", " output (dict): The dictionary of rectangle bounding boxes.\n", " \"\"\"\n", " h_size, w_size = image_orig.shape[-2:]\n", "\n", " x1 = boxes[:, 0] * w_size\n", " y1 = boxes[:, 1] * h_size\n", " x2 = boxes[:, 2] * w_size\n", " y2 = boxes[:, 3] * h_size\n", "\n", " width = x2 - x1\n", " height = y2 - y1\n", "\n", " boxes = torch.stack([x1, y1, width, height]).T\n", " return {\n", " 'boxes': boxes.detach().numpy(),\n", " 'labels': labels.detach().numpy(),\n", " 'scores': scores.detach().numpy(),\n", " }\n", "\n", "\n", "def prepare_for_coco_detection(predictions):\n", " \"\"\"\n", " Convert dictionary model predictions into an expected COCO dataset format.\n", " \n", " Args:\n", " predictions (dict): The list of box coordinates, scores, and labels.\n", " \n", " Returns:\n", " output (list[dict]): The list of bounding boxes.\n", " \"\"\"\n", " coco_results = []\n", " for original_id, prediction in predictions.items():\n", " if len(prediction) == 0:\n", " continue\n", "\n", " boxes = prediction[\"boxes\"].tolist()\n", " scores = prediction[\"scores\"].tolist()\n", " labels = prediction[\"labels\"].tolist()\n", "\n", " coco_results.extend(\n", " [\n", " {\n", " \"image_id\": original_id,\n", " \"category_id\": labels[k],\n", " \"bbox\": box,\n", " \"score\": scores[k],\n", " }\n", " for k, box in enumerate(boxes)\n", " ]\n", " )\n", " return coco_results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download pretrained checkpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import requests\n", "\n", "def download_file_from_google_drive(id, destination):\n", " response = requests.post('https://drive.google.com/uc?id='+id+'&confirm=t')\n", " save_response_content(response, destination)\n", "\n", "def save_response_content(response, destination):\n", " CHUNK_SIZE = 32768\n", " with open(destination, \"wb\") as f:\n", " for chunk in response.iter_content(CHUNK_SIZE):\n", " if chunk: # filter out keep-alive new chunks\n", " f.write(chunk)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "download_file_from_google_drive('1wv_LiFeCRYwtpkqREPeI13-gPELBDwuJ', './yolo_v4.pth')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 3: Build, Compile, and Save Neuron-Optimized YOLO v4 TorchScript\n", "### Construct model and load pretrained checkpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "model = Yolov4(yolov4conv137weight=None, n_classes=80, inference=True)\n", "weightfile = \"./yolo_v4.pth\"\n", "pretrained_dict = torch.load(weightfile, map_location=torch.device('cpu'))\n", "model.load_state_dict(pretrained_dict)\n", "model.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Execute inference for a single image and display output" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import matplotlib.patches as patches\n", "\n", "image_orig, _ = next(iter(orig_coco_val_data_loader))\n", "image, _ = next(iter(coco_val_data_loader))\n", "boxes, scores = model(image)\n", "boxes, scores, labels = postprocess(boxes, scores)\n", "result_dict = get_results_as_dict(boxes, scores, labels, image_orig)\n", "\n", "fig, ax = plt.subplots(figsize=(10, 10))\n", "ax.imshow(image_orig.numpy().squeeze(0).transpose(1, 2, 0))\n", "for xywh, _ in zip(result_dict['boxes'], result_dict['labels']):\n", " x, y, w, h = xywh\n", " rect = patches.Rectangle((x, y), w, h, linewidth=1, edgecolor='g', facecolor='none')\n", " ax.add_patch(rect)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "### Run compilation with manually specified device placement\n", "\n", "First, inspect the model without running compilation by adding the `skip_compiler=True` argument to the `torch.neuron.trace` call." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "model_neuron_for_inspection = torch.neuron.trace(model, image, skip_compiler=True)\n", "print(model_neuron_for_inspection)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Inspecting the model, we discover that there are many `aten::slice` operations in some submodules called `YoloLayer`. Although these operations are supported by the neuron-cc compiler, they are not going to run efficiently on the Inferentia hardware. To work it around, we recommend to manually place these operators on CPU." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To manually place `YoloLayer` on CPU, we may make use of the `subgraph_builder_function` argument in `torch.neuron.trace`. It is a callback function that returns `True` or `False` based on information available in `node`. The typical use is a condition based on either `node.name` or `node.type_string`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "def subgraph_builder_function(node):\n", " return 'YoloLayer' not in node.name\n", "\n", "model_neuron = torch.neuron.trace(model, image, subgraph_builder_function=subgraph_builder_function)\n", "model_neuron.save('yolo_v4_neuron.pt')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compilation is now finished and the compiled model has been saved to a local file called 'yolo_v4_neuron.pt'. Saving is important due to the slow compilation process." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 4: Evaluate Accuracy on the COCO 2017 Dataset\n", "### Load compiled model and run inference\n", "To validate accuracy of the compiled model, lets run inference on the COCO 2017 validation dataset. We start by defining a helper function `run_inference`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def run_inference(dataloader, dataloader_orig, model, convert=True, modelName=''):\n", " \"\"\"\n", " Run Yolo V4 inference on the COCO dataset.\n", " \n", " Args:\n", " dataloader (iterable): Data loader of input processed images and labels.\n", " dataloader_orig (iterable): Data loader with original images.\n", " model (torch.nn.Module): The torch model to run inference against.\n", " convert (bool): Set to False when using a vanilla torchvision model that \n", " does not need to be transformed into coco format.\n", " \n", " Returns: \n", " imgIds (list): The list of images with predictions.\n", " cocoDt (pycocotools.coco.COCO): Contains the predictions from the model \n", " in coco format.\n", " \"\"\"\n", " print('\\n================ Starting Inference on {} Images using {} model ================\\n'.format(\n", " len(dataloader), modelName))\n", "\n", " modelName = str(modelName).replace(\" \", \"_\")\n", "\n", " # convert predicition to cocoDt\n", " # code from def evaluate in https://github.com/pytorch/vision/blob/master/references/detection/engine.py\n", " imgIds = []\n", " results = []\n", " skippedImages = []\n", "\n", " # time inference\n", " inference_time = 0.0\n", " for idx, ((image, targets), (image_orig, _)) in enumerate(zip(dataloader, dataloader_orig)):\n", " # if target is empty, skip the image because it breaks the scripted model\n", " if not targets:\n", " skippedImages.append(idx)\n", " continue\n", "\n", " # get the predictions\n", " start_time = time.time()\n", " boxes, scores = model(image)\n", " delta = time.time() - start_time\n", " inference_time += delta\n", " boxes, scores, labels = postprocess(boxes, scores)\n", " outputs = get_results_as_dict(boxes, scores, labels, image_orig)\n", "\n", " res = {target[\"image_id\"].item(): output for target,\n", " output in zip(targets, [outputs])}\n", "\n", " # add the image id to imgIds\n", " image_id = targets[0][\"image_id\"].item()\n", " imgIds.append(image_id)\n", "\n", " # convert the predicition into cocoDt results\n", " pred = prepare_for_coco_detection(res)\n", " results.extend(pred)\n", "\n", " print('\\n==================== Performance Measurement ====================')\n", " print('Finished inference on {} images in {:.2f} seconds'.format(\n", " len(dataloader), inference_time))\n", " print('=================================================================\\n')\n", "\n", " # create bbox detections file\n", " # following code in https://github.com/aws/aws-neuron-sdk/blob/master/src/examples/tensorflow/yolo_v4_demo/evaluate.ipynb\n", " resultsfile = modelName + '_bbox_detections.json'\n", " print('Generating json file...')\n", " with open(resultsfile, 'w') as f:\n", " json.dump(results, f)\n", "\n", " # return COCO api object with loadRes\n", " cocoDt = cocoGt.loadRes(resultsfile)\n", "\n", " return imgIds, cocoDt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The next step is to simply load the compiled model from disk and then run inference." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_neuron = torch.jit.load('yolo_v4_neuron.pt')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "imgIds, cocoDt = run_inference(coco_val_data_loader, orig_coco_val_data_loader, model_neuron)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We then use the standard `pycocotools` routines to generate a report of bounding box precision/recall." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pycocotools.cocoeval import COCOeval\n", "\n", "cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')\n", "cocoEval.params.imgIds = imgIds\n", "cocoEval.evaluate()\n", "cocoEval.accumulate()\n", "cocoEval.summarize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For reference, we may perform the same evaluation on the CPU model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "imgIdsRef, cocoDtRef = run_inference(coco_val_data_loader, orig_coco_val_data_loader, model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cocoEval = COCOeval(cocoGt, cocoDtRef, 'bbox')\n", "cocoEval.params.imgIds = imgIdsRef\n", "cocoEval.evaluate()\n", "cocoEval.accumulate()\n", "cocoEval.summarize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 5: Benchmark COCO Dataset Performance of the Neuron-Optimized TorchScript\n", "The following code snippet sets up data parallel on 16 NeuronCores and runs saturated multi-threaded inference on the Inferentia accelerator. Note that the number of cores (`n_cores`) should be set to the number of available NeuronCores on the current instance." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.neuron\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "import torchvision.datasets as dset\n", "import multiprocessing as mp\n", "from concurrent.futures import ThreadPoolExecutor\n", "import PIL\n", "import os\n", "import time\n", "\n", "n_threads = 16\n", "\n", "def get_image_filenames(root=os.getcwd()):\n", " \"\"\"\n", " Generate paths to the coco dataset image files.\n", " \n", " Args:\n", " root (str): The root folder contains.\n", " \n", " Yields:\n", " filename (str): The path to an image file.\n", " \"\"\"\n", " image_path = os.path.join(root, 'val2017')\n", " for root, dirs, files in os.walk(image_path):\n", " for filename in files:\n", " yield os.path.join(image_path, filename)\n", "\n", "def preprocess(path):\n", " \"\"\"\n", " Load an image and convert to the expected Yolo V4 tensor format.\n", " \n", " Args:\n", " path (str): The image file to load from disk. \n", " \n", " Returns:\n", " result (torch.Tensor): The image for prediction. Shape: [1, 3, 608, 608]\n", " \"\"\"\n", " image = PIL.Image.open(path).convert('RGB')\n", " resized = torchvision.transforms.functional.resize(image, [608, 608])\n", " tensor = torchvision.transforms.functional.to_tensor(resized)\n", " return tensor.unsqueeze(0).to(torch.float32)\n", "\n", "\n", "def load_model(filename='yolo_v4_neuron.pt'):\n", " \"\"\"\n", " Load and pre-warm the Yolo V4 model.\n", " \n", " Args:\n", " filename (str): The location to load the model from.\n", " \n", " Returns:\n", " model (torch.nn.Module): The torch model.\n", " \"\"\"\n", " \n", " # Load model from disk\n", " model = torch.jit.load(filename)\n", "\n", " # Warm up model on neuron by running a single example image\n", " filename = next(iter(get_image_filenames()))\n", " image = preprocess(filename)\n", " model(image)\n", "\n", " return model\n", "\n", "\n", "def task(model, filename):\n", " \"\"\"\n", " The thread task to perform prediction.\n", " \n", " This does the full end-to-end processing of an image from loading from disk\n", " all the way to classifying and filtering bounding boxes.\n", " \n", " Args:\n", " model (torch.nn.Module): The model to run processing with\n", " filename (str): The image file to load from disk. \n", " \n", " Returns:\n", " boxes (torch.Tensor): The Yolo V4 bounding boxes.\n", " scores (torch.Tensor): The label score for each box.\n", " labels (torch.Tensor): The label for each box. \n", " \"\"\"\n", " image = preprocess(filename)\n", " begin = time.time()\n", " boxes, scores = model(image)\n", " delta = time.time() - begin\n", " return postprocess(boxes, scores), delta\n", "\n", "\n", "def benchmark():\n", " \"\"\"\n", " Run a benchmark on the entire COCO dataset against the neuron model.\n", " \"\"\"\n", " \n", " # Load a model into each NeuronCore\n", " models = [load_model() for _ in range(n_cores)]\n", " \n", " # Create input/output lists\n", " filenames = list(get_image_filenames())\n", " results = list()\n", " latency = list()\n", " \n", " # We want to keep track of average completion time per thread\n", " sum_time = 0.0\n", " \n", " # Submit all tasks and wait for them to finish\n", " with ThreadPoolExecutor(n_threads) as pool:\n", " for i, filename in enumerate(filenames):\n", " result = pool.submit(task, models[i % len(models)], filename)\n", " results.append(result)\n", " for result in results:\n", " results, times = result.result() # Note: Outputs unused for benchmark\n", " latency.append(times)\n", " sum_time += times\n", " \n", " print('Duration: ', sum_time / n_threads)\n", " print('Images Per Second:', len(filenames) / (sum_time / n_threads))\n", " print(\"Latency P50: {:.1f}\".format(np.percentile(latency[1000:], 50)*1000.0))\n", " print(\"Latency P90: {:.1f}\".format(np.percentile(latency[1000:], 90)*1000.0))\n", " print(\"Latency P95: {:.1f}\".format(np.percentile(latency[1000:], 95)*1000.0))\n", " print(\"Latency P99: {:.1f}\".format(np.percentile(latency[1000:], 99)*1000.0))\n", "\n", "benchmark()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.9 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.9" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 4 }