This document is relevant for: Inf1

yolo_v4_coco_saved_model.py

   1import os
   2import io
   3from functools import partial
   4import requests
   5import numpy as np
   6import torch
   7import tensorflow as tf
   8from tensorflow import keras
   9from tensorflow.keras import layers
  10
  11
  12
  13def rename_weights(checkpoint):
  14    name_mapping = {
  15        'down1.conv1.conv.0.weight': 'models.0.conv1.weight',
  16        'down1.conv1.conv.1.weight': 'models.0.bn1.weight',
  17        'down1.conv1.conv.1.bias': 'models.0.bn1.bias',
  18        'down1.conv1.conv.1.running_mean': 'models.0.bn1.running_mean',
  19        'down1.conv1.conv.1.running_var': 'models.0.bn1.running_var',
  20        'down1.conv1.conv.1.num_batches_tracked': 'models.0.bn1.num_batches_tracked',
  21        'down1.conv2.conv.0.weight': 'models.1.conv2.weight',
  22        'down1.conv2.conv.1.weight': 'models.1.bn2.weight',
  23        'down1.conv2.conv.1.bias': 'models.1.bn2.bias',
  24        'down1.conv2.conv.1.running_mean': 'models.1.bn2.running_mean',
  25        'down1.conv2.conv.1.running_var': 'models.1.bn2.running_var',
  26        'down1.conv2.conv.1.num_batches_tracked': 'models.1.bn2.num_batches_tracked',
  27        'down1.conv3.conv.0.weight': 'models.2.conv3.weight',
  28        'down1.conv3.conv.1.weight': 'models.2.bn3.weight',
  29        'down1.conv3.conv.1.bias': 'models.2.bn3.bias',
  30        'down1.conv3.conv.1.running_mean': 'models.2.bn3.running_mean',
  31        'down1.conv3.conv.1.running_var': 'models.2.bn3.running_var',
  32        'down1.conv3.conv.1.num_batches_tracked': 'models.2.bn3.num_batches_tracked',
  33        'down1.conv4.conv.0.weight': 'models.4.conv4.weight',
  34        'down1.conv4.conv.1.weight': 'models.4.bn4.weight',
  35        'down1.conv4.conv.1.bias': 'models.4.bn4.bias',
  36        'down1.conv4.conv.1.running_mean': 'models.4.bn4.running_mean',
  37        'down1.conv4.conv.1.running_var': 'models.4.bn4.running_var',
  38        'down1.conv4.conv.1.num_batches_tracked': 'models.4.bn4.num_batches_tracked',
  39        'down1.conv5.conv.0.weight': 'models.5.conv5.weight',
  40        'down1.conv5.conv.1.weight': 'models.5.bn5.weight',
  41        'down1.conv5.conv.1.bias': 'models.5.bn5.bias',
  42        'down1.conv5.conv.1.running_mean': 'models.5.bn5.running_mean',
  43        'down1.conv5.conv.1.running_var': 'models.5.bn5.running_var',
  44        'down1.conv5.conv.1.num_batches_tracked': 'models.5.bn5.num_batches_tracked',
  45        'down1.conv6.conv.0.weight': 'models.6.conv6.weight',
  46        'down1.conv6.conv.1.weight': 'models.6.bn6.weight',
  47        'down1.conv6.conv.1.bias': 'models.6.bn6.bias',
  48        'down1.conv6.conv.1.running_mean': 'models.6.bn6.running_mean',
  49        'down1.conv6.conv.1.running_var': 'models.6.bn6.running_var',
  50        'down1.conv6.conv.1.num_batches_tracked': 'models.6.bn6.num_batches_tracked',
  51        'down1.conv7.conv.0.weight': 'models.8.conv7.weight',
  52        'down1.conv7.conv.1.weight': 'models.8.bn7.weight',
  53        'down1.conv7.conv.1.bias': 'models.8.bn7.bias',
  54        'down1.conv7.conv.1.running_mean': 'models.8.bn7.running_mean',
  55        'down1.conv7.conv.1.running_var': 'models.8.bn7.running_var',
  56        'down1.conv7.conv.1.num_batches_tracked': 'models.8.bn7.num_batches_tracked',
  57        'down1.conv8.conv.0.weight': 'models.10.conv8.weight',
  58        'down1.conv8.conv.1.weight': 'models.10.bn8.weight',
  59        'down1.conv8.conv.1.bias': 'models.10.bn8.bias',
  60        'down1.conv8.conv.1.running_mean': 'models.10.bn8.running_mean',
  61        'down1.conv8.conv.1.running_var': 'models.10.bn8.running_var',
  62        'down1.conv8.conv.1.num_batches_tracked': 'models.10.bn8.num_batches_tracked',
  63        'down2.conv1.conv.0.weight': 'models.11.conv9.weight',
  64        'down2.conv1.conv.1.weight': 'models.11.bn9.weight',
  65        'down2.conv1.conv.1.bias': 'models.11.bn9.bias',
  66        'down2.conv1.conv.1.running_mean': 'models.11.bn9.running_mean',
  67        'down2.conv1.conv.1.running_var': 'models.11.bn9.running_var',
  68        'down2.conv1.conv.1.num_batches_tracked': 'models.11.bn9.num_batches_tracked',
  69        'down2.conv2.conv.0.weight': 'models.12.conv10.weight',
  70        'down2.conv2.conv.1.weight': 'models.12.bn10.weight',
  71        'down2.conv2.conv.1.bias': 'models.12.bn10.bias',
  72        'down2.conv2.conv.1.running_mean': 'models.12.bn10.running_mean',
  73        'down2.conv2.conv.1.running_var': 'models.12.bn10.running_var',
  74        'down2.conv2.conv.1.num_batches_tracked': 'models.12.bn10.num_batches_tracked',
  75        'down2.conv3.conv.0.weight': 'models.14.conv11.weight',
  76        'down2.conv3.conv.1.weight': 'models.14.bn11.weight',
  77        'down2.conv3.conv.1.bias': 'models.14.bn11.bias',
  78        'down2.conv3.conv.1.running_mean': 'models.14.bn11.running_mean',
  79        'down2.conv3.conv.1.running_var': 'models.14.bn11.running_var',
  80        'down2.conv3.conv.1.num_batches_tracked': 'models.14.bn11.num_batches_tracked',
  81        'down2.resblock.module_list.0.0.conv.0.weight': 'models.15.conv12.weight',
  82        'down2.resblock.module_list.0.0.conv.1.weight': 'models.15.bn12.weight',
  83        'down2.resblock.module_list.0.0.conv.1.bias': 'models.15.bn12.bias',
  84        'down2.resblock.module_list.0.0.conv.1.running_mean': 'models.15.bn12.running_mean',
  85        'down2.resblock.module_list.0.0.conv.1.running_var': 'models.15.bn12.running_var',
  86        'down2.resblock.module_list.0.0.conv.1.num_batches_tracked': 'models.15.bn12.num_batches_tracked',
  87        'down2.resblock.module_list.0.1.conv.0.weight': 'models.16.conv13.weight',
  88        'down2.resblock.module_list.0.1.conv.1.weight': 'models.16.bn13.weight',
  89        'down2.resblock.module_list.0.1.conv.1.bias': 'models.16.bn13.bias',
  90        'down2.resblock.module_list.0.1.conv.1.running_mean': 'models.16.bn13.running_mean',
  91        'down2.resblock.module_list.0.1.conv.1.running_var': 'models.16.bn13.running_var',
  92        'down2.resblock.module_list.0.1.conv.1.num_batches_tracked': 'models.16.bn13.num_batches_tracked',
  93        'down2.resblock.module_list.1.0.conv.0.weight': 'models.18.conv14.weight',
  94        'down2.resblock.module_list.1.0.conv.1.weight': 'models.18.bn14.weight',
  95        'down2.resblock.module_list.1.0.conv.1.bias': 'models.18.bn14.bias',
  96        'down2.resblock.module_list.1.0.conv.1.running_mean': 'models.18.bn14.running_mean',
  97        'down2.resblock.module_list.1.0.conv.1.running_var': 'models.18.bn14.running_var',
  98        'down2.resblock.module_list.1.0.conv.1.num_batches_tracked': 'models.18.bn14.num_batches_tracked',
  99        'down2.resblock.module_list.1.1.conv.0.weight': 'models.19.conv15.weight',
 100        'down2.resblock.module_list.1.1.conv.1.weight': 'models.19.bn15.weight',
 101        'down2.resblock.module_list.1.1.conv.1.bias': 'models.19.bn15.bias',
 102        'down2.resblock.module_list.1.1.conv.1.running_mean': 'models.19.bn15.running_mean',
 103        'down2.resblock.module_list.1.1.conv.1.running_var': 'models.19.bn15.running_var',
 104        'down2.resblock.module_list.1.1.conv.1.num_batches_tracked': 'models.19.bn15.num_batches_tracked',
 105        'down2.conv4.conv.0.weight': 'models.21.conv16.weight',
 106        'down2.conv4.conv.1.weight': 'models.21.bn16.weight',
 107        'down2.conv4.conv.1.bias': 'models.21.bn16.bias',
 108        'down2.conv4.conv.1.running_mean': 'models.21.bn16.running_mean',
 109        'down2.conv4.conv.1.running_var': 'models.21.bn16.running_var',
 110        'down2.conv4.conv.1.num_batches_tracked': 'models.21.bn16.num_batches_tracked',
 111        'down2.conv5.conv.0.weight': 'models.23.conv17.weight',
 112        'down2.conv5.conv.1.weight': 'models.23.bn17.weight',
 113        'down2.conv5.conv.1.bias': 'models.23.bn17.bias',
 114        'down2.conv5.conv.1.running_mean': 'models.23.bn17.running_mean',
 115        'down2.conv5.conv.1.running_var': 'models.23.bn17.running_var',
 116        'down2.conv5.conv.1.num_batches_tracked': 'models.23.bn17.num_batches_tracked',
 117        'down3.conv1.conv.0.weight': 'models.24.conv18.weight',
 118        'down3.conv1.conv.1.weight': 'models.24.bn18.weight',
 119        'down3.conv1.conv.1.bias': 'models.24.bn18.bias',
 120        'down3.conv1.conv.1.running_mean': 'models.24.bn18.running_mean',
 121        'down3.conv1.conv.1.running_var': 'models.24.bn18.running_var',
 122        'down3.conv1.conv.1.num_batches_tracked': 'models.24.bn18.num_batches_tracked',
 123        'down3.conv2.conv.0.weight': 'models.25.conv19.weight',
 124        'down3.conv2.conv.1.weight': 'models.25.bn19.weight',
 125        'down3.conv2.conv.1.bias': 'models.25.bn19.bias',
 126        'down3.conv2.conv.1.running_mean': 'models.25.bn19.running_mean',
 127        'down3.conv2.conv.1.running_var': 'models.25.bn19.running_var',
 128        'down3.conv2.conv.1.num_batches_tracked': 'models.25.bn19.num_batches_tracked',
 129        'down3.conv3.conv.0.weight': 'models.27.conv20.weight',
 130        'down3.conv3.conv.1.weight': 'models.27.bn20.weight',
 131        'down3.conv3.conv.1.bias': 'models.27.bn20.bias',
 132        'down3.conv3.conv.1.running_mean': 'models.27.bn20.running_mean',
 133        'down3.conv3.conv.1.running_var': 'models.27.bn20.running_var',
 134        'down3.conv3.conv.1.num_batches_tracked': 'models.27.bn20.num_batches_tracked',
 135        'down3.resblock.module_list.0.0.conv.0.weight': 'models.28.conv21.weight',
 136        'down3.resblock.module_list.0.0.conv.1.weight': 'models.28.bn21.weight',
 137        'down3.resblock.module_list.0.0.conv.1.bias': 'models.28.bn21.bias',
 138        'down3.resblock.module_list.0.0.conv.1.running_mean': 'models.28.bn21.running_mean',
 139        'down3.resblock.module_list.0.0.conv.1.running_var': 'models.28.bn21.running_var',
 140        'down3.resblock.module_list.0.0.conv.1.num_batches_tracked': 'models.28.bn21.num_batches_tracked',
 141        'down3.resblock.module_list.0.1.conv.0.weight': 'models.29.conv22.weight',
 142        'down3.resblock.module_list.0.1.conv.1.weight': 'models.29.bn22.weight',
 143        'down3.resblock.module_list.0.1.conv.1.bias': 'models.29.bn22.bias',
 144        'down3.resblock.module_list.0.1.conv.1.running_mean': 'models.29.bn22.running_mean',
 145        'down3.resblock.module_list.0.1.conv.1.running_var': 'models.29.bn22.running_var',
 146        'down3.resblock.module_list.0.1.conv.1.num_batches_tracked': 'models.29.bn22.num_batches_tracked',
 147        'down3.resblock.module_list.1.0.conv.0.weight': 'models.31.conv23.weight',
 148        'down3.resblock.module_list.1.0.conv.1.weight': 'models.31.bn23.weight',
 149        'down3.resblock.module_list.1.0.conv.1.bias': 'models.31.bn23.bias',
 150        'down3.resblock.module_list.1.0.conv.1.running_mean': 'models.31.bn23.running_mean',
 151        'down3.resblock.module_list.1.0.conv.1.running_var': 'models.31.bn23.running_var',
 152        'down3.resblock.module_list.1.0.conv.1.num_batches_tracked': 'models.31.bn23.num_batches_tracked',
 153        'down3.resblock.module_list.1.1.conv.0.weight': 'models.32.conv24.weight',
 154        'down3.resblock.module_list.1.1.conv.1.weight': 'models.32.bn24.weight',
 155        'down3.resblock.module_list.1.1.conv.1.bias': 'models.32.bn24.bias',
 156        'down3.resblock.module_list.1.1.conv.1.running_mean': 'models.32.bn24.running_mean',
 157        'down3.resblock.module_list.1.1.conv.1.running_var': 'models.32.bn24.running_var',
 158        'down3.resblock.module_list.1.1.conv.1.num_batches_tracked': 'models.32.bn24.num_batches_tracked',
 159        'down3.resblock.module_list.2.0.conv.0.weight': 'models.34.conv25.weight',
 160        'down3.resblock.module_list.2.0.conv.1.weight': 'models.34.bn25.weight',
 161        'down3.resblock.module_list.2.0.conv.1.bias': 'models.34.bn25.bias',
 162        'down3.resblock.module_list.2.0.conv.1.running_mean': 'models.34.bn25.running_mean',
 163        'down3.resblock.module_list.2.0.conv.1.running_var': 'models.34.bn25.running_var',
 164        'down3.resblock.module_list.2.0.conv.1.num_batches_tracked': 'models.34.bn25.num_batches_tracked',
 165        'down3.resblock.module_list.2.1.conv.0.weight': 'models.35.conv26.weight',
 166        'down3.resblock.module_list.2.1.conv.1.weight': 'models.35.bn26.weight',
 167        'down3.resblock.module_list.2.1.conv.1.bias': 'models.35.bn26.bias',
 168        'down3.resblock.module_list.2.1.conv.1.running_mean': 'models.35.bn26.running_mean',
 169        'down3.resblock.module_list.2.1.conv.1.running_var': 'models.35.bn26.running_var',
 170        'down3.resblock.module_list.2.1.conv.1.num_batches_tracked': 'models.35.bn26.num_batches_tracked',
 171        'down3.resblock.module_list.3.0.conv.0.weight': 'models.37.conv27.weight',
 172        'down3.resblock.module_list.3.0.conv.1.weight': 'models.37.bn27.weight',
 173        'down3.resblock.module_list.3.0.conv.1.bias': 'models.37.bn27.bias',
 174        'down3.resblock.module_list.3.0.conv.1.running_mean': 'models.37.bn27.running_mean',
 175        'down3.resblock.module_list.3.0.conv.1.running_var': 'models.37.bn27.running_var',
 176        'down3.resblock.module_list.3.0.conv.1.num_batches_tracked': 'models.37.bn27.num_batches_tracked',
 177        'down3.resblock.module_list.3.1.conv.0.weight': 'models.38.conv28.weight',
 178        'down3.resblock.module_list.3.1.conv.1.weight': 'models.38.bn28.weight',
 179        'down3.resblock.module_list.3.1.conv.1.bias': 'models.38.bn28.bias',
 180        'down3.resblock.module_list.3.1.conv.1.running_mean': 'models.38.bn28.running_mean',
 181        'down3.resblock.module_list.3.1.conv.1.running_var': 'models.38.bn28.running_var',
 182        'down3.resblock.module_list.3.1.conv.1.num_batches_tracked': 'models.38.bn28.num_batches_tracked',
 183        'down3.resblock.module_list.4.0.conv.0.weight': 'models.40.conv29.weight',
 184        'down3.resblock.module_list.4.0.conv.1.weight': 'models.40.bn29.weight',
 185        'down3.resblock.module_list.4.0.conv.1.bias': 'models.40.bn29.bias',
 186        'down3.resblock.module_list.4.0.conv.1.running_mean': 'models.40.bn29.running_mean',
 187        'down3.resblock.module_list.4.0.conv.1.running_var': 'models.40.bn29.running_var',
 188        'down3.resblock.module_list.4.0.conv.1.num_batches_tracked': 'models.40.bn29.num_batches_tracked',
 189        'down3.resblock.module_list.4.1.conv.0.weight': 'models.41.conv30.weight',
 190        'down3.resblock.module_list.4.1.conv.1.weight': 'models.41.bn30.weight',
 191        'down3.resblock.module_list.4.1.conv.1.bias': 'models.41.bn30.bias',
 192        'down3.resblock.module_list.4.1.conv.1.running_mean': 'models.41.bn30.running_mean',
 193        'down3.resblock.module_list.4.1.conv.1.running_var': 'models.41.bn30.running_var',
 194        'down3.resblock.module_list.4.1.conv.1.num_batches_tracked': 'models.41.bn30.num_batches_tracked',
 195        'down3.resblock.module_list.5.0.conv.0.weight': 'models.43.conv31.weight',
 196        'down3.resblock.module_list.5.0.conv.1.weight': 'models.43.bn31.weight',
 197        'down3.resblock.module_list.5.0.conv.1.bias': 'models.43.bn31.bias',
 198        'down3.resblock.module_list.5.0.conv.1.running_mean': 'models.43.bn31.running_mean',
 199        'down3.resblock.module_list.5.0.conv.1.running_var': 'models.43.bn31.running_var',
 200        'down3.resblock.module_list.5.0.conv.1.num_batches_tracked': 'models.43.bn31.num_batches_tracked',
 201        'down3.resblock.module_list.5.1.conv.0.weight': 'models.44.conv32.weight',
 202        'down3.resblock.module_list.5.1.conv.1.weight': 'models.44.bn32.weight',
 203        'down3.resblock.module_list.5.1.conv.1.bias': 'models.44.bn32.bias',
 204        'down3.resblock.module_list.5.1.conv.1.running_mean': 'models.44.bn32.running_mean',
 205        'down3.resblock.module_list.5.1.conv.1.running_var': 'models.44.bn32.running_var',
 206        'down3.resblock.module_list.5.1.conv.1.num_batches_tracked': 'models.44.bn32.num_batches_tracked',
 207        'down3.resblock.module_list.6.0.conv.0.weight': 'models.46.conv33.weight',
 208        'down3.resblock.module_list.6.0.conv.1.weight': 'models.46.bn33.weight',
 209        'down3.resblock.module_list.6.0.conv.1.bias': 'models.46.bn33.bias',
 210        'down3.resblock.module_list.6.0.conv.1.running_mean': 'models.46.bn33.running_mean',
 211        'down3.resblock.module_list.6.0.conv.1.running_var': 'models.46.bn33.running_var',
 212        'down3.resblock.module_list.6.0.conv.1.num_batches_tracked': 'models.46.bn33.num_batches_tracked',
 213        'down3.resblock.module_list.6.1.conv.0.weight': 'models.47.conv34.weight',
 214        'down3.resblock.module_list.6.1.conv.1.weight': 'models.47.bn34.weight',
 215        'down3.resblock.module_list.6.1.conv.1.bias': 'models.47.bn34.bias',
 216        'down3.resblock.module_list.6.1.conv.1.running_mean': 'models.47.bn34.running_mean',
 217        'down3.resblock.module_list.6.1.conv.1.running_var': 'models.47.bn34.running_var',
 218        'down3.resblock.module_list.6.1.conv.1.num_batches_tracked': 'models.47.bn34.num_batches_tracked',
 219        'down3.resblock.module_list.7.0.conv.0.weight': 'models.49.conv35.weight',
 220        'down3.resblock.module_list.7.0.conv.1.weight': 'models.49.bn35.weight',
 221        'down3.resblock.module_list.7.0.conv.1.bias': 'models.49.bn35.bias',
 222        'down3.resblock.module_list.7.0.conv.1.running_mean': 'models.49.bn35.running_mean',
 223        'down3.resblock.module_list.7.0.conv.1.running_var': 'models.49.bn35.running_var',
 224        'down3.resblock.module_list.7.0.conv.1.num_batches_tracked': 'models.49.bn35.num_batches_tracked',
 225        'down3.resblock.module_list.7.1.conv.0.weight': 'models.50.conv36.weight',
 226        'down3.resblock.module_list.7.1.conv.1.weight': 'models.50.bn36.weight',
 227        'down3.resblock.module_list.7.1.conv.1.bias': 'models.50.bn36.bias',
 228        'down3.resblock.module_list.7.1.conv.1.running_mean': 'models.50.bn36.running_mean',
 229        'down3.resblock.module_list.7.1.conv.1.running_var': 'models.50.bn36.running_var',
 230        'down3.resblock.module_list.7.1.conv.1.num_batches_tracked': 'models.50.bn36.num_batches_tracked',
 231        'down3.conv4.conv.0.weight': 'models.52.conv37.weight',
 232        'down3.conv4.conv.1.weight': 'models.52.bn37.weight',
 233        'down3.conv4.conv.1.bias': 'models.52.bn37.bias',
 234        'down3.conv4.conv.1.running_mean': 'models.52.bn37.running_mean',
 235        'down3.conv4.conv.1.running_var': 'models.52.bn37.running_var',
 236        'down3.conv4.conv.1.num_batches_tracked': 'models.52.bn37.num_batches_tracked',
 237        'down3.conv5.conv.0.weight': 'models.54.conv38.weight',
 238        'down3.conv5.conv.1.weight': 'models.54.bn38.weight',
 239        'down3.conv5.conv.1.bias': 'models.54.bn38.bias',
 240        'down3.conv5.conv.1.running_mean': 'models.54.bn38.running_mean',
 241        'down3.conv5.conv.1.running_var': 'models.54.bn38.running_var',
 242        'down3.conv5.conv.1.num_batches_tracked': 'models.54.bn38.num_batches_tracked',
 243        'down4.conv1.conv.0.weight': 'models.55.conv39.weight',
 244        'down4.conv1.conv.1.weight': 'models.55.bn39.weight',
 245        'down4.conv1.conv.1.bias': 'models.55.bn39.bias',
 246        'down4.conv1.conv.1.running_mean': 'models.55.bn39.running_mean',
 247        'down4.conv1.conv.1.running_var': 'models.55.bn39.running_var',
 248        'down4.conv1.conv.1.num_batches_tracked': 'models.55.bn39.num_batches_tracked',
 249        'down4.conv2.conv.0.weight': 'models.56.conv40.weight',
 250        'down4.conv2.conv.1.weight': 'models.56.bn40.weight',
 251        'down4.conv2.conv.1.bias': 'models.56.bn40.bias',
 252        'down4.conv2.conv.1.running_mean': 'models.56.bn40.running_mean',
 253        'down4.conv2.conv.1.running_var': 'models.56.bn40.running_var',
 254        'down4.conv2.conv.1.num_batches_tracked': 'models.56.bn40.num_batches_tracked',
 255        'down4.conv3.conv.0.weight': 'models.58.conv41.weight',
 256        'down4.conv3.conv.1.weight': 'models.58.bn41.weight',
 257        'down4.conv3.conv.1.bias': 'models.58.bn41.bias',
 258        'down4.conv3.conv.1.running_mean': 'models.58.bn41.running_mean',
 259        'down4.conv3.conv.1.running_var': 'models.58.bn41.running_var',
 260        'down4.conv3.conv.1.num_batches_tracked': 'models.58.bn41.num_batches_tracked',
 261        'down4.resblock.module_list.0.0.conv.0.weight': 'models.59.conv42.weight',
 262        'down4.resblock.module_list.0.0.conv.1.weight': 'models.59.bn42.weight',
 263        'down4.resblock.module_list.0.0.conv.1.bias': 'models.59.bn42.bias',
 264        'down4.resblock.module_list.0.0.conv.1.running_mean': 'models.59.bn42.running_mean',
 265        'down4.resblock.module_list.0.0.conv.1.running_var': 'models.59.bn42.running_var',
 266        'down4.resblock.module_list.0.0.conv.1.num_batches_tracked': 'models.59.bn42.num_batches_tracked',
 267        'down4.resblock.module_list.0.1.conv.0.weight': 'models.60.conv43.weight',
 268        'down4.resblock.module_list.0.1.conv.1.weight': 'models.60.bn43.weight',
 269        'down4.resblock.module_list.0.1.conv.1.bias': 'models.60.bn43.bias',
 270        'down4.resblock.module_list.0.1.conv.1.running_mean': 'models.60.bn43.running_mean',
 271        'down4.resblock.module_list.0.1.conv.1.running_var': 'models.60.bn43.running_var',
 272        'down4.resblock.module_list.0.1.conv.1.num_batches_tracked': 'models.60.bn43.num_batches_tracked',
 273        'down4.resblock.module_list.1.0.conv.0.weight': 'models.62.conv44.weight',
 274        'down4.resblock.module_list.1.0.conv.1.weight': 'models.62.bn44.weight',
 275        'down4.resblock.module_list.1.0.conv.1.bias': 'models.62.bn44.bias',
 276        'down4.resblock.module_list.1.0.conv.1.running_mean': 'models.62.bn44.running_mean',
 277        'down4.resblock.module_list.1.0.conv.1.running_var': 'models.62.bn44.running_var',
 278        'down4.resblock.module_list.1.0.conv.1.num_batches_tracked': 'models.62.bn44.num_batches_tracked',
 279        'down4.resblock.module_list.1.1.conv.0.weight': 'models.63.conv45.weight',
 280        'down4.resblock.module_list.1.1.conv.1.weight': 'models.63.bn45.weight',
 281        'down4.resblock.module_list.1.1.conv.1.bias': 'models.63.bn45.bias',
 282        'down4.resblock.module_list.1.1.conv.1.running_mean': 'models.63.bn45.running_mean',
 283        'down4.resblock.module_list.1.1.conv.1.running_var': 'models.63.bn45.running_var',
 284        'down4.resblock.module_list.1.1.conv.1.num_batches_tracked': 'models.63.bn45.num_batches_tracked',
 285        'down4.resblock.module_list.2.0.conv.0.weight': 'models.65.conv46.weight',
 286        'down4.resblock.module_list.2.0.conv.1.weight': 'models.65.bn46.weight',
 287        'down4.resblock.module_list.2.0.conv.1.bias': 'models.65.bn46.bias',
 288        'down4.resblock.module_list.2.0.conv.1.running_mean': 'models.65.bn46.running_mean',
 289        'down4.resblock.module_list.2.0.conv.1.running_var': 'models.65.bn46.running_var',
 290        'down4.resblock.module_list.2.0.conv.1.num_batches_tracked': 'models.65.bn46.num_batches_tracked',
 291        'down4.resblock.module_list.2.1.conv.0.weight': 'models.66.conv47.weight',
 292        'down4.resblock.module_list.2.1.conv.1.weight': 'models.66.bn47.weight',
 293        'down4.resblock.module_list.2.1.conv.1.bias': 'models.66.bn47.bias',
 294        'down4.resblock.module_list.2.1.conv.1.running_mean': 'models.66.bn47.running_mean',
 295        'down4.resblock.module_list.2.1.conv.1.running_var': 'models.66.bn47.running_var',
 296        'down4.resblock.module_list.2.1.conv.1.num_batches_tracked': 'models.66.bn47.num_batches_tracked',
 297        'down4.resblock.module_list.3.0.conv.0.weight': 'models.68.conv48.weight',
 298        'down4.resblock.module_list.3.0.conv.1.weight': 'models.68.bn48.weight',
 299        'down4.resblock.module_list.3.0.conv.1.bias': 'models.68.bn48.bias',
 300        'down4.resblock.module_list.3.0.conv.1.running_mean': 'models.68.bn48.running_mean',
 301        'down4.resblock.module_list.3.0.conv.1.running_var': 'models.68.bn48.running_var',
 302        'down4.resblock.module_list.3.0.conv.1.num_batches_tracked': 'models.68.bn48.num_batches_tracked',
 303        'down4.resblock.module_list.3.1.conv.0.weight': 'models.69.conv49.weight',
 304        'down4.resblock.module_list.3.1.conv.1.weight': 'models.69.bn49.weight',
 305        'down4.resblock.module_list.3.1.conv.1.bias': 'models.69.bn49.bias',
 306        'down4.resblock.module_list.3.1.conv.1.running_mean': 'models.69.bn49.running_mean',
 307        'down4.resblock.module_list.3.1.conv.1.running_var': 'models.69.bn49.running_var',
 308        'down4.resblock.module_list.3.1.conv.1.num_batches_tracked': 'models.69.bn49.num_batches_tracked',
 309        'down4.resblock.module_list.4.0.conv.0.weight': 'models.71.conv50.weight',
 310        'down4.resblock.module_list.4.0.conv.1.weight': 'models.71.bn50.weight',
 311        'down4.resblock.module_list.4.0.conv.1.bias': 'models.71.bn50.bias',
 312        'down4.resblock.module_list.4.0.conv.1.running_mean': 'models.71.bn50.running_mean',
 313        'down4.resblock.module_list.4.0.conv.1.running_var': 'models.71.bn50.running_var',
 314        'down4.resblock.module_list.4.0.conv.1.num_batches_tracked': 'models.71.bn50.num_batches_tracked',
 315        'down4.resblock.module_list.4.1.conv.0.weight': 'models.72.conv51.weight',
 316        'down4.resblock.module_list.4.1.conv.1.weight': 'models.72.bn51.weight',
 317        'down4.resblock.module_list.4.1.conv.1.bias': 'models.72.bn51.bias',
 318        'down4.resblock.module_list.4.1.conv.1.running_mean': 'models.72.bn51.running_mean',
 319        'down4.resblock.module_list.4.1.conv.1.running_var': 'models.72.bn51.running_var',
 320        'down4.resblock.module_list.4.1.conv.1.num_batches_tracked': 'models.72.bn51.num_batches_tracked',
 321        'down4.resblock.module_list.5.0.conv.0.weight': 'models.74.conv52.weight',
 322        'down4.resblock.module_list.5.0.conv.1.weight': 'models.74.bn52.weight',
 323        'down4.resblock.module_list.5.0.conv.1.bias': 'models.74.bn52.bias',
 324        'down4.resblock.module_list.5.0.conv.1.running_mean': 'models.74.bn52.running_mean',
 325        'down4.resblock.module_list.5.0.conv.1.running_var': 'models.74.bn52.running_var',
 326        'down4.resblock.module_list.5.0.conv.1.num_batches_tracked': 'models.74.bn52.num_batches_tracked',
 327        'down4.resblock.module_list.5.1.conv.0.weight': 'models.75.conv53.weight',
 328        'down4.resblock.module_list.5.1.conv.1.weight': 'models.75.bn53.weight',
 329        'down4.resblock.module_list.5.1.conv.1.bias': 'models.75.bn53.bias',
 330        'down4.resblock.module_list.5.1.conv.1.running_mean': 'models.75.bn53.running_mean',
 331        'down4.resblock.module_list.5.1.conv.1.running_var': 'models.75.bn53.running_var',
 332        'down4.resblock.module_list.5.1.conv.1.num_batches_tracked': 'models.75.bn53.num_batches_tracked',
 333        'down4.resblock.module_list.6.0.conv.0.weight': 'models.77.conv54.weight',
 334        'down4.resblock.module_list.6.0.conv.1.weight': 'models.77.bn54.weight',
 335        'down4.resblock.module_list.6.0.conv.1.bias': 'models.77.bn54.bias',
 336        'down4.resblock.module_list.6.0.conv.1.running_mean': 'models.77.bn54.running_mean',
 337        'down4.resblock.module_list.6.0.conv.1.running_var': 'models.77.bn54.running_var',
 338        'down4.resblock.module_list.6.0.conv.1.num_batches_tracked': 'models.77.bn54.num_batches_tracked',
 339        'down4.resblock.module_list.6.1.conv.0.weight': 'models.78.conv55.weight',
 340        'down4.resblock.module_list.6.1.conv.1.weight': 'models.78.bn55.weight',
 341        'down4.resblock.module_list.6.1.conv.1.bias': 'models.78.bn55.bias',
 342        'down4.resblock.module_list.6.1.conv.1.running_mean': 'models.78.bn55.running_mean',
 343        'down4.resblock.module_list.6.1.conv.1.running_var': 'models.78.bn55.running_var',
 344        'down4.resblock.module_list.6.1.conv.1.num_batches_tracked': 'models.78.bn55.num_batches_tracked',
 345        'down4.resblock.module_list.7.0.conv.0.weight': 'models.80.conv56.weight',
 346        'down4.resblock.module_list.7.0.conv.1.weight': 'models.80.bn56.weight',
 347        'down4.resblock.module_list.7.0.conv.1.bias': 'models.80.bn56.bias',
 348        'down4.resblock.module_list.7.0.conv.1.running_mean': 'models.80.bn56.running_mean',
 349        'down4.resblock.module_list.7.0.conv.1.running_var': 'models.80.bn56.running_var',
 350        'down4.resblock.module_list.7.0.conv.1.num_batches_tracked': 'models.80.bn56.num_batches_tracked',
 351        'down4.resblock.module_list.7.1.conv.0.weight': 'models.81.conv57.weight',
 352        'down4.resblock.module_list.7.1.conv.1.weight': 'models.81.bn57.weight',
 353        'down4.resblock.module_list.7.1.conv.1.bias': 'models.81.bn57.bias',
 354        'down4.resblock.module_list.7.1.conv.1.running_mean': 'models.81.bn57.running_mean',
 355        'down4.resblock.module_list.7.1.conv.1.running_var': 'models.81.bn57.running_var',
 356        'down4.resblock.module_list.7.1.conv.1.num_batches_tracked': 'models.81.bn57.num_batches_tracked',
 357        'down4.conv4.conv.0.weight': 'models.83.conv58.weight',
 358        'down4.conv4.conv.1.weight': 'models.83.bn58.weight',
 359        'down4.conv4.conv.1.bias': 'models.83.bn58.bias',
 360        'down4.conv4.conv.1.running_mean': 'models.83.bn58.running_mean',
 361        'down4.conv4.conv.1.running_var': 'models.83.bn58.running_var',
 362        'down4.conv4.conv.1.num_batches_tracked': 'models.83.bn58.num_batches_tracked',
 363        'down4.conv5.conv.0.weight': 'models.85.conv59.weight',
 364        'down4.conv5.conv.1.weight': 'models.85.bn59.weight',
 365        'down4.conv5.conv.1.bias': 'models.85.bn59.bias',
 366        'down4.conv5.conv.1.running_mean': 'models.85.bn59.running_mean',
 367        'down4.conv5.conv.1.running_var': 'models.85.bn59.running_var',
 368        'down4.conv5.conv.1.num_batches_tracked': 'models.85.bn59.num_batches_tracked',
 369        'down5.conv1.conv.0.weight': 'models.86.conv60.weight',
 370        'down5.conv1.conv.1.weight': 'models.86.bn60.weight',
 371        'down5.conv1.conv.1.bias': 'models.86.bn60.bias',
 372        'down5.conv1.conv.1.running_mean': 'models.86.bn60.running_mean',
 373        'down5.conv1.conv.1.running_var': 'models.86.bn60.running_var',
 374        'down5.conv1.conv.1.num_batches_tracked': 'models.86.bn60.num_batches_tracked',
 375        'down5.conv2.conv.0.weight': 'models.87.conv61.weight',
 376        'down5.conv2.conv.1.weight': 'models.87.bn61.weight',
 377        'down5.conv2.conv.1.bias': 'models.87.bn61.bias',
 378        'down5.conv2.conv.1.running_mean': 'models.87.bn61.running_mean',
 379        'down5.conv2.conv.1.running_var': 'models.87.bn61.running_var',
 380        'down5.conv2.conv.1.num_batches_tracked': 'models.87.bn61.num_batches_tracked',
 381        'down5.conv3.conv.0.weight': 'models.89.conv62.weight',
 382        'down5.conv3.conv.1.weight': 'models.89.bn62.weight',
 383        'down5.conv3.conv.1.bias': 'models.89.bn62.bias',
 384        'down5.conv3.conv.1.running_mean': 'models.89.bn62.running_mean',
 385        'down5.conv3.conv.1.running_var': 'models.89.bn62.running_var',
 386        'down5.conv3.conv.1.num_batches_tracked': 'models.89.bn62.num_batches_tracked',
 387        'down5.resblock.module_list.0.0.conv.0.weight': 'models.90.conv63.weight',
 388        'down5.resblock.module_list.0.0.conv.1.weight': 'models.90.bn63.weight',
 389        'down5.resblock.module_list.0.0.conv.1.bias': 'models.90.bn63.bias',
 390        'down5.resblock.module_list.0.0.conv.1.running_mean': 'models.90.bn63.running_mean',
 391        'down5.resblock.module_list.0.0.conv.1.running_var': 'models.90.bn63.running_var',
 392        'down5.resblock.module_list.0.0.conv.1.num_batches_tracked': 'models.90.bn63.num_batches_tracked',
 393        'down5.resblock.module_list.0.1.conv.0.weight': 'models.91.conv64.weight',
 394        'down5.resblock.module_list.0.1.conv.1.weight': 'models.91.bn64.weight',
 395        'down5.resblock.module_list.0.1.conv.1.bias': 'models.91.bn64.bias',
 396        'down5.resblock.module_list.0.1.conv.1.running_mean': 'models.91.bn64.running_mean',
 397        'down5.resblock.module_list.0.1.conv.1.running_var': 'models.91.bn64.running_var',
 398        'down5.resblock.module_list.0.1.conv.1.num_batches_tracked': 'models.91.bn64.num_batches_tracked',
 399        'down5.resblock.module_list.1.0.conv.0.weight': 'models.93.conv65.weight',
 400        'down5.resblock.module_list.1.0.conv.1.weight': 'models.93.bn65.weight',
 401        'down5.resblock.module_list.1.0.conv.1.bias': 'models.93.bn65.bias',
 402        'down5.resblock.module_list.1.0.conv.1.running_mean': 'models.93.bn65.running_mean',
 403        'down5.resblock.module_list.1.0.conv.1.running_var': 'models.93.bn65.running_var',
 404        'down5.resblock.module_list.1.0.conv.1.num_batches_tracked': 'models.93.bn65.num_batches_tracked',
 405        'down5.resblock.module_list.1.1.conv.0.weight': 'models.94.conv66.weight',
 406        'down5.resblock.module_list.1.1.conv.1.weight': 'models.94.bn66.weight',
 407        'down5.resblock.module_list.1.1.conv.1.bias': 'models.94.bn66.bias',
 408        'down5.resblock.module_list.1.1.conv.1.running_mean': 'models.94.bn66.running_mean',
 409        'down5.resblock.module_list.1.1.conv.1.running_var': 'models.94.bn66.running_var',
 410        'down5.resblock.module_list.1.1.conv.1.num_batches_tracked': 'models.94.bn66.num_batches_tracked',
 411        'down5.resblock.module_list.2.0.conv.0.weight': 'models.96.conv67.weight',
 412        'down5.resblock.module_list.2.0.conv.1.weight': 'models.96.bn67.weight',
 413        'down5.resblock.module_list.2.0.conv.1.bias': 'models.96.bn67.bias',
 414        'down5.resblock.module_list.2.0.conv.1.running_mean': 'models.96.bn67.running_mean',
 415        'down5.resblock.module_list.2.0.conv.1.running_var': 'models.96.bn67.running_var',
 416        'down5.resblock.module_list.2.0.conv.1.num_batches_tracked': 'models.96.bn67.num_batches_tracked',
 417        'down5.resblock.module_list.2.1.conv.0.weight': 'models.97.conv68.weight',
 418        'down5.resblock.module_list.2.1.conv.1.weight': 'models.97.bn68.weight',
 419        'down5.resblock.module_list.2.1.conv.1.bias': 'models.97.bn68.bias',
 420        'down5.resblock.module_list.2.1.conv.1.running_mean': 'models.97.bn68.running_mean',
 421        'down5.resblock.module_list.2.1.conv.1.running_var': 'models.97.bn68.running_var',
 422        'down5.resblock.module_list.2.1.conv.1.num_batches_tracked': 'models.97.bn68.num_batches_tracked',
 423        'down5.resblock.module_list.3.0.conv.0.weight': 'models.99.conv69.weight',
 424        'down5.resblock.module_list.3.0.conv.1.weight': 'models.99.bn69.weight',
 425        'down5.resblock.module_list.3.0.conv.1.bias': 'models.99.bn69.bias',
 426        'down5.resblock.module_list.3.0.conv.1.running_mean': 'models.99.bn69.running_mean',
 427        'down5.resblock.module_list.3.0.conv.1.running_var': 'models.99.bn69.running_var',
 428        'down5.resblock.module_list.3.0.conv.1.num_batches_tracked': 'models.99.bn69.num_batches_tracked',
 429        'down5.resblock.module_list.3.1.conv.0.weight': 'models.100.conv70.weight',
 430        'down5.resblock.module_list.3.1.conv.1.weight': 'models.100.bn70.weight',
 431        'down5.resblock.module_list.3.1.conv.1.bias': 'models.100.bn70.bias',
 432        'down5.resblock.module_list.3.1.conv.1.running_mean': 'models.100.bn70.running_mean',
 433        'down5.resblock.module_list.3.1.conv.1.running_var': 'models.100.bn70.running_var',
 434        'down5.resblock.module_list.3.1.conv.1.num_batches_tracked': 'models.100.bn70.num_batches_tracked',
 435        'down5.conv4.conv.0.weight': 'models.102.conv71.weight',
 436        'down5.conv4.conv.1.weight': 'models.102.bn71.weight',
 437        'down5.conv4.conv.1.bias': 'models.102.bn71.bias',
 438        'down5.conv4.conv.1.running_mean': 'models.102.bn71.running_mean',
 439        'down5.conv4.conv.1.running_var': 'models.102.bn71.running_var',
 440        'down5.conv4.conv.1.num_batches_tracked': 'models.102.bn71.num_batches_tracked',
 441        'down5.conv5.conv.0.weight': 'models.104.conv72.weight',
 442        'down5.conv5.conv.1.weight': 'models.104.bn72.weight',
 443        'down5.conv5.conv.1.bias': 'models.104.bn72.bias',
 444        'down5.conv5.conv.1.running_mean': 'models.104.bn72.running_mean',
 445        'down5.conv5.conv.1.running_var': 'models.104.bn72.running_var',
 446        'down5.conv5.conv.1.num_batches_tracked': 'models.104.bn72.num_batches_tracked',
 447        'neek.conv1.conv.0.weight': 'models.105.conv73.weight',
 448        'neek.conv1.conv.1.weight': 'models.105.bn73.weight',
 449        'neek.conv1.conv.1.bias': 'models.105.bn73.bias',
 450        'neek.conv1.conv.1.running_mean': 'models.105.bn73.running_mean',
 451        'neek.conv1.conv.1.running_var': 'models.105.bn73.running_var',
 452        'neek.conv1.conv.1.num_batches_tracked': 'models.105.bn73.num_batches_tracked',
 453        'neek.conv2.conv.0.weight': 'models.106.conv74.weight',
 454        'neek.conv2.conv.1.weight': 'models.106.bn74.weight',
 455        'neek.conv2.conv.1.bias': 'models.106.bn74.bias',
 456        'neek.conv2.conv.1.running_mean': 'models.106.bn74.running_mean',
 457        'neek.conv2.conv.1.running_var': 'models.106.bn74.running_var',
 458        'neek.conv2.conv.1.num_batches_tracked': 'models.106.bn74.num_batches_tracked',
 459        'neek.conv3.conv.0.weight': 'models.107.conv75.weight',
 460        'neek.conv3.conv.1.weight': 'models.107.bn75.weight',
 461        'neek.conv3.conv.1.bias': 'models.107.bn75.bias',
 462        'neek.conv3.conv.1.running_mean': 'models.107.bn75.running_mean',
 463        'neek.conv3.conv.1.running_var': 'models.107.bn75.running_var',
 464        'neek.conv3.conv.1.num_batches_tracked': 'models.107.bn75.num_batches_tracked',
 465        'neek.conv4.conv.0.weight': 'models.114.conv76.weight',
 466        'neek.conv4.conv.1.weight': 'models.114.bn76.weight',
 467        'neek.conv4.conv.1.bias': 'models.114.bn76.bias',
 468        'neek.conv4.conv.1.running_mean': 'models.114.bn76.running_mean',
 469        'neek.conv4.conv.1.running_var': 'models.114.bn76.running_var',
 470        'neek.conv4.conv.1.num_batches_tracked': 'models.114.bn76.num_batches_tracked',
 471        'neek.conv5.conv.0.weight': 'models.115.conv77.weight',
 472        'neek.conv5.conv.1.weight': 'models.115.bn77.weight',
 473        'neek.conv5.conv.1.bias': 'models.115.bn77.bias',
 474        'neek.conv5.conv.1.running_mean': 'models.115.bn77.running_mean',
 475        'neek.conv5.conv.1.running_var': 'models.115.bn77.running_var',
 476        'neek.conv5.conv.1.num_batches_tracked': 'models.115.bn77.num_batches_tracked',
 477        'neek.conv6.conv.0.weight': 'models.116.conv78.weight',
 478        'neek.conv6.conv.1.weight': 'models.116.bn78.weight',
 479        'neek.conv6.conv.1.bias': 'models.116.bn78.bias',
 480        'neek.conv6.conv.1.running_mean': 'models.116.bn78.running_mean',
 481        'neek.conv6.conv.1.running_var': 'models.116.bn78.running_var',
 482        'neek.conv6.conv.1.num_batches_tracked': 'models.116.bn78.num_batches_tracked',
 483        'neek.conv7.conv.0.weight': 'models.117.conv79.weight',
 484        'neek.conv7.conv.1.weight': 'models.117.bn79.weight',
 485        'neek.conv7.conv.1.bias': 'models.117.bn79.bias',
 486        'neek.conv7.conv.1.running_mean': 'models.117.bn79.running_mean',
 487        'neek.conv7.conv.1.running_var': 'models.117.bn79.running_var',
 488        'neek.conv7.conv.1.num_batches_tracked': 'models.117.bn79.num_batches_tracked',
 489        'neek.conv8.conv.0.weight': 'models.120.conv80.weight',
 490        'neek.conv8.conv.1.weight': 'models.120.bn80.weight',
 491        'neek.conv8.conv.1.bias': 'models.120.bn80.bias',
 492        'neek.conv8.conv.1.running_mean': 'models.120.bn80.running_mean',
 493        'neek.conv8.conv.1.running_var': 'models.120.bn80.running_var',
 494        'neek.conv8.conv.1.num_batches_tracked': 'models.120.bn80.num_batches_tracked',
 495        'neek.conv9.conv.0.weight': 'models.122.conv81.weight',
 496        'neek.conv9.conv.1.weight': 'models.122.bn81.weight',
 497        'neek.conv9.conv.1.bias': 'models.122.bn81.bias',
 498        'neek.conv9.conv.1.running_mean': 'models.122.bn81.running_mean',
 499        'neek.conv9.conv.1.running_var': 'models.122.bn81.running_var',
 500        'neek.conv9.conv.1.num_batches_tracked': 'models.122.bn81.num_batches_tracked',
 501        'neek.conv10.conv.0.weight': 'models.123.conv82.weight',
 502        'neek.conv10.conv.1.weight': 'models.123.bn82.weight',
 503        'neek.conv10.conv.1.bias': 'models.123.bn82.bias',
 504        'neek.conv10.conv.1.running_mean': 'models.123.bn82.running_mean',
 505        'neek.conv10.conv.1.running_var': 'models.123.bn82.running_var',
 506        'neek.conv10.conv.1.num_batches_tracked': 'models.123.bn82.num_batches_tracked',
 507        'neek.conv11.conv.0.weight': 'models.124.conv83.weight',
 508        'neek.conv11.conv.1.weight': 'models.124.bn83.weight',
 509        'neek.conv11.conv.1.bias': 'models.124.bn83.bias',
 510        'neek.conv11.conv.1.running_mean': 'models.124.bn83.running_mean',
 511        'neek.conv11.conv.1.running_var': 'models.124.bn83.running_var',
 512        'neek.conv11.conv.1.num_batches_tracked': 'models.124.bn83.num_batches_tracked',
 513        'neek.conv12.conv.0.weight': 'models.125.conv84.weight',
 514        'neek.conv12.conv.1.weight': 'models.125.bn84.weight',
 515        'neek.conv12.conv.1.bias': 'models.125.bn84.bias',
 516        'neek.conv12.conv.1.running_mean': 'models.125.bn84.running_mean',
 517        'neek.conv12.conv.1.running_var': 'models.125.bn84.running_var',
 518        'neek.conv12.conv.1.num_batches_tracked': 'models.125.bn84.num_batches_tracked',
 519        'neek.conv13.conv.0.weight': 'models.126.conv85.weight',
 520        'neek.conv13.conv.1.weight': 'models.126.bn85.weight',
 521        'neek.conv13.conv.1.bias': 'models.126.bn85.bias',
 522        'neek.conv13.conv.1.running_mean': 'models.126.bn85.running_mean',
 523        'neek.conv13.conv.1.running_var': 'models.126.bn85.running_var',
 524        'neek.conv13.conv.1.num_batches_tracked': 'models.126.bn85.num_batches_tracked',
 525        'neek.conv14.conv.0.weight': 'models.127.conv86.weight',
 526        'neek.conv14.conv.1.weight': 'models.127.bn86.weight',
 527        'neek.conv14.conv.1.bias': 'models.127.bn86.bias',
 528        'neek.conv14.conv.1.running_mean': 'models.127.bn86.running_mean',
 529        'neek.conv14.conv.1.running_var': 'models.127.bn86.running_var',
 530        'neek.conv14.conv.1.num_batches_tracked': 'models.127.bn86.num_batches_tracked',
 531        'neek.conv15.conv.0.weight': 'models.130.conv87.weight',
 532        'neek.conv15.conv.1.weight': 'models.130.bn87.weight',
 533        'neek.conv15.conv.1.bias': 'models.130.bn87.bias',
 534        'neek.conv15.conv.1.running_mean': 'models.130.bn87.running_mean',
 535        'neek.conv15.conv.1.running_var': 'models.130.bn87.running_var',
 536        'neek.conv15.conv.1.num_batches_tracked': 'models.130.bn87.num_batches_tracked',
 537        'neek.conv16.conv.0.weight': 'models.132.conv88.weight',
 538        'neek.conv16.conv.1.weight': 'models.132.bn88.weight',
 539        'neek.conv16.conv.1.bias': 'models.132.bn88.bias',
 540        'neek.conv16.conv.1.running_mean': 'models.132.bn88.running_mean',
 541        'neek.conv16.conv.1.running_var': 'models.132.bn88.running_var',
 542        'neek.conv16.conv.1.num_batches_tracked': 'models.132.bn88.num_batches_tracked',
 543        'neek.conv17.conv.0.weight': 'models.133.conv89.weight',
 544        'neek.conv17.conv.1.weight': 'models.133.bn89.weight',
 545        'neek.conv17.conv.1.bias': 'models.133.bn89.bias',
 546        'neek.conv17.conv.1.running_mean': 'models.133.bn89.running_mean',
 547        'neek.conv17.conv.1.running_var': 'models.133.bn89.running_var',
 548        'neek.conv17.conv.1.num_batches_tracked': 'models.133.bn89.num_batches_tracked',
 549        'neek.conv18.conv.0.weight': 'models.134.conv90.weight',
 550        'neek.conv18.conv.1.weight': 'models.134.bn90.weight',
 551        'neek.conv18.conv.1.bias': 'models.134.bn90.bias',
 552        'neek.conv18.conv.1.running_mean': 'models.134.bn90.running_mean',
 553        'neek.conv18.conv.1.running_var': 'models.134.bn90.running_var',
 554        'neek.conv18.conv.1.num_batches_tracked': 'models.134.bn90.num_batches_tracked',
 555        'neek.conv19.conv.0.weight': 'models.135.conv91.weight',
 556        'neek.conv19.conv.1.weight': 'models.135.bn91.weight',
 557        'neek.conv19.conv.1.bias': 'models.135.bn91.bias',
 558        'neek.conv19.conv.1.running_mean': 'models.135.bn91.running_mean',
 559        'neek.conv19.conv.1.running_var': 'models.135.bn91.running_var',
 560        'neek.conv19.conv.1.num_batches_tracked': 'models.135.bn91.num_batches_tracked',
 561        'neek.conv20.conv.0.weight': 'models.136.conv92.weight',
 562        'neek.conv20.conv.1.weight': 'models.136.bn92.weight',
 563        'neek.conv20.conv.1.bias': 'models.136.bn92.bias',
 564        'neek.conv20.conv.1.running_mean': 'models.136.bn92.running_mean',
 565        'neek.conv20.conv.1.running_var': 'models.136.bn92.running_var',
 566        'neek.conv20.conv.1.num_batches_tracked': 'models.136.bn92.num_batches_tracked',
 567        'head.conv1.conv.0.weight': 'models.137.conv93.weight',
 568        'head.conv1.conv.1.weight': 'models.137.bn93.weight',
 569        'head.conv1.conv.1.bias': 'models.137.bn93.bias',
 570        'head.conv1.conv.1.running_mean': 'models.137.bn93.running_mean',
 571        'head.conv1.conv.1.running_var': 'models.137.bn93.running_var',
 572        'head.conv1.conv.1.num_batches_tracked': 'models.137.bn93.num_batches_tracked',
 573        'head.conv2.conv.0.weight': 'models.138.conv94.weight',
 574        'head.conv2.conv.0.bias': 'models.138.conv94.bias',
 575        'head.conv3.conv.0.weight': 'models.141.conv95.weight',
 576        'head.conv3.conv.1.weight': 'models.141.bn95.weight',
 577        'head.conv3.conv.1.bias': 'models.141.bn95.bias',
 578        'head.conv3.conv.1.running_mean': 'models.141.bn95.running_mean',
 579        'head.conv3.conv.1.running_var': 'models.141.bn95.running_var',
 580        'head.conv3.conv.1.num_batches_tracked': 'models.141.bn95.num_batches_tracked',
 581        'head.conv4.conv.0.weight': 'models.143.conv96.weight',
 582        'head.conv4.conv.1.weight': 'models.143.bn96.weight',
 583        'head.conv4.conv.1.bias': 'models.143.bn96.bias',
 584        'head.conv4.conv.1.running_mean': 'models.143.bn96.running_mean',
 585        'head.conv4.conv.1.running_var': 'models.143.bn96.running_var',
 586        'head.conv4.conv.1.num_batches_tracked': 'models.143.bn96.num_batches_tracked',
 587        'head.conv5.conv.0.weight': 'models.144.conv97.weight',
 588        'head.conv5.conv.1.weight': 'models.144.bn97.weight',
 589        'head.conv5.conv.1.bias': 'models.144.bn97.bias',
 590        'head.conv5.conv.1.running_mean': 'models.144.bn97.running_mean',
 591        'head.conv5.conv.1.running_var': 'models.144.bn97.running_var',
 592        'head.conv5.conv.1.num_batches_tracked': 'models.144.bn97.num_batches_tracked',
 593        'head.conv6.conv.0.weight': 'models.145.conv98.weight',
 594        'head.conv6.conv.1.weight': 'models.145.bn98.weight',
 595        'head.conv6.conv.1.bias': 'models.145.bn98.bias',
 596        'head.conv6.conv.1.running_mean': 'models.145.bn98.running_mean',
 597        'head.conv6.conv.1.running_var': 'models.145.bn98.running_var',
 598        'head.conv6.conv.1.num_batches_tracked': 'models.145.bn98.num_batches_tracked',
 599        'head.conv7.conv.0.weight': 'models.146.conv99.weight',
 600        'head.conv7.conv.1.weight': 'models.146.bn99.weight',
 601        'head.conv7.conv.1.bias': 'models.146.bn99.bias',
 602        'head.conv7.conv.1.running_mean': 'models.146.bn99.running_mean',
 603        'head.conv7.conv.1.running_var': 'models.146.bn99.running_var',
 604        'head.conv7.conv.1.num_batches_tracked': 'models.146.bn99.num_batches_tracked',
 605        'head.conv8.conv.0.weight': 'models.147.conv100.weight',
 606        'head.conv8.conv.1.weight': 'models.147.bn100.weight',
 607        'head.conv8.conv.1.bias': 'models.147.bn100.bias',
 608        'head.conv8.conv.1.running_mean': 'models.147.bn100.running_mean',
 609        'head.conv8.conv.1.running_var': 'models.147.bn100.running_var',
 610        'head.conv8.conv.1.num_batches_tracked': 'models.147.bn100.num_batches_tracked',
 611        'head.conv9.conv.0.weight': 'models.148.conv101.weight',
 612        'head.conv9.conv.1.weight': 'models.148.bn101.weight',
 613        'head.conv9.conv.1.bias': 'models.148.bn101.bias',
 614        'head.conv9.conv.1.running_mean': 'models.148.bn101.running_mean',
 615        'head.conv9.conv.1.running_var': 'models.148.bn101.running_var',
 616        'head.conv9.conv.1.num_batches_tracked': 'models.148.bn101.num_batches_tracked',
 617        'head.conv10.conv.0.weight': 'models.149.conv102.weight',
 618        'head.conv10.conv.0.bias': 'models.149.conv102.bias',
 619        'head.conv11.conv.0.weight': 'models.152.conv103.weight',
 620        'head.conv11.conv.1.weight': 'models.152.bn103.weight',
 621        'head.conv11.conv.1.bias': 'models.152.bn103.bias',
 622        'head.conv11.conv.1.running_mean': 'models.152.bn103.running_mean',
 623        'head.conv11.conv.1.running_var': 'models.152.bn103.running_var',
 624        'head.conv11.conv.1.num_batches_tracked': 'models.152.bn103.num_batches_tracked',
 625        'head.conv12.conv.0.weight': 'models.154.conv104.weight',
 626        'head.conv12.conv.1.weight': 'models.154.bn104.weight',
 627        'head.conv12.conv.1.bias': 'models.154.bn104.bias',
 628        'head.conv12.conv.1.running_mean': 'models.154.bn104.running_mean',
 629        'head.conv12.conv.1.running_var': 'models.154.bn104.running_var',
 630        'head.conv12.conv.1.num_batches_tracked': 'models.154.bn104.num_batches_tracked',
 631        'head.conv13.conv.0.weight': 'models.155.conv105.weight',
 632        'head.conv13.conv.1.weight': 'models.155.bn105.weight',
 633        'head.conv13.conv.1.bias': 'models.155.bn105.bias',
 634        'head.conv13.conv.1.running_mean': 'models.155.bn105.running_mean',
 635        'head.conv13.conv.1.running_var': 'models.155.bn105.running_var',
 636        'head.conv13.conv.1.num_batches_tracked': 'models.155.bn105.num_batches_tracked',
 637        'head.conv14.conv.0.weight': 'models.156.conv106.weight',
 638        'head.conv14.conv.1.weight': 'models.156.bn106.weight',
 639        'head.conv14.conv.1.bias': 'models.156.bn106.bias',
 640        'head.conv14.conv.1.running_mean': 'models.156.bn106.running_mean',
 641        'head.conv14.conv.1.running_var': 'models.156.bn106.running_var',
 642        'head.conv14.conv.1.num_batches_tracked': 'models.156.bn106.num_batches_tracked',
 643        'head.conv15.conv.0.weight': 'models.157.conv107.weight',
 644        'head.conv15.conv.1.weight': 'models.157.bn107.weight',
 645        'head.conv15.conv.1.bias': 'models.157.bn107.bias',
 646        'head.conv15.conv.1.running_mean': 'models.157.bn107.running_mean',
 647        'head.conv15.conv.1.running_var': 'models.157.bn107.running_var',
 648        'head.conv15.conv.1.num_batches_tracked': 'models.157.bn107.num_batches_tracked',
 649        'head.conv16.conv.0.weight': 'models.158.conv108.weight',
 650        'head.conv16.conv.1.weight': 'models.158.bn108.weight',
 651        'head.conv16.conv.1.bias': 'models.158.bn108.bias',
 652        'head.conv16.conv.1.running_mean': 'models.158.bn108.running_mean',
 653        'head.conv16.conv.1.running_var': 'models.158.bn108.running_var',
 654        'head.conv16.conv.1.num_batches_tracked': 'models.158.bn108.num_batches_tracked',
 655        'head.conv17.conv.0.weight': 'models.159.conv109.weight',
 656        'head.conv17.conv.1.weight': 'models.159.bn109.weight',
 657        'head.conv17.conv.1.bias': 'models.159.bn109.bias',
 658        'head.conv17.conv.1.running_mean': 'models.159.bn109.running_mean',
 659        'head.conv17.conv.1.running_var': 'models.159.bn109.running_var',
 660        'head.conv17.conv.1.num_batches_tracked': 'models.159.bn109.num_batches_tracked',
 661        'head.conv18.conv.0.weight': 'models.160.conv110.weight',
 662        'head.conv18.conv.0.bias': 'models.160.conv110.bias',
 663    }
 664    pth_weights = torch.load(checkpoint)
 665    pt_weights = type(pth_weights)()
 666    for name, new_name in name_mapping.items():
 667        pt_weights[new_name] = pth_weights[name]
 668    return pt_weights
 669
 670
 671def convert_pt_checkpoint_to_keras_h5(state_dict):
 672    print('============================================================')
 673
 674    def copy1(conv, bn, idx):
 675        keyword1 = 'conv%d.weight' % idx
 676        keyword2 = 'bn%d.weight' % idx
 677        keyword3 = 'bn%d.bias' % idx
 678        keyword4 = 'bn%d.running_mean' % idx
 679        keyword5 = 'bn%d.running_var' % idx
 680        for key in state_dict:
 681            value = state_dict[key].numpy()
 682            if keyword1 in key:
 683                w = value
 684            elif keyword2 in key:
 685                y = value
 686            elif keyword3 in key:
 687                b = value
 688            elif keyword4 in key:
 689                m = value
 690            elif keyword5 in key:
 691                v = value
 692        w = w.transpose(2, 3, 1, 0)
 693        conv.set_weights([w])
 694        bn.set_weights([y, b, m, v])
 695
 696    def copy2(conv, idx):
 697        keyword1 = 'conv%d.weight' % idx
 698        keyword2 = 'conv%d.bias' % idx
 699        for key in state_dict:
 700            value = state_dict[key].numpy()
 701            if keyword1 in key:
 702                w = value
 703            elif keyword2 in key:
 704                b = value
 705        w = w.transpose(2, 3, 1, 0)
 706        conv.set_weights([w, b])
 707
 708    num_classes = 80
 709    num_anchors = 3
 710
 711    with tf.Session(graph=tf.Graph()):
 712        inputs = layers.Input(shape=[], dtype='string')
 713        model_body = YOLOv4(inputs, num_classes, num_anchors)
 714        model_body.summary()
 715        layer_name_to_idx = {layer.name: idx for idx, layer in enumerate(model_body.layers)}
 716
 717        print('\nCopying...')
 718        i1 = layer_name_to_idx['conv2d']
 719        i2 = layer_name_to_idx['batch_normalization']
 720        copy1(model_body.layers[i1], model_body.layers[i2], 1)
 721        for i in range(2, 94, 1):
 722            i1 = layer_name_to_idx['conv2d_%d' % (i - 1)]
 723            i2 = layer_name_to_idx['batch_normalization_%d' % (i - 1)]
 724            copy1(model_body.layers[i1], model_body.layers[i2], i)
 725        for i in range(95, 102, 1):
 726            i1 = layer_name_to_idx['conv2d_%d' % (i - 1)]
 727            i2 = layer_name_to_idx['batch_normalization_%d' % (i - 2,)]
 728            copy1(model_body.layers[i1], model_body.layers[i2], i)
 729        for i in range(103, 110, 1):
 730            i1 = layer_name_to_idx['conv2d_%d' % (i - 1)]
 731            i2 = layer_name_to_idx['batch_normalization_%d' % (i - 3,)]
 732            copy1(model_body.layers[i1], model_body.layers[i2], i)
 733
 734        i1 = layer_name_to_idx['conv2d_93']
 735        copy2(model_body.layers[i1], 94)
 736        i1 = layer_name_to_idx['conv2d_101']
 737        copy2(model_body.layers[i1], 102)
 738        i1 = layer_name_to_idx['conv2d_109']
 739        copy2(model_body.layers[i1], 110)
 740
 741        weights = model_body.get_weights()
 742    print('\nDone.')
 743    return weights
 744
 745
 746class Mish(layers.Layer):
 747
 748    def __init__(self):
 749        super(Mish, self).__init__()
 750
 751    def compute_output_shape(self, input_shape):
 752        return input_shape
 753
 754    def call(self, x):
 755        return x * tf.tanh(tf.math.softplus(x))
 756
 757
 758def conv2d_unit(x, filters, kernels, strides=1, padding='valid', bn=1, act='mish'):
 759    use_bias = (bn != 1)
 760    x = layers.Conv2D(filters, kernels,
 761                      padding=padding,
 762                      strides=strides,
 763                      use_bias=use_bias,
 764                      activation='linear',
 765                      kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.01))(x)
 766    if bn:
 767        x = layers.BatchNormalization(fused=False)(x)
 768    if act == 'leaky':
 769        x = keras.layers.LeakyReLU(alpha=0.1)(x)
 770    elif act == 'mish':
 771        x = Mish()(x)
 772    return x
 773
 774
 775def residual_block(inputs, filters_1, filters_2):
 776    x = conv2d_unit(inputs, filters_1, 1, strides=1, padding='valid')
 777    x = conv2d_unit(x, filters_2, 3, strides=1, padding='same')
 778    x = layers.add([inputs, x])
 779    return x
 780
 781
 782def stack_residual_block(inputs, filters_1, filters_2, n):
 783    x = residual_block(inputs, filters_1, filters_2)
 784    for i in range(n - 1):
 785        x = residual_block(x, filters_1, filters_2)
 786    return x
 787
 788
 789def spp(x):
 790    x_1 = x
 791    x_2 = layers.MaxPooling2D(pool_size=5, strides=1, padding='same')(x)
 792    x_3 = layers.MaxPooling2D(pool_size=9, strides=1, padding='same')(x)
 793    x_4 = layers.MaxPooling2D(pool_size=13, strides=1, padding='same')(x)
 794    out = layers.Concatenate()([x_4, x_3, x_2, x_1])
 795    return out
 796
 797
 798def YOLOv4(inputs, num_classes, num_anchors, input_shape=(608, 608), initial_filters=32,
 799           fast=False, anchors=None, conf_thresh=0.05, nms_thresh=0.45, keep_top_k=100, nms_top_k=100):
 800    i32 = initial_filters
 801    i64 = i32 * 2
 802    i128 = i32 * 4
 803    i256 = i32 * 8
 804    i512 = i32 * 16
 805    i1024 = i32 * 32
 806
 807    x, image_shape = layers.Lambda(lambda t: preprocessor(t, input_shape))(inputs)
 808
 809    # cspdarknet53
 810    x = conv2d_unit(x, i32, 3, strides=1, padding='same')
 811
 812    # ============================= s2 =============================
 813    x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(x)
 814    x = conv2d_unit(x, i64, 3, strides=2)
 815    s2 = conv2d_unit(x, i64, 1, strides=1)
 816    x = conv2d_unit(x, i64, 1, strides=1)
 817    x = stack_residual_block(x, i32, i64, n=1)
 818    x = conv2d_unit(x, i64, 1, strides=1)
 819    x = layers.Concatenate()([x, s2])
 820    s2 = conv2d_unit(x, i64, 1, strides=1)
 821
 822    # ============================= s4 =============================
 823    x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(s2)
 824    x = conv2d_unit(x, i128, 3, strides=2)
 825    s4 = conv2d_unit(x, i64, 1, strides=1)
 826    x = conv2d_unit(x, i64, 1, strides=1)
 827    x = stack_residual_block(x, i64, i64, n=2)
 828    x = conv2d_unit(x, i64, 1, strides=1)
 829    x = layers.Concatenate()([x, s4])
 830    s4 = conv2d_unit(x, i128, 1, strides=1)
 831
 832    # ============================= s8 =============================
 833    x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(s4)
 834    x = conv2d_unit(x, i256, 3, strides=2)
 835    s8 = conv2d_unit(x, i128, 1, strides=1)
 836    x = conv2d_unit(x, i128, 1, strides=1)
 837    x = stack_residual_block(x, i128, i128, n=8)
 838    x = conv2d_unit(x, i128, 1, strides=1)
 839    x = layers.Concatenate()([x, s8])
 840    s8 = conv2d_unit(x, i256, 1, strides=1)
 841
 842    # ============================= s16 =============================
 843    x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(s8)
 844    x = conv2d_unit(x, i512, 3, strides=2)
 845    s16 = conv2d_unit(x, i256, 1, strides=1)
 846    x = conv2d_unit(x, i256, 1, strides=1)
 847    x = stack_residual_block(x, i256, i256, n=8)
 848    x = conv2d_unit(x, i256, 1, strides=1)
 849    x = layers.Concatenate()([x, s16])
 850    s16 = conv2d_unit(x, i512, 1, strides=1)
 851
 852    # ============================= s32 =============================
 853    x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(s16)
 854    x = conv2d_unit(x, i1024, 3, strides=2)
 855    s32 = conv2d_unit(x, i512, 1, strides=1)
 856    x = conv2d_unit(x, i512, 1, strides=1)
 857    x = stack_residual_block(x, i512, i512, n=4)
 858    x = conv2d_unit(x, i512, 1, strides=1)
 859    x = layers.Concatenate()([x, s32])
 860    s32 = conv2d_unit(x, i1024, 1, strides=1)
 861
 862    # fpn
 863    x = conv2d_unit(s32, i512, 1, strides=1, act='leaky')
 864    x = conv2d_unit(x, i1024, 3, strides=1, padding='same', act='leaky')
 865    x = conv2d_unit(x, i512, 1, strides=1, act='leaky')
 866    x = spp(x)
 867
 868    x = conv2d_unit(x, i512, 1, strides=1, act='leaky')
 869    x = conv2d_unit(x, i1024, 3, strides=1, padding='same', act='leaky')
 870    fpn_s32 = conv2d_unit(x, i512, 1, strides=1, act='leaky')
 871
 872    # pan01
 873    x = conv2d_unit(fpn_s32, i256, 1, strides=1, act='leaky')
 874    x = layers.UpSampling2D(2)(x)
 875    s16 = conv2d_unit(s16, i256, 1, strides=1, act='leaky')
 876    x = layers.Concatenate()([s16, x])
 877    x = conv2d_unit(x, i256, 1, strides=1, act='leaky')
 878    x = conv2d_unit(x, i512, 3, strides=1, padding='same', act='leaky')
 879    x = conv2d_unit(x, i256, 1, strides=1, act='leaky')
 880    x = conv2d_unit(x, i512, 3, strides=1, padding='same', act='leaky')
 881    fpn_s16 = conv2d_unit(x, i256, 1, strides=1, act='leaky')
 882
 883    # pan02
 884    x = conv2d_unit(fpn_s16, i128, 1, strides=1, act='leaky')
 885    x = layers.UpSampling2D(2)(x)
 886    s8 = conv2d_unit(s8, i128, 1, strides=1, act='leaky')
 887    x = layers.Concatenate()([s8, x])
 888    x = conv2d_unit(x, i128, 1, strides=1, act='leaky')
 889    x = conv2d_unit(x, i256, 3, strides=1, padding='same', act='leaky')
 890    x = conv2d_unit(x, i128, 1, strides=1, act='leaky')
 891    x = conv2d_unit(x, i256, 3, strides=1, padding='same', act='leaky')
 892    x = conv2d_unit(x, i128, 1, strides=1, act='leaky')
 893
 894    # output_s, doesn't need concat()
 895    output_s = conv2d_unit(x, i256, 3, strides=1, padding='same', act='leaky')
 896    output_s = conv2d_unit(output_s, num_anchors * (num_classes + 5), 1, strides=1, bn=0, act=None)
 897
 898    # output_m, need concat()
 899    x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(x)
 900    x = conv2d_unit(x, i256, 3, strides=2, act='leaky')
 901    x = layers.Concatenate()([x, fpn_s16])
 902    x = conv2d_unit(x, i256, 1, strides=1, act='leaky')
 903    x = conv2d_unit(x, i512, 3, strides=1, padding='same', act='leaky')
 904    x = conv2d_unit(x, i256, 1, strides=1, act='leaky')
 905    x = conv2d_unit(x, i512, 3, strides=1, padding='same', act='leaky')
 906    x = conv2d_unit(x, i256, 1, strides=1, act='leaky')
 907    output_m = conv2d_unit(x, i512, 3, strides=1, padding='same', act='leaky')
 908    output_m = conv2d_unit(output_m, num_anchors * (num_classes + 5), 1, strides=1, bn=0, act=None)
 909
 910    # output_l, need concat()
 911    x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(x)
 912    x = conv2d_unit(x, i512, 3, strides=2, act='leaky')
 913    x = layers.Concatenate()([x, fpn_s32])
 914    x = conv2d_unit(x, i512, 1, strides=1, act='leaky')
 915    x = conv2d_unit(x, i1024, 3, strides=1, padding='same', act='leaky')
 916    x = conv2d_unit(x, i512, 1, strides=1, act='leaky')
 917    x = conv2d_unit(x, i1024, 3, strides=1, padding='same', act='leaky')
 918    x = conv2d_unit(x, i512, 1, strides=1, act='leaky')
 919    output_l = conv2d_unit(x, i1024, 3, strides=1, padding='same', act='leaky')
 920    output_l = conv2d_unit(output_l, num_anchors * (num_classes + 5), 1, strides=1, bn=0, act=None)
 921
 922    def cast_float32(tensor):
 923        return tf.cast(tensor, tf.float32)
 924
 925    output_l = layers.Lambda(cast_float32)(output_l)
 926    output_m = layers.Lambda(cast_float32)(output_m)
 927    output_s = layers.Lambda(cast_float32)(output_s)
 928
 929    # originally reshape in multi_thread_post
 930    output_lr = layers.Reshape((1, input_shape[0] // 32, input_shape[1] // 32, 3, 5 + num_classes))(output_l)
 931    output_mr = layers.Reshape((1, input_shape[0] // 16, input_shape[1] // 16, 3, 5 + num_classes))(output_m)
 932    output_sr = layers.Reshape((1, input_shape[0] // 8, input_shape[1] // 8, 3, 5 + num_classes))(output_s)
 933
 934    # originally _yolo_out
 935    masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
 936    anchors = [[12, 16], [19, 36], [40, 28], [36, 75], [76, 55],
 937               [72, 146], [142, 110], [192, 243], [459, 401]]
 938
 939    def batch_process_feats(out, anchors, mask):
 940        grid_h, grid_w, num_boxes = map(int, out.shape[2:5])
 941
 942        anchors = [anchors[i] for i in mask]
 943        anchors_tensor = np.array(anchors).reshape(1, 1, len(anchors), 2)
 944
 945        # Reshape to batch, height, width, num_anchors, box_params.
 946        box_xy = tf.sigmoid(out[..., :2])
 947        box_wh = tf.exp(out[..., 2:4])
 948        box_wh = box_wh * anchors_tensor
 949
 950        box_confidence = tf.sigmoid(out[..., 4])
 951        box_confidence = tf.expand_dims(box_confidence, axis=-1)
 952        box_class_probs = tf.sigmoid(out[..., 5:])
 953
 954        col = np.tile(np.arange(0, grid_w), grid_w).reshape(-1, grid_w)
 955        row = np.tile(np.arange(0, grid_h).reshape(-1, 1), grid_h)
 956
 957        col = col.reshape(grid_h, grid_w, 1, 1).repeat(3, axis=-2)
 958        row = row.reshape(grid_h, grid_w, 1, 1).repeat(3, axis=-2)
 959        grid = np.concatenate((col, row), axis=-1).astype(np.float32)
 960
 961        box_xy += grid
 962        box_xy /= (grid_w, grid_h)
 963        box_wh /= input_shape
 964        box_xy -= (box_wh / 2.)  # normalized xywh
 965        boxes = tf.concat((box_xy, box_xy + box_wh), axis=-1)
 966
 967        box_scores = box_confidence * box_class_probs
 968        num_boxes = np.prod(boxes.shape[1:-1])
 969        boxes = tf.reshape(boxes, [-1, num_boxes, boxes.shape[-1]])
 970        box_scores = tf.reshape(box_scores, [-1, num_boxes, box_scores.shape[-1]])
 971        return boxes, box_scores
 972
 973    def filter_boxes(outputs):
 974        boxes_l, boxes_m, boxes_s, box_scores_l, box_scores_m, box_scores_s, image_shape = outputs
 975        boxes_l, box_scores_l = filter_boxes_one_size(boxes_l, box_scores_l)
 976        boxes_m, box_scores_m = filter_boxes_one_size(boxes_m, box_scores_m)
 977        boxes_s, box_scores_s = filter_boxes_one_size(boxes_s, box_scores_s)
 978        boxes = tf.concat([boxes_l, boxes_m, boxes_s], axis=0)
 979        box_scores = tf.concat([box_scores_l, box_scores_m, box_scores_s], axis=0)
 980        image_shape_wh = image_shape[1::-1]
 981        image_shape_whwh = tf.concat([image_shape_wh, image_shape_wh], axis=-1)
 982        image_shape_whwh = tf.cast(image_shape_whwh, tf.float32)
 983        boxes *= image_shape_whwh
 984        boxes = tf.expand_dims(boxes, 0)
 985        box_scores = tf.expand_dims(box_scores, 0)
 986        boxes = tf.expand_dims(boxes, 2)
 987        nms_boxes, nms_scores, nms_classes, valid_detections = tf.image.combined_non_max_suppression(
 988            boxes,
 989            box_scores,
 990            max_output_size_per_class=nms_top_k,
 991            max_total_size=nms_top_k,
 992            iou_threshold=nms_thresh,
 993            score_threshold=conf_thresh,
 994            pad_per_class=False,
 995            clip_boxes=False,
 996            name='CombinedNonMaxSuppression',
 997        )
 998        return nms_boxes[0], nms_scores[0], nms_classes[0]
 999
1000    def filter_boxes_one_size(boxes, box_scores):
1001        box_class_scores = tf.reduce_max(box_scores, axis=-1)
1002        keep = box_class_scores > conf_thresh
1003        boxes = boxes[keep]
1004        box_scores = box_scores[keep]
1005        return boxes, box_scores
1006
1007    def batch_yolo_out(outputs):
1008        with tf.name_scope('yolo_out'):
1009            b_output_lr, b_output_mr, b_output_sr, b_image_shape = outputs
1010            with tf.name_scope('process_feats'):
1011                b_boxes_l, b_box_scores_l = batch_process_feats(b_output_lr, anchors, masks[0])
1012            with tf.name_scope('process_feats'):
1013                b_boxes_m, b_box_scores_m = batch_process_feats(b_output_mr, anchors, masks[1])
1014            with tf.name_scope('process_feats'):
1015                b_boxes_s, b_box_scores_s = batch_process_feats(b_output_sr, anchors, masks[2])
1016            with tf.name_scope('filter_boxes'):
1017                b_nms_boxes, b_nms_scores, b_nms_classes = tf.map_fn(
1018                    filter_boxes, [b_boxes_l, b_boxes_m, b_boxes_s, b_box_scores_l, b_box_scores_m, b_box_scores_s, b_image_shape],
1019                    dtype=(tf.float32, tf.float32, tf.float32), back_prop=False, parallel_iterations=16)
1020        return b_nms_boxes, b_nms_scores, b_nms_classes
1021
1022    boxes_scores_classes = layers.Lambda(batch_yolo_out)([output_lr, output_mr, output_sr, image_shape])
1023
1024    model_body = keras.models.Model(inputs=inputs, outputs=boxes_scores_classes)
1025    return model_body
1026
1027
1028def decode_jpeg_resize(input_tensor, image_size):
1029    tensor = tf.image.decode_png(input_tensor, channels=3)
1030    shape = tf.shape(tensor)
1031    tensor = tf.cast(tensor, tf.float32)
1032    tensor = tf.image.resize(tensor, image_size)
1033    tensor /= 255.0
1034    return tf.cast(tensor, tf.float16), shape
1035
1036
1037def preprocessor(input_tensor, image_size):
1038    with tf.name_scope('Preprocessor'):
1039        tensor = tf.map_fn(
1040            partial(decode_jpeg_resize, image_size=image_size), input_tensor,
1041            dtype=(tf.float16, tf.int32), back_prop=False, parallel_iterations=16)
1042    return tensor
1043
1044
1045def main():
1046    os.system('aws s3 cp s3://neuron-s3/training_checkpoints/pytorch/yolov4/yolov4.pth . --no-sign-request')
1047    torch_weights = rename_weights('./yolov4.pth')
1048    keras_weights = convert_pt_checkpoint_to_keras_h5(torch_weights)
1049    keras.backend.set_learning_phase(0)
1050    num_anchors = 3
1051    num_classes = 80
1052    input_shape = (608, 608)
1053    conf_thresh = 0.001
1054    nms_thresh = 0.45
1055    inputs = layers.Input(shape=[], dtype='string')
1056    yolo = YOLOv4(inputs, num_classes, num_anchors, input_shape, conf_thresh=conf_thresh, nms_thresh=nms_thresh)
1057    yolo.set_weights(keras_weights)
1058    sess = keras.backend.get_session()
1059    inputs = {'image': yolo.inputs[0]}
1060    output_names = ['boxes', 'scores', 'classes']
1061    outputs = {name: ts for name, ts in zip(output_names, yolo.outputs)}
1062    tf.saved_model.simple_save(sess, './yolo_v4_coco_saved_model', inputs, outputs)
1063
1064
1065if __name__ == '__main__':
1066    main()

This document is relevant for: Inf1