This document is relevant for: Inf1
yolo_v4_coco_saved_model.py
1import os
2import io
3from functools import partial
4import requests
5import numpy as np
6import torch
7import tensorflow as tf
8from tensorflow import keras
9from tensorflow.keras import layers
10
11
12
13def rename_weights(checkpoint):
14 name_mapping = {
15 'down1.conv1.conv.0.weight': 'models.0.conv1.weight',
16 'down1.conv1.conv.1.weight': 'models.0.bn1.weight',
17 'down1.conv1.conv.1.bias': 'models.0.bn1.bias',
18 'down1.conv1.conv.1.running_mean': 'models.0.bn1.running_mean',
19 'down1.conv1.conv.1.running_var': 'models.0.bn1.running_var',
20 'down1.conv1.conv.1.num_batches_tracked': 'models.0.bn1.num_batches_tracked',
21 'down1.conv2.conv.0.weight': 'models.1.conv2.weight',
22 'down1.conv2.conv.1.weight': 'models.1.bn2.weight',
23 'down1.conv2.conv.1.bias': 'models.1.bn2.bias',
24 'down1.conv2.conv.1.running_mean': 'models.1.bn2.running_mean',
25 'down1.conv2.conv.1.running_var': 'models.1.bn2.running_var',
26 'down1.conv2.conv.1.num_batches_tracked': 'models.1.bn2.num_batches_tracked',
27 'down1.conv3.conv.0.weight': 'models.2.conv3.weight',
28 'down1.conv3.conv.1.weight': 'models.2.bn3.weight',
29 'down1.conv3.conv.1.bias': 'models.2.bn3.bias',
30 'down1.conv3.conv.1.running_mean': 'models.2.bn3.running_mean',
31 'down1.conv3.conv.1.running_var': 'models.2.bn3.running_var',
32 'down1.conv3.conv.1.num_batches_tracked': 'models.2.bn3.num_batches_tracked',
33 'down1.conv4.conv.0.weight': 'models.4.conv4.weight',
34 'down1.conv4.conv.1.weight': 'models.4.bn4.weight',
35 'down1.conv4.conv.1.bias': 'models.4.bn4.bias',
36 'down1.conv4.conv.1.running_mean': 'models.4.bn4.running_mean',
37 'down1.conv4.conv.1.running_var': 'models.4.bn4.running_var',
38 'down1.conv4.conv.1.num_batches_tracked': 'models.4.bn4.num_batches_tracked',
39 'down1.conv5.conv.0.weight': 'models.5.conv5.weight',
40 'down1.conv5.conv.1.weight': 'models.5.bn5.weight',
41 'down1.conv5.conv.1.bias': 'models.5.bn5.bias',
42 'down1.conv5.conv.1.running_mean': 'models.5.bn5.running_mean',
43 'down1.conv5.conv.1.running_var': 'models.5.bn5.running_var',
44 'down1.conv5.conv.1.num_batches_tracked': 'models.5.bn5.num_batches_tracked',
45 'down1.conv6.conv.0.weight': 'models.6.conv6.weight',
46 'down1.conv6.conv.1.weight': 'models.6.bn6.weight',
47 'down1.conv6.conv.1.bias': 'models.6.bn6.bias',
48 'down1.conv6.conv.1.running_mean': 'models.6.bn6.running_mean',
49 'down1.conv6.conv.1.running_var': 'models.6.bn6.running_var',
50 'down1.conv6.conv.1.num_batches_tracked': 'models.6.bn6.num_batches_tracked',
51 'down1.conv7.conv.0.weight': 'models.8.conv7.weight',
52 'down1.conv7.conv.1.weight': 'models.8.bn7.weight',
53 'down1.conv7.conv.1.bias': 'models.8.bn7.bias',
54 'down1.conv7.conv.1.running_mean': 'models.8.bn7.running_mean',
55 'down1.conv7.conv.1.running_var': 'models.8.bn7.running_var',
56 'down1.conv7.conv.1.num_batches_tracked': 'models.8.bn7.num_batches_tracked',
57 'down1.conv8.conv.0.weight': 'models.10.conv8.weight',
58 'down1.conv8.conv.1.weight': 'models.10.bn8.weight',
59 'down1.conv8.conv.1.bias': 'models.10.bn8.bias',
60 'down1.conv8.conv.1.running_mean': 'models.10.bn8.running_mean',
61 'down1.conv8.conv.1.running_var': 'models.10.bn8.running_var',
62 'down1.conv8.conv.1.num_batches_tracked': 'models.10.bn8.num_batches_tracked',
63 'down2.conv1.conv.0.weight': 'models.11.conv9.weight',
64 'down2.conv1.conv.1.weight': 'models.11.bn9.weight',
65 'down2.conv1.conv.1.bias': 'models.11.bn9.bias',
66 'down2.conv1.conv.1.running_mean': 'models.11.bn9.running_mean',
67 'down2.conv1.conv.1.running_var': 'models.11.bn9.running_var',
68 'down2.conv1.conv.1.num_batches_tracked': 'models.11.bn9.num_batches_tracked',
69 'down2.conv2.conv.0.weight': 'models.12.conv10.weight',
70 'down2.conv2.conv.1.weight': 'models.12.bn10.weight',
71 'down2.conv2.conv.1.bias': 'models.12.bn10.bias',
72 'down2.conv2.conv.1.running_mean': 'models.12.bn10.running_mean',
73 'down2.conv2.conv.1.running_var': 'models.12.bn10.running_var',
74 'down2.conv2.conv.1.num_batches_tracked': 'models.12.bn10.num_batches_tracked',
75 'down2.conv3.conv.0.weight': 'models.14.conv11.weight',
76 'down2.conv3.conv.1.weight': 'models.14.bn11.weight',
77 'down2.conv3.conv.1.bias': 'models.14.bn11.bias',
78 'down2.conv3.conv.1.running_mean': 'models.14.bn11.running_mean',
79 'down2.conv3.conv.1.running_var': 'models.14.bn11.running_var',
80 'down2.conv3.conv.1.num_batches_tracked': 'models.14.bn11.num_batches_tracked',
81 'down2.resblock.module_list.0.0.conv.0.weight': 'models.15.conv12.weight',
82 'down2.resblock.module_list.0.0.conv.1.weight': 'models.15.bn12.weight',
83 'down2.resblock.module_list.0.0.conv.1.bias': 'models.15.bn12.bias',
84 'down2.resblock.module_list.0.0.conv.1.running_mean': 'models.15.bn12.running_mean',
85 'down2.resblock.module_list.0.0.conv.1.running_var': 'models.15.bn12.running_var',
86 'down2.resblock.module_list.0.0.conv.1.num_batches_tracked': 'models.15.bn12.num_batches_tracked',
87 'down2.resblock.module_list.0.1.conv.0.weight': 'models.16.conv13.weight',
88 'down2.resblock.module_list.0.1.conv.1.weight': 'models.16.bn13.weight',
89 'down2.resblock.module_list.0.1.conv.1.bias': 'models.16.bn13.bias',
90 'down2.resblock.module_list.0.1.conv.1.running_mean': 'models.16.bn13.running_mean',
91 'down2.resblock.module_list.0.1.conv.1.running_var': 'models.16.bn13.running_var',
92 'down2.resblock.module_list.0.1.conv.1.num_batches_tracked': 'models.16.bn13.num_batches_tracked',
93 'down2.resblock.module_list.1.0.conv.0.weight': 'models.18.conv14.weight',
94 'down2.resblock.module_list.1.0.conv.1.weight': 'models.18.bn14.weight',
95 'down2.resblock.module_list.1.0.conv.1.bias': 'models.18.bn14.bias',
96 'down2.resblock.module_list.1.0.conv.1.running_mean': 'models.18.bn14.running_mean',
97 'down2.resblock.module_list.1.0.conv.1.running_var': 'models.18.bn14.running_var',
98 'down2.resblock.module_list.1.0.conv.1.num_batches_tracked': 'models.18.bn14.num_batches_tracked',
99 'down2.resblock.module_list.1.1.conv.0.weight': 'models.19.conv15.weight',
100 'down2.resblock.module_list.1.1.conv.1.weight': 'models.19.bn15.weight',
101 'down2.resblock.module_list.1.1.conv.1.bias': 'models.19.bn15.bias',
102 'down2.resblock.module_list.1.1.conv.1.running_mean': 'models.19.bn15.running_mean',
103 'down2.resblock.module_list.1.1.conv.1.running_var': 'models.19.bn15.running_var',
104 'down2.resblock.module_list.1.1.conv.1.num_batches_tracked': 'models.19.bn15.num_batches_tracked',
105 'down2.conv4.conv.0.weight': 'models.21.conv16.weight',
106 'down2.conv4.conv.1.weight': 'models.21.bn16.weight',
107 'down2.conv4.conv.1.bias': 'models.21.bn16.bias',
108 'down2.conv4.conv.1.running_mean': 'models.21.bn16.running_mean',
109 'down2.conv4.conv.1.running_var': 'models.21.bn16.running_var',
110 'down2.conv4.conv.1.num_batches_tracked': 'models.21.bn16.num_batches_tracked',
111 'down2.conv5.conv.0.weight': 'models.23.conv17.weight',
112 'down2.conv5.conv.1.weight': 'models.23.bn17.weight',
113 'down2.conv5.conv.1.bias': 'models.23.bn17.bias',
114 'down2.conv5.conv.1.running_mean': 'models.23.bn17.running_mean',
115 'down2.conv5.conv.1.running_var': 'models.23.bn17.running_var',
116 'down2.conv5.conv.1.num_batches_tracked': 'models.23.bn17.num_batches_tracked',
117 'down3.conv1.conv.0.weight': 'models.24.conv18.weight',
118 'down3.conv1.conv.1.weight': 'models.24.bn18.weight',
119 'down3.conv1.conv.1.bias': 'models.24.bn18.bias',
120 'down3.conv1.conv.1.running_mean': 'models.24.bn18.running_mean',
121 'down3.conv1.conv.1.running_var': 'models.24.bn18.running_var',
122 'down3.conv1.conv.1.num_batches_tracked': 'models.24.bn18.num_batches_tracked',
123 'down3.conv2.conv.0.weight': 'models.25.conv19.weight',
124 'down3.conv2.conv.1.weight': 'models.25.bn19.weight',
125 'down3.conv2.conv.1.bias': 'models.25.bn19.bias',
126 'down3.conv2.conv.1.running_mean': 'models.25.bn19.running_mean',
127 'down3.conv2.conv.1.running_var': 'models.25.bn19.running_var',
128 'down3.conv2.conv.1.num_batches_tracked': 'models.25.bn19.num_batches_tracked',
129 'down3.conv3.conv.0.weight': 'models.27.conv20.weight',
130 'down3.conv3.conv.1.weight': 'models.27.bn20.weight',
131 'down3.conv3.conv.1.bias': 'models.27.bn20.bias',
132 'down3.conv3.conv.1.running_mean': 'models.27.bn20.running_mean',
133 'down3.conv3.conv.1.running_var': 'models.27.bn20.running_var',
134 'down3.conv3.conv.1.num_batches_tracked': 'models.27.bn20.num_batches_tracked',
135 'down3.resblock.module_list.0.0.conv.0.weight': 'models.28.conv21.weight',
136 'down3.resblock.module_list.0.0.conv.1.weight': 'models.28.bn21.weight',
137 'down3.resblock.module_list.0.0.conv.1.bias': 'models.28.bn21.bias',
138 'down3.resblock.module_list.0.0.conv.1.running_mean': 'models.28.bn21.running_mean',
139 'down3.resblock.module_list.0.0.conv.1.running_var': 'models.28.bn21.running_var',
140 'down3.resblock.module_list.0.0.conv.1.num_batches_tracked': 'models.28.bn21.num_batches_tracked',
141 'down3.resblock.module_list.0.1.conv.0.weight': 'models.29.conv22.weight',
142 'down3.resblock.module_list.0.1.conv.1.weight': 'models.29.bn22.weight',
143 'down3.resblock.module_list.0.1.conv.1.bias': 'models.29.bn22.bias',
144 'down3.resblock.module_list.0.1.conv.1.running_mean': 'models.29.bn22.running_mean',
145 'down3.resblock.module_list.0.1.conv.1.running_var': 'models.29.bn22.running_var',
146 'down3.resblock.module_list.0.1.conv.1.num_batches_tracked': 'models.29.bn22.num_batches_tracked',
147 'down3.resblock.module_list.1.0.conv.0.weight': 'models.31.conv23.weight',
148 'down3.resblock.module_list.1.0.conv.1.weight': 'models.31.bn23.weight',
149 'down3.resblock.module_list.1.0.conv.1.bias': 'models.31.bn23.bias',
150 'down3.resblock.module_list.1.0.conv.1.running_mean': 'models.31.bn23.running_mean',
151 'down3.resblock.module_list.1.0.conv.1.running_var': 'models.31.bn23.running_var',
152 'down3.resblock.module_list.1.0.conv.1.num_batches_tracked': 'models.31.bn23.num_batches_tracked',
153 'down3.resblock.module_list.1.1.conv.0.weight': 'models.32.conv24.weight',
154 'down3.resblock.module_list.1.1.conv.1.weight': 'models.32.bn24.weight',
155 'down3.resblock.module_list.1.1.conv.1.bias': 'models.32.bn24.bias',
156 'down3.resblock.module_list.1.1.conv.1.running_mean': 'models.32.bn24.running_mean',
157 'down3.resblock.module_list.1.1.conv.1.running_var': 'models.32.bn24.running_var',
158 'down3.resblock.module_list.1.1.conv.1.num_batches_tracked': 'models.32.bn24.num_batches_tracked',
159 'down3.resblock.module_list.2.0.conv.0.weight': 'models.34.conv25.weight',
160 'down3.resblock.module_list.2.0.conv.1.weight': 'models.34.bn25.weight',
161 'down3.resblock.module_list.2.0.conv.1.bias': 'models.34.bn25.bias',
162 'down3.resblock.module_list.2.0.conv.1.running_mean': 'models.34.bn25.running_mean',
163 'down3.resblock.module_list.2.0.conv.1.running_var': 'models.34.bn25.running_var',
164 'down3.resblock.module_list.2.0.conv.1.num_batches_tracked': 'models.34.bn25.num_batches_tracked',
165 'down3.resblock.module_list.2.1.conv.0.weight': 'models.35.conv26.weight',
166 'down3.resblock.module_list.2.1.conv.1.weight': 'models.35.bn26.weight',
167 'down3.resblock.module_list.2.1.conv.1.bias': 'models.35.bn26.bias',
168 'down3.resblock.module_list.2.1.conv.1.running_mean': 'models.35.bn26.running_mean',
169 'down3.resblock.module_list.2.1.conv.1.running_var': 'models.35.bn26.running_var',
170 'down3.resblock.module_list.2.1.conv.1.num_batches_tracked': 'models.35.bn26.num_batches_tracked',
171 'down3.resblock.module_list.3.0.conv.0.weight': 'models.37.conv27.weight',
172 'down3.resblock.module_list.3.0.conv.1.weight': 'models.37.bn27.weight',
173 'down3.resblock.module_list.3.0.conv.1.bias': 'models.37.bn27.bias',
174 'down3.resblock.module_list.3.0.conv.1.running_mean': 'models.37.bn27.running_mean',
175 'down3.resblock.module_list.3.0.conv.1.running_var': 'models.37.bn27.running_var',
176 'down3.resblock.module_list.3.0.conv.1.num_batches_tracked': 'models.37.bn27.num_batches_tracked',
177 'down3.resblock.module_list.3.1.conv.0.weight': 'models.38.conv28.weight',
178 'down3.resblock.module_list.3.1.conv.1.weight': 'models.38.bn28.weight',
179 'down3.resblock.module_list.3.1.conv.1.bias': 'models.38.bn28.bias',
180 'down3.resblock.module_list.3.1.conv.1.running_mean': 'models.38.bn28.running_mean',
181 'down3.resblock.module_list.3.1.conv.1.running_var': 'models.38.bn28.running_var',
182 'down3.resblock.module_list.3.1.conv.1.num_batches_tracked': 'models.38.bn28.num_batches_tracked',
183 'down3.resblock.module_list.4.0.conv.0.weight': 'models.40.conv29.weight',
184 'down3.resblock.module_list.4.0.conv.1.weight': 'models.40.bn29.weight',
185 'down3.resblock.module_list.4.0.conv.1.bias': 'models.40.bn29.bias',
186 'down3.resblock.module_list.4.0.conv.1.running_mean': 'models.40.bn29.running_mean',
187 'down3.resblock.module_list.4.0.conv.1.running_var': 'models.40.bn29.running_var',
188 'down3.resblock.module_list.4.0.conv.1.num_batches_tracked': 'models.40.bn29.num_batches_tracked',
189 'down3.resblock.module_list.4.1.conv.0.weight': 'models.41.conv30.weight',
190 'down3.resblock.module_list.4.1.conv.1.weight': 'models.41.bn30.weight',
191 'down3.resblock.module_list.4.1.conv.1.bias': 'models.41.bn30.bias',
192 'down3.resblock.module_list.4.1.conv.1.running_mean': 'models.41.bn30.running_mean',
193 'down3.resblock.module_list.4.1.conv.1.running_var': 'models.41.bn30.running_var',
194 'down3.resblock.module_list.4.1.conv.1.num_batches_tracked': 'models.41.bn30.num_batches_tracked',
195 'down3.resblock.module_list.5.0.conv.0.weight': 'models.43.conv31.weight',
196 'down3.resblock.module_list.5.0.conv.1.weight': 'models.43.bn31.weight',
197 'down3.resblock.module_list.5.0.conv.1.bias': 'models.43.bn31.bias',
198 'down3.resblock.module_list.5.0.conv.1.running_mean': 'models.43.bn31.running_mean',
199 'down3.resblock.module_list.5.0.conv.1.running_var': 'models.43.bn31.running_var',
200 'down3.resblock.module_list.5.0.conv.1.num_batches_tracked': 'models.43.bn31.num_batches_tracked',
201 'down3.resblock.module_list.5.1.conv.0.weight': 'models.44.conv32.weight',
202 'down3.resblock.module_list.5.1.conv.1.weight': 'models.44.bn32.weight',
203 'down3.resblock.module_list.5.1.conv.1.bias': 'models.44.bn32.bias',
204 'down3.resblock.module_list.5.1.conv.1.running_mean': 'models.44.bn32.running_mean',
205 'down3.resblock.module_list.5.1.conv.1.running_var': 'models.44.bn32.running_var',
206 'down3.resblock.module_list.5.1.conv.1.num_batches_tracked': 'models.44.bn32.num_batches_tracked',
207 'down3.resblock.module_list.6.0.conv.0.weight': 'models.46.conv33.weight',
208 'down3.resblock.module_list.6.0.conv.1.weight': 'models.46.bn33.weight',
209 'down3.resblock.module_list.6.0.conv.1.bias': 'models.46.bn33.bias',
210 'down3.resblock.module_list.6.0.conv.1.running_mean': 'models.46.bn33.running_mean',
211 'down3.resblock.module_list.6.0.conv.1.running_var': 'models.46.bn33.running_var',
212 'down3.resblock.module_list.6.0.conv.1.num_batches_tracked': 'models.46.bn33.num_batches_tracked',
213 'down3.resblock.module_list.6.1.conv.0.weight': 'models.47.conv34.weight',
214 'down3.resblock.module_list.6.1.conv.1.weight': 'models.47.bn34.weight',
215 'down3.resblock.module_list.6.1.conv.1.bias': 'models.47.bn34.bias',
216 'down3.resblock.module_list.6.1.conv.1.running_mean': 'models.47.bn34.running_mean',
217 'down3.resblock.module_list.6.1.conv.1.running_var': 'models.47.bn34.running_var',
218 'down3.resblock.module_list.6.1.conv.1.num_batches_tracked': 'models.47.bn34.num_batches_tracked',
219 'down3.resblock.module_list.7.0.conv.0.weight': 'models.49.conv35.weight',
220 'down3.resblock.module_list.7.0.conv.1.weight': 'models.49.bn35.weight',
221 'down3.resblock.module_list.7.0.conv.1.bias': 'models.49.bn35.bias',
222 'down3.resblock.module_list.7.0.conv.1.running_mean': 'models.49.bn35.running_mean',
223 'down3.resblock.module_list.7.0.conv.1.running_var': 'models.49.bn35.running_var',
224 'down3.resblock.module_list.7.0.conv.1.num_batches_tracked': 'models.49.bn35.num_batches_tracked',
225 'down3.resblock.module_list.7.1.conv.0.weight': 'models.50.conv36.weight',
226 'down3.resblock.module_list.7.1.conv.1.weight': 'models.50.bn36.weight',
227 'down3.resblock.module_list.7.1.conv.1.bias': 'models.50.bn36.bias',
228 'down3.resblock.module_list.7.1.conv.1.running_mean': 'models.50.bn36.running_mean',
229 'down3.resblock.module_list.7.1.conv.1.running_var': 'models.50.bn36.running_var',
230 'down3.resblock.module_list.7.1.conv.1.num_batches_tracked': 'models.50.bn36.num_batches_tracked',
231 'down3.conv4.conv.0.weight': 'models.52.conv37.weight',
232 'down3.conv4.conv.1.weight': 'models.52.bn37.weight',
233 'down3.conv4.conv.1.bias': 'models.52.bn37.bias',
234 'down3.conv4.conv.1.running_mean': 'models.52.bn37.running_mean',
235 'down3.conv4.conv.1.running_var': 'models.52.bn37.running_var',
236 'down3.conv4.conv.1.num_batches_tracked': 'models.52.bn37.num_batches_tracked',
237 'down3.conv5.conv.0.weight': 'models.54.conv38.weight',
238 'down3.conv5.conv.1.weight': 'models.54.bn38.weight',
239 'down3.conv5.conv.1.bias': 'models.54.bn38.bias',
240 'down3.conv5.conv.1.running_mean': 'models.54.bn38.running_mean',
241 'down3.conv5.conv.1.running_var': 'models.54.bn38.running_var',
242 'down3.conv5.conv.1.num_batches_tracked': 'models.54.bn38.num_batches_tracked',
243 'down4.conv1.conv.0.weight': 'models.55.conv39.weight',
244 'down4.conv1.conv.1.weight': 'models.55.bn39.weight',
245 'down4.conv1.conv.1.bias': 'models.55.bn39.bias',
246 'down4.conv1.conv.1.running_mean': 'models.55.bn39.running_mean',
247 'down4.conv1.conv.1.running_var': 'models.55.bn39.running_var',
248 'down4.conv1.conv.1.num_batches_tracked': 'models.55.bn39.num_batches_tracked',
249 'down4.conv2.conv.0.weight': 'models.56.conv40.weight',
250 'down4.conv2.conv.1.weight': 'models.56.bn40.weight',
251 'down4.conv2.conv.1.bias': 'models.56.bn40.bias',
252 'down4.conv2.conv.1.running_mean': 'models.56.bn40.running_mean',
253 'down4.conv2.conv.1.running_var': 'models.56.bn40.running_var',
254 'down4.conv2.conv.1.num_batches_tracked': 'models.56.bn40.num_batches_tracked',
255 'down4.conv3.conv.0.weight': 'models.58.conv41.weight',
256 'down4.conv3.conv.1.weight': 'models.58.bn41.weight',
257 'down4.conv3.conv.1.bias': 'models.58.bn41.bias',
258 'down4.conv3.conv.1.running_mean': 'models.58.bn41.running_mean',
259 'down4.conv3.conv.1.running_var': 'models.58.bn41.running_var',
260 'down4.conv3.conv.1.num_batches_tracked': 'models.58.bn41.num_batches_tracked',
261 'down4.resblock.module_list.0.0.conv.0.weight': 'models.59.conv42.weight',
262 'down4.resblock.module_list.0.0.conv.1.weight': 'models.59.bn42.weight',
263 'down4.resblock.module_list.0.0.conv.1.bias': 'models.59.bn42.bias',
264 'down4.resblock.module_list.0.0.conv.1.running_mean': 'models.59.bn42.running_mean',
265 'down4.resblock.module_list.0.0.conv.1.running_var': 'models.59.bn42.running_var',
266 'down4.resblock.module_list.0.0.conv.1.num_batches_tracked': 'models.59.bn42.num_batches_tracked',
267 'down4.resblock.module_list.0.1.conv.0.weight': 'models.60.conv43.weight',
268 'down4.resblock.module_list.0.1.conv.1.weight': 'models.60.bn43.weight',
269 'down4.resblock.module_list.0.1.conv.1.bias': 'models.60.bn43.bias',
270 'down4.resblock.module_list.0.1.conv.1.running_mean': 'models.60.bn43.running_mean',
271 'down4.resblock.module_list.0.1.conv.1.running_var': 'models.60.bn43.running_var',
272 'down4.resblock.module_list.0.1.conv.1.num_batches_tracked': 'models.60.bn43.num_batches_tracked',
273 'down4.resblock.module_list.1.0.conv.0.weight': 'models.62.conv44.weight',
274 'down4.resblock.module_list.1.0.conv.1.weight': 'models.62.bn44.weight',
275 'down4.resblock.module_list.1.0.conv.1.bias': 'models.62.bn44.bias',
276 'down4.resblock.module_list.1.0.conv.1.running_mean': 'models.62.bn44.running_mean',
277 'down4.resblock.module_list.1.0.conv.1.running_var': 'models.62.bn44.running_var',
278 'down4.resblock.module_list.1.0.conv.1.num_batches_tracked': 'models.62.bn44.num_batches_tracked',
279 'down4.resblock.module_list.1.1.conv.0.weight': 'models.63.conv45.weight',
280 'down4.resblock.module_list.1.1.conv.1.weight': 'models.63.bn45.weight',
281 'down4.resblock.module_list.1.1.conv.1.bias': 'models.63.bn45.bias',
282 'down4.resblock.module_list.1.1.conv.1.running_mean': 'models.63.bn45.running_mean',
283 'down4.resblock.module_list.1.1.conv.1.running_var': 'models.63.bn45.running_var',
284 'down4.resblock.module_list.1.1.conv.1.num_batches_tracked': 'models.63.bn45.num_batches_tracked',
285 'down4.resblock.module_list.2.0.conv.0.weight': 'models.65.conv46.weight',
286 'down4.resblock.module_list.2.0.conv.1.weight': 'models.65.bn46.weight',
287 'down4.resblock.module_list.2.0.conv.1.bias': 'models.65.bn46.bias',
288 'down4.resblock.module_list.2.0.conv.1.running_mean': 'models.65.bn46.running_mean',
289 'down4.resblock.module_list.2.0.conv.1.running_var': 'models.65.bn46.running_var',
290 'down4.resblock.module_list.2.0.conv.1.num_batches_tracked': 'models.65.bn46.num_batches_tracked',
291 'down4.resblock.module_list.2.1.conv.0.weight': 'models.66.conv47.weight',
292 'down4.resblock.module_list.2.1.conv.1.weight': 'models.66.bn47.weight',
293 'down4.resblock.module_list.2.1.conv.1.bias': 'models.66.bn47.bias',
294 'down4.resblock.module_list.2.1.conv.1.running_mean': 'models.66.bn47.running_mean',
295 'down4.resblock.module_list.2.1.conv.1.running_var': 'models.66.bn47.running_var',
296 'down4.resblock.module_list.2.1.conv.1.num_batches_tracked': 'models.66.bn47.num_batches_tracked',
297 'down4.resblock.module_list.3.0.conv.0.weight': 'models.68.conv48.weight',
298 'down4.resblock.module_list.3.0.conv.1.weight': 'models.68.bn48.weight',
299 'down4.resblock.module_list.3.0.conv.1.bias': 'models.68.bn48.bias',
300 'down4.resblock.module_list.3.0.conv.1.running_mean': 'models.68.bn48.running_mean',
301 'down4.resblock.module_list.3.0.conv.1.running_var': 'models.68.bn48.running_var',
302 'down4.resblock.module_list.3.0.conv.1.num_batches_tracked': 'models.68.bn48.num_batches_tracked',
303 'down4.resblock.module_list.3.1.conv.0.weight': 'models.69.conv49.weight',
304 'down4.resblock.module_list.3.1.conv.1.weight': 'models.69.bn49.weight',
305 'down4.resblock.module_list.3.1.conv.1.bias': 'models.69.bn49.bias',
306 'down4.resblock.module_list.3.1.conv.1.running_mean': 'models.69.bn49.running_mean',
307 'down4.resblock.module_list.3.1.conv.1.running_var': 'models.69.bn49.running_var',
308 'down4.resblock.module_list.3.1.conv.1.num_batches_tracked': 'models.69.bn49.num_batches_tracked',
309 'down4.resblock.module_list.4.0.conv.0.weight': 'models.71.conv50.weight',
310 'down4.resblock.module_list.4.0.conv.1.weight': 'models.71.bn50.weight',
311 'down4.resblock.module_list.4.0.conv.1.bias': 'models.71.bn50.bias',
312 'down4.resblock.module_list.4.0.conv.1.running_mean': 'models.71.bn50.running_mean',
313 'down4.resblock.module_list.4.0.conv.1.running_var': 'models.71.bn50.running_var',
314 'down4.resblock.module_list.4.0.conv.1.num_batches_tracked': 'models.71.bn50.num_batches_tracked',
315 'down4.resblock.module_list.4.1.conv.0.weight': 'models.72.conv51.weight',
316 'down4.resblock.module_list.4.1.conv.1.weight': 'models.72.bn51.weight',
317 'down4.resblock.module_list.4.1.conv.1.bias': 'models.72.bn51.bias',
318 'down4.resblock.module_list.4.1.conv.1.running_mean': 'models.72.bn51.running_mean',
319 'down4.resblock.module_list.4.1.conv.1.running_var': 'models.72.bn51.running_var',
320 'down4.resblock.module_list.4.1.conv.1.num_batches_tracked': 'models.72.bn51.num_batches_tracked',
321 'down4.resblock.module_list.5.0.conv.0.weight': 'models.74.conv52.weight',
322 'down4.resblock.module_list.5.0.conv.1.weight': 'models.74.bn52.weight',
323 'down4.resblock.module_list.5.0.conv.1.bias': 'models.74.bn52.bias',
324 'down4.resblock.module_list.5.0.conv.1.running_mean': 'models.74.bn52.running_mean',
325 'down4.resblock.module_list.5.0.conv.1.running_var': 'models.74.bn52.running_var',
326 'down4.resblock.module_list.5.0.conv.1.num_batches_tracked': 'models.74.bn52.num_batches_tracked',
327 'down4.resblock.module_list.5.1.conv.0.weight': 'models.75.conv53.weight',
328 'down4.resblock.module_list.5.1.conv.1.weight': 'models.75.bn53.weight',
329 'down4.resblock.module_list.5.1.conv.1.bias': 'models.75.bn53.bias',
330 'down4.resblock.module_list.5.1.conv.1.running_mean': 'models.75.bn53.running_mean',
331 'down4.resblock.module_list.5.1.conv.1.running_var': 'models.75.bn53.running_var',
332 'down4.resblock.module_list.5.1.conv.1.num_batches_tracked': 'models.75.bn53.num_batches_tracked',
333 'down4.resblock.module_list.6.0.conv.0.weight': 'models.77.conv54.weight',
334 'down4.resblock.module_list.6.0.conv.1.weight': 'models.77.bn54.weight',
335 'down4.resblock.module_list.6.0.conv.1.bias': 'models.77.bn54.bias',
336 'down4.resblock.module_list.6.0.conv.1.running_mean': 'models.77.bn54.running_mean',
337 'down4.resblock.module_list.6.0.conv.1.running_var': 'models.77.bn54.running_var',
338 'down4.resblock.module_list.6.0.conv.1.num_batches_tracked': 'models.77.bn54.num_batches_tracked',
339 'down4.resblock.module_list.6.1.conv.0.weight': 'models.78.conv55.weight',
340 'down4.resblock.module_list.6.1.conv.1.weight': 'models.78.bn55.weight',
341 'down4.resblock.module_list.6.1.conv.1.bias': 'models.78.bn55.bias',
342 'down4.resblock.module_list.6.1.conv.1.running_mean': 'models.78.bn55.running_mean',
343 'down4.resblock.module_list.6.1.conv.1.running_var': 'models.78.bn55.running_var',
344 'down4.resblock.module_list.6.1.conv.1.num_batches_tracked': 'models.78.bn55.num_batches_tracked',
345 'down4.resblock.module_list.7.0.conv.0.weight': 'models.80.conv56.weight',
346 'down4.resblock.module_list.7.0.conv.1.weight': 'models.80.bn56.weight',
347 'down4.resblock.module_list.7.0.conv.1.bias': 'models.80.bn56.bias',
348 'down4.resblock.module_list.7.0.conv.1.running_mean': 'models.80.bn56.running_mean',
349 'down4.resblock.module_list.7.0.conv.1.running_var': 'models.80.bn56.running_var',
350 'down4.resblock.module_list.7.0.conv.1.num_batches_tracked': 'models.80.bn56.num_batches_tracked',
351 'down4.resblock.module_list.7.1.conv.0.weight': 'models.81.conv57.weight',
352 'down4.resblock.module_list.7.1.conv.1.weight': 'models.81.bn57.weight',
353 'down4.resblock.module_list.7.1.conv.1.bias': 'models.81.bn57.bias',
354 'down4.resblock.module_list.7.1.conv.1.running_mean': 'models.81.bn57.running_mean',
355 'down4.resblock.module_list.7.1.conv.1.running_var': 'models.81.bn57.running_var',
356 'down4.resblock.module_list.7.1.conv.1.num_batches_tracked': 'models.81.bn57.num_batches_tracked',
357 'down4.conv4.conv.0.weight': 'models.83.conv58.weight',
358 'down4.conv4.conv.1.weight': 'models.83.bn58.weight',
359 'down4.conv4.conv.1.bias': 'models.83.bn58.bias',
360 'down4.conv4.conv.1.running_mean': 'models.83.bn58.running_mean',
361 'down4.conv4.conv.1.running_var': 'models.83.bn58.running_var',
362 'down4.conv4.conv.1.num_batches_tracked': 'models.83.bn58.num_batches_tracked',
363 'down4.conv5.conv.0.weight': 'models.85.conv59.weight',
364 'down4.conv5.conv.1.weight': 'models.85.bn59.weight',
365 'down4.conv5.conv.1.bias': 'models.85.bn59.bias',
366 'down4.conv5.conv.1.running_mean': 'models.85.bn59.running_mean',
367 'down4.conv5.conv.1.running_var': 'models.85.bn59.running_var',
368 'down4.conv5.conv.1.num_batches_tracked': 'models.85.bn59.num_batches_tracked',
369 'down5.conv1.conv.0.weight': 'models.86.conv60.weight',
370 'down5.conv1.conv.1.weight': 'models.86.bn60.weight',
371 'down5.conv1.conv.1.bias': 'models.86.bn60.bias',
372 'down5.conv1.conv.1.running_mean': 'models.86.bn60.running_mean',
373 'down5.conv1.conv.1.running_var': 'models.86.bn60.running_var',
374 'down5.conv1.conv.1.num_batches_tracked': 'models.86.bn60.num_batches_tracked',
375 'down5.conv2.conv.0.weight': 'models.87.conv61.weight',
376 'down5.conv2.conv.1.weight': 'models.87.bn61.weight',
377 'down5.conv2.conv.1.bias': 'models.87.bn61.bias',
378 'down5.conv2.conv.1.running_mean': 'models.87.bn61.running_mean',
379 'down5.conv2.conv.1.running_var': 'models.87.bn61.running_var',
380 'down5.conv2.conv.1.num_batches_tracked': 'models.87.bn61.num_batches_tracked',
381 'down5.conv3.conv.0.weight': 'models.89.conv62.weight',
382 'down5.conv3.conv.1.weight': 'models.89.bn62.weight',
383 'down5.conv3.conv.1.bias': 'models.89.bn62.bias',
384 'down5.conv3.conv.1.running_mean': 'models.89.bn62.running_mean',
385 'down5.conv3.conv.1.running_var': 'models.89.bn62.running_var',
386 'down5.conv3.conv.1.num_batches_tracked': 'models.89.bn62.num_batches_tracked',
387 'down5.resblock.module_list.0.0.conv.0.weight': 'models.90.conv63.weight',
388 'down5.resblock.module_list.0.0.conv.1.weight': 'models.90.bn63.weight',
389 'down5.resblock.module_list.0.0.conv.1.bias': 'models.90.bn63.bias',
390 'down5.resblock.module_list.0.0.conv.1.running_mean': 'models.90.bn63.running_mean',
391 'down5.resblock.module_list.0.0.conv.1.running_var': 'models.90.bn63.running_var',
392 'down5.resblock.module_list.0.0.conv.1.num_batches_tracked': 'models.90.bn63.num_batches_tracked',
393 'down5.resblock.module_list.0.1.conv.0.weight': 'models.91.conv64.weight',
394 'down5.resblock.module_list.0.1.conv.1.weight': 'models.91.bn64.weight',
395 'down5.resblock.module_list.0.1.conv.1.bias': 'models.91.bn64.bias',
396 'down5.resblock.module_list.0.1.conv.1.running_mean': 'models.91.bn64.running_mean',
397 'down5.resblock.module_list.0.1.conv.1.running_var': 'models.91.bn64.running_var',
398 'down5.resblock.module_list.0.1.conv.1.num_batches_tracked': 'models.91.bn64.num_batches_tracked',
399 'down5.resblock.module_list.1.0.conv.0.weight': 'models.93.conv65.weight',
400 'down5.resblock.module_list.1.0.conv.1.weight': 'models.93.bn65.weight',
401 'down5.resblock.module_list.1.0.conv.1.bias': 'models.93.bn65.bias',
402 'down5.resblock.module_list.1.0.conv.1.running_mean': 'models.93.bn65.running_mean',
403 'down5.resblock.module_list.1.0.conv.1.running_var': 'models.93.bn65.running_var',
404 'down5.resblock.module_list.1.0.conv.1.num_batches_tracked': 'models.93.bn65.num_batches_tracked',
405 'down5.resblock.module_list.1.1.conv.0.weight': 'models.94.conv66.weight',
406 'down5.resblock.module_list.1.1.conv.1.weight': 'models.94.bn66.weight',
407 'down5.resblock.module_list.1.1.conv.1.bias': 'models.94.bn66.bias',
408 'down5.resblock.module_list.1.1.conv.1.running_mean': 'models.94.bn66.running_mean',
409 'down5.resblock.module_list.1.1.conv.1.running_var': 'models.94.bn66.running_var',
410 'down5.resblock.module_list.1.1.conv.1.num_batches_tracked': 'models.94.bn66.num_batches_tracked',
411 'down5.resblock.module_list.2.0.conv.0.weight': 'models.96.conv67.weight',
412 'down5.resblock.module_list.2.0.conv.1.weight': 'models.96.bn67.weight',
413 'down5.resblock.module_list.2.0.conv.1.bias': 'models.96.bn67.bias',
414 'down5.resblock.module_list.2.0.conv.1.running_mean': 'models.96.bn67.running_mean',
415 'down5.resblock.module_list.2.0.conv.1.running_var': 'models.96.bn67.running_var',
416 'down5.resblock.module_list.2.0.conv.1.num_batches_tracked': 'models.96.bn67.num_batches_tracked',
417 'down5.resblock.module_list.2.1.conv.0.weight': 'models.97.conv68.weight',
418 'down5.resblock.module_list.2.1.conv.1.weight': 'models.97.bn68.weight',
419 'down5.resblock.module_list.2.1.conv.1.bias': 'models.97.bn68.bias',
420 'down5.resblock.module_list.2.1.conv.1.running_mean': 'models.97.bn68.running_mean',
421 'down5.resblock.module_list.2.1.conv.1.running_var': 'models.97.bn68.running_var',
422 'down5.resblock.module_list.2.1.conv.1.num_batches_tracked': 'models.97.bn68.num_batches_tracked',
423 'down5.resblock.module_list.3.0.conv.0.weight': 'models.99.conv69.weight',
424 'down5.resblock.module_list.3.0.conv.1.weight': 'models.99.bn69.weight',
425 'down5.resblock.module_list.3.0.conv.1.bias': 'models.99.bn69.bias',
426 'down5.resblock.module_list.3.0.conv.1.running_mean': 'models.99.bn69.running_mean',
427 'down5.resblock.module_list.3.0.conv.1.running_var': 'models.99.bn69.running_var',
428 'down5.resblock.module_list.3.0.conv.1.num_batches_tracked': 'models.99.bn69.num_batches_tracked',
429 'down5.resblock.module_list.3.1.conv.0.weight': 'models.100.conv70.weight',
430 'down5.resblock.module_list.3.1.conv.1.weight': 'models.100.bn70.weight',
431 'down5.resblock.module_list.3.1.conv.1.bias': 'models.100.bn70.bias',
432 'down5.resblock.module_list.3.1.conv.1.running_mean': 'models.100.bn70.running_mean',
433 'down5.resblock.module_list.3.1.conv.1.running_var': 'models.100.bn70.running_var',
434 'down5.resblock.module_list.3.1.conv.1.num_batches_tracked': 'models.100.bn70.num_batches_tracked',
435 'down5.conv4.conv.0.weight': 'models.102.conv71.weight',
436 'down5.conv4.conv.1.weight': 'models.102.bn71.weight',
437 'down5.conv4.conv.1.bias': 'models.102.bn71.bias',
438 'down5.conv4.conv.1.running_mean': 'models.102.bn71.running_mean',
439 'down5.conv4.conv.1.running_var': 'models.102.bn71.running_var',
440 'down5.conv4.conv.1.num_batches_tracked': 'models.102.bn71.num_batches_tracked',
441 'down5.conv5.conv.0.weight': 'models.104.conv72.weight',
442 'down5.conv5.conv.1.weight': 'models.104.bn72.weight',
443 'down5.conv5.conv.1.bias': 'models.104.bn72.bias',
444 'down5.conv5.conv.1.running_mean': 'models.104.bn72.running_mean',
445 'down5.conv5.conv.1.running_var': 'models.104.bn72.running_var',
446 'down5.conv5.conv.1.num_batches_tracked': 'models.104.bn72.num_batches_tracked',
447 'neek.conv1.conv.0.weight': 'models.105.conv73.weight',
448 'neek.conv1.conv.1.weight': 'models.105.bn73.weight',
449 'neek.conv1.conv.1.bias': 'models.105.bn73.bias',
450 'neek.conv1.conv.1.running_mean': 'models.105.bn73.running_mean',
451 'neek.conv1.conv.1.running_var': 'models.105.bn73.running_var',
452 'neek.conv1.conv.1.num_batches_tracked': 'models.105.bn73.num_batches_tracked',
453 'neek.conv2.conv.0.weight': 'models.106.conv74.weight',
454 'neek.conv2.conv.1.weight': 'models.106.bn74.weight',
455 'neek.conv2.conv.1.bias': 'models.106.bn74.bias',
456 'neek.conv2.conv.1.running_mean': 'models.106.bn74.running_mean',
457 'neek.conv2.conv.1.running_var': 'models.106.bn74.running_var',
458 'neek.conv2.conv.1.num_batches_tracked': 'models.106.bn74.num_batches_tracked',
459 'neek.conv3.conv.0.weight': 'models.107.conv75.weight',
460 'neek.conv3.conv.1.weight': 'models.107.bn75.weight',
461 'neek.conv3.conv.1.bias': 'models.107.bn75.bias',
462 'neek.conv3.conv.1.running_mean': 'models.107.bn75.running_mean',
463 'neek.conv3.conv.1.running_var': 'models.107.bn75.running_var',
464 'neek.conv3.conv.1.num_batches_tracked': 'models.107.bn75.num_batches_tracked',
465 'neek.conv4.conv.0.weight': 'models.114.conv76.weight',
466 'neek.conv4.conv.1.weight': 'models.114.bn76.weight',
467 'neek.conv4.conv.1.bias': 'models.114.bn76.bias',
468 'neek.conv4.conv.1.running_mean': 'models.114.bn76.running_mean',
469 'neek.conv4.conv.1.running_var': 'models.114.bn76.running_var',
470 'neek.conv4.conv.1.num_batches_tracked': 'models.114.bn76.num_batches_tracked',
471 'neek.conv5.conv.0.weight': 'models.115.conv77.weight',
472 'neek.conv5.conv.1.weight': 'models.115.bn77.weight',
473 'neek.conv5.conv.1.bias': 'models.115.bn77.bias',
474 'neek.conv5.conv.1.running_mean': 'models.115.bn77.running_mean',
475 'neek.conv5.conv.1.running_var': 'models.115.bn77.running_var',
476 'neek.conv5.conv.1.num_batches_tracked': 'models.115.bn77.num_batches_tracked',
477 'neek.conv6.conv.0.weight': 'models.116.conv78.weight',
478 'neek.conv6.conv.1.weight': 'models.116.bn78.weight',
479 'neek.conv6.conv.1.bias': 'models.116.bn78.bias',
480 'neek.conv6.conv.1.running_mean': 'models.116.bn78.running_mean',
481 'neek.conv6.conv.1.running_var': 'models.116.bn78.running_var',
482 'neek.conv6.conv.1.num_batches_tracked': 'models.116.bn78.num_batches_tracked',
483 'neek.conv7.conv.0.weight': 'models.117.conv79.weight',
484 'neek.conv7.conv.1.weight': 'models.117.bn79.weight',
485 'neek.conv7.conv.1.bias': 'models.117.bn79.bias',
486 'neek.conv7.conv.1.running_mean': 'models.117.bn79.running_mean',
487 'neek.conv7.conv.1.running_var': 'models.117.bn79.running_var',
488 'neek.conv7.conv.1.num_batches_tracked': 'models.117.bn79.num_batches_tracked',
489 'neek.conv8.conv.0.weight': 'models.120.conv80.weight',
490 'neek.conv8.conv.1.weight': 'models.120.bn80.weight',
491 'neek.conv8.conv.1.bias': 'models.120.bn80.bias',
492 'neek.conv8.conv.1.running_mean': 'models.120.bn80.running_mean',
493 'neek.conv8.conv.1.running_var': 'models.120.bn80.running_var',
494 'neek.conv8.conv.1.num_batches_tracked': 'models.120.bn80.num_batches_tracked',
495 'neek.conv9.conv.0.weight': 'models.122.conv81.weight',
496 'neek.conv9.conv.1.weight': 'models.122.bn81.weight',
497 'neek.conv9.conv.1.bias': 'models.122.bn81.bias',
498 'neek.conv9.conv.1.running_mean': 'models.122.bn81.running_mean',
499 'neek.conv9.conv.1.running_var': 'models.122.bn81.running_var',
500 'neek.conv9.conv.1.num_batches_tracked': 'models.122.bn81.num_batches_tracked',
501 'neek.conv10.conv.0.weight': 'models.123.conv82.weight',
502 'neek.conv10.conv.1.weight': 'models.123.bn82.weight',
503 'neek.conv10.conv.1.bias': 'models.123.bn82.bias',
504 'neek.conv10.conv.1.running_mean': 'models.123.bn82.running_mean',
505 'neek.conv10.conv.1.running_var': 'models.123.bn82.running_var',
506 'neek.conv10.conv.1.num_batches_tracked': 'models.123.bn82.num_batches_tracked',
507 'neek.conv11.conv.0.weight': 'models.124.conv83.weight',
508 'neek.conv11.conv.1.weight': 'models.124.bn83.weight',
509 'neek.conv11.conv.1.bias': 'models.124.bn83.bias',
510 'neek.conv11.conv.1.running_mean': 'models.124.bn83.running_mean',
511 'neek.conv11.conv.1.running_var': 'models.124.bn83.running_var',
512 'neek.conv11.conv.1.num_batches_tracked': 'models.124.bn83.num_batches_tracked',
513 'neek.conv12.conv.0.weight': 'models.125.conv84.weight',
514 'neek.conv12.conv.1.weight': 'models.125.bn84.weight',
515 'neek.conv12.conv.1.bias': 'models.125.bn84.bias',
516 'neek.conv12.conv.1.running_mean': 'models.125.bn84.running_mean',
517 'neek.conv12.conv.1.running_var': 'models.125.bn84.running_var',
518 'neek.conv12.conv.1.num_batches_tracked': 'models.125.bn84.num_batches_tracked',
519 'neek.conv13.conv.0.weight': 'models.126.conv85.weight',
520 'neek.conv13.conv.1.weight': 'models.126.bn85.weight',
521 'neek.conv13.conv.1.bias': 'models.126.bn85.bias',
522 'neek.conv13.conv.1.running_mean': 'models.126.bn85.running_mean',
523 'neek.conv13.conv.1.running_var': 'models.126.bn85.running_var',
524 'neek.conv13.conv.1.num_batches_tracked': 'models.126.bn85.num_batches_tracked',
525 'neek.conv14.conv.0.weight': 'models.127.conv86.weight',
526 'neek.conv14.conv.1.weight': 'models.127.bn86.weight',
527 'neek.conv14.conv.1.bias': 'models.127.bn86.bias',
528 'neek.conv14.conv.1.running_mean': 'models.127.bn86.running_mean',
529 'neek.conv14.conv.1.running_var': 'models.127.bn86.running_var',
530 'neek.conv14.conv.1.num_batches_tracked': 'models.127.bn86.num_batches_tracked',
531 'neek.conv15.conv.0.weight': 'models.130.conv87.weight',
532 'neek.conv15.conv.1.weight': 'models.130.bn87.weight',
533 'neek.conv15.conv.1.bias': 'models.130.bn87.bias',
534 'neek.conv15.conv.1.running_mean': 'models.130.bn87.running_mean',
535 'neek.conv15.conv.1.running_var': 'models.130.bn87.running_var',
536 'neek.conv15.conv.1.num_batches_tracked': 'models.130.bn87.num_batches_tracked',
537 'neek.conv16.conv.0.weight': 'models.132.conv88.weight',
538 'neek.conv16.conv.1.weight': 'models.132.bn88.weight',
539 'neek.conv16.conv.1.bias': 'models.132.bn88.bias',
540 'neek.conv16.conv.1.running_mean': 'models.132.bn88.running_mean',
541 'neek.conv16.conv.1.running_var': 'models.132.bn88.running_var',
542 'neek.conv16.conv.1.num_batches_tracked': 'models.132.bn88.num_batches_tracked',
543 'neek.conv17.conv.0.weight': 'models.133.conv89.weight',
544 'neek.conv17.conv.1.weight': 'models.133.bn89.weight',
545 'neek.conv17.conv.1.bias': 'models.133.bn89.bias',
546 'neek.conv17.conv.1.running_mean': 'models.133.bn89.running_mean',
547 'neek.conv17.conv.1.running_var': 'models.133.bn89.running_var',
548 'neek.conv17.conv.1.num_batches_tracked': 'models.133.bn89.num_batches_tracked',
549 'neek.conv18.conv.0.weight': 'models.134.conv90.weight',
550 'neek.conv18.conv.1.weight': 'models.134.bn90.weight',
551 'neek.conv18.conv.1.bias': 'models.134.bn90.bias',
552 'neek.conv18.conv.1.running_mean': 'models.134.bn90.running_mean',
553 'neek.conv18.conv.1.running_var': 'models.134.bn90.running_var',
554 'neek.conv18.conv.1.num_batches_tracked': 'models.134.bn90.num_batches_tracked',
555 'neek.conv19.conv.0.weight': 'models.135.conv91.weight',
556 'neek.conv19.conv.1.weight': 'models.135.bn91.weight',
557 'neek.conv19.conv.1.bias': 'models.135.bn91.bias',
558 'neek.conv19.conv.1.running_mean': 'models.135.bn91.running_mean',
559 'neek.conv19.conv.1.running_var': 'models.135.bn91.running_var',
560 'neek.conv19.conv.1.num_batches_tracked': 'models.135.bn91.num_batches_tracked',
561 'neek.conv20.conv.0.weight': 'models.136.conv92.weight',
562 'neek.conv20.conv.1.weight': 'models.136.bn92.weight',
563 'neek.conv20.conv.1.bias': 'models.136.bn92.bias',
564 'neek.conv20.conv.1.running_mean': 'models.136.bn92.running_mean',
565 'neek.conv20.conv.1.running_var': 'models.136.bn92.running_var',
566 'neek.conv20.conv.1.num_batches_tracked': 'models.136.bn92.num_batches_tracked',
567 'head.conv1.conv.0.weight': 'models.137.conv93.weight',
568 'head.conv1.conv.1.weight': 'models.137.bn93.weight',
569 'head.conv1.conv.1.bias': 'models.137.bn93.bias',
570 'head.conv1.conv.1.running_mean': 'models.137.bn93.running_mean',
571 'head.conv1.conv.1.running_var': 'models.137.bn93.running_var',
572 'head.conv1.conv.1.num_batches_tracked': 'models.137.bn93.num_batches_tracked',
573 'head.conv2.conv.0.weight': 'models.138.conv94.weight',
574 'head.conv2.conv.0.bias': 'models.138.conv94.bias',
575 'head.conv3.conv.0.weight': 'models.141.conv95.weight',
576 'head.conv3.conv.1.weight': 'models.141.bn95.weight',
577 'head.conv3.conv.1.bias': 'models.141.bn95.bias',
578 'head.conv3.conv.1.running_mean': 'models.141.bn95.running_mean',
579 'head.conv3.conv.1.running_var': 'models.141.bn95.running_var',
580 'head.conv3.conv.1.num_batches_tracked': 'models.141.bn95.num_batches_tracked',
581 'head.conv4.conv.0.weight': 'models.143.conv96.weight',
582 'head.conv4.conv.1.weight': 'models.143.bn96.weight',
583 'head.conv4.conv.1.bias': 'models.143.bn96.bias',
584 'head.conv4.conv.1.running_mean': 'models.143.bn96.running_mean',
585 'head.conv4.conv.1.running_var': 'models.143.bn96.running_var',
586 'head.conv4.conv.1.num_batches_tracked': 'models.143.bn96.num_batches_tracked',
587 'head.conv5.conv.0.weight': 'models.144.conv97.weight',
588 'head.conv5.conv.1.weight': 'models.144.bn97.weight',
589 'head.conv5.conv.1.bias': 'models.144.bn97.bias',
590 'head.conv5.conv.1.running_mean': 'models.144.bn97.running_mean',
591 'head.conv5.conv.1.running_var': 'models.144.bn97.running_var',
592 'head.conv5.conv.1.num_batches_tracked': 'models.144.bn97.num_batches_tracked',
593 'head.conv6.conv.0.weight': 'models.145.conv98.weight',
594 'head.conv6.conv.1.weight': 'models.145.bn98.weight',
595 'head.conv6.conv.1.bias': 'models.145.bn98.bias',
596 'head.conv6.conv.1.running_mean': 'models.145.bn98.running_mean',
597 'head.conv6.conv.1.running_var': 'models.145.bn98.running_var',
598 'head.conv6.conv.1.num_batches_tracked': 'models.145.bn98.num_batches_tracked',
599 'head.conv7.conv.0.weight': 'models.146.conv99.weight',
600 'head.conv7.conv.1.weight': 'models.146.bn99.weight',
601 'head.conv7.conv.1.bias': 'models.146.bn99.bias',
602 'head.conv7.conv.1.running_mean': 'models.146.bn99.running_mean',
603 'head.conv7.conv.1.running_var': 'models.146.bn99.running_var',
604 'head.conv7.conv.1.num_batches_tracked': 'models.146.bn99.num_batches_tracked',
605 'head.conv8.conv.0.weight': 'models.147.conv100.weight',
606 'head.conv8.conv.1.weight': 'models.147.bn100.weight',
607 'head.conv8.conv.1.bias': 'models.147.bn100.bias',
608 'head.conv8.conv.1.running_mean': 'models.147.bn100.running_mean',
609 'head.conv8.conv.1.running_var': 'models.147.bn100.running_var',
610 'head.conv8.conv.1.num_batches_tracked': 'models.147.bn100.num_batches_tracked',
611 'head.conv9.conv.0.weight': 'models.148.conv101.weight',
612 'head.conv9.conv.1.weight': 'models.148.bn101.weight',
613 'head.conv9.conv.1.bias': 'models.148.bn101.bias',
614 'head.conv9.conv.1.running_mean': 'models.148.bn101.running_mean',
615 'head.conv9.conv.1.running_var': 'models.148.bn101.running_var',
616 'head.conv9.conv.1.num_batches_tracked': 'models.148.bn101.num_batches_tracked',
617 'head.conv10.conv.0.weight': 'models.149.conv102.weight',
618 'head.conv10.conv.0.bias': 'models.149.conv102.bias',
619 'head.conv11.conv.0.weight': 'models.152.conv103.weight',
620 'head.conv11.conv.1.weight': 'models.152.bn103.weight',
621 'head.conv11.conv.1.bias': 'models.152.bn103.bias',
622 'head.conv11.conv.1.running_mean': 'models.152.bn103.running_mean',
623 'head.conv11.conv.1.running_var': 'models.152.bn103.running_var',
624 'head.conv11.conv.1.num_batches_tracked': 'models.152.bn103.num_batches_tracked',
625 'head.conv12.conv.0.weight': 'models.154.conv104.weight',
626 'head.conv12.conv.1.weight': 'models.154.bn104.weight',
627 'head.conv12.conv.1.bias': 'models.154.bn104.bias',
628 'head.conv12.conv.1.running_mean': 'models.154.bn104.running_mean',
629 'head.conv12.conv.1.running_var': 'models.154.bn104.running_var',
630 'head.conv12.conv.1.num_batches_tracked': 'models.154.bn104.num_batches_tracked',
631 'head.conv13.conv.0.weight': 'models.155.conv105.weight',
632 'head.conv13.conv.1.weight': 'models.155.bn105.weight',
633 'head.conv13.conv.1.bias': 'models.155.bn105.bias',
634 'head.conv13.conv.1.running_mean': 'models.155.bn105.running_mean',
635 'head.conv13.conv.1.running_var': 'models.155.bn105.running_var',
636 'head.conv13.conv.1.num_batches_tracked': 'models.155.bn105.num_batches_tracked',
637 'head.conv14.conv.0.weight': 'models.156.conv106.weight',
638 'head.conv14.conv.1.weight': 'models.156.bn106.weight',
639 'head.conv14.conv.1.bias': 'models.156.bn106.bias',
640 'head.conv14.conv.1.running_mean': 'models.156.bn106.running_mean',
641 'head.conv14.conv.1.running_var': 'models.156.bn106.running_var',
642 'head.conv14.conv.1.num_batches_tracked': 'models.156.bn106.num_batches_tracked',
643 'head.conv15.conv.0.weight': 'models.157.conv107.weight',
644 'head.conv15.conv.1.weight': 'models.157.bn107.weight',
645 'head.conv15.conv.1.bias': 'models.157.bn107.bias',
646 'head.conv15.conv.1.running_mean': 'models.157.bn107.running_mean',
647 'head.conv15.conv.1.running_var': 'models.157.bn107.running_var',
648 'head.conv15.conv.1.num_batches_tracked': 'models.157.bn107.num_batches_tracked',
649 'head.conv16.conv.0.weight': 'models.158.conv108.weight',
650 'head.conv16.conv.1.weight': 'models.158.bn108.weight',
651 'head.conv16.conv.1.bias': 'models.158.bn108.bias',
652 'head.conv16.conv.1.running_mean': 'models.158.bn108.running_mean',
653 'head.conv16.conv.1.running_var': 'models.158.bn108.running_var',
654 'head.conv16.conv.1.num_batches_tracked': 'models.158.bn108.num_batches_tracked',
655 'head.conv17.conv.0.weight': 'models.159.conv109.weight',
656 'head.conv17.conv.1.weight': 'models.159.bn109.weight',
657 'head.conv17.conv.1.bias': 'models.159.bn109.bias',
658 'head.conv17.conv.1.running_mean': 'models.159.bn109.running_mean',
659 'head.conv17.conv.1.running_var': 'models.159.bn109.running_var',
660 'head.conv17.conv.1.num_batches_tracked': 'models.159.bn109.num_batches_tracked',
661 'head.conv18.conv.0.weight': 'models.160.conv110.weight',
662 'head.conv18.conv.0.bias': 'models.160.conv110.bias',
663 }
664 pth_weights = torch.load(checkpoint)
665 pt_weights = type(pth_weights)()
666 for name, new_name in name_mapping.items():
667 pt_weights[new_name] = pth_weights[name]
668 return pt_weights
669
670
671def convert_pt_checkpoint_to_keras_h5(state_dict):
672 print('============================================================')
673
674 def copy1(conv, bn, idx):
675 keyword1 = 'conv%d.weight' % idx
676 keyword2 = 'bn%d.weight' % idx
677 keyword3 = 'bn%d.bias' % idx
678 keyword4 = 'bn%d.running_mean' % idx
679 keyword5 = 'bn%d.running_var' % idx
680 for key in state_dict:
681 value = state_dict[key].numpy()
682 if keyword1 in key:
683 w = value
684 elif keyword2 in key:
685 y = value
686 elif keyword3 in key:
687 b = value
688 elif keyword4 in key:
689 m = value
690 elif keyword5 in key:
691 v = value
692 w = w.transpose(2, 3, 1, 0)
693 conv.set_weights([w])
694 bn.set_weights([y, b, m, v])
695
696 def copy2(conv, idx):
697 keyword1 = 'conv%d.weight' % idx
698 keyword2 = 'conv%d.bias' % idx
699 for key in state_dict:
700 value = state_dict[key].numpy()
701 if keyword1 in key:
702 w = value
703 elif keyword2 in key:
704 b = value
705 w = w.transpose(2, 3, 1, 0)
706 conv.set_weights([w, b])
707
708 num_classes = 80
709 num_anchors = 3
710
711 with tf.Session(graph=tf.Graph()):
712 inputs = layers.Input(shape=[], dtype='string')
713 model_body = YOLOv4(inputs, num_classes, num_anchors)
714 model_body.summary()
715 layer_name_to_idx = {layer.name: idx for idx, layer in enumerate(model_body.layers)}
716
717 print('\nCopying...')
718 i1 = layer_name_to_idx['conv2d']
719 i2 = layer_name_to_idx['batch_normalization']
720 copy1(model_body.layers[i1], model_body.layers[i2], 1)
721 for i in range(2, 94, 1):
722 i1 = layer_name_to_idx['conv2d_%d' % (i - 1)]
723 i2 = layer_name_to_idx['batch_normalization_%d' % (i - 1)]
724 copy1(model_body.layers[i1], model_body.layers[i2], i)
725 for i in range(95, 102, 1):
726 i1 = layer_name_to_idx['conv2d_%d' % (i - 1)]
727 i2 = layer_name_to_idx['batch_normalization_%d' % (i - 2,)]
728 copy1(model_body.layers[i1], model_body.layers[i2], i)
729 for i in range(103, 110, 1):
730 i1 = layer_name_to_idx['conv2d_%d' % (i - 1)]
731 i2 = layer_name_to_idx['batch_normalization_%d' % (i - 3,)]
732 copy1(model_body.layers[i1], model_body.layers[i2], i)
733
734 i1 = layer_name_to_idx['conv2d_93']
735 copy2(model_body.layers[i1], 94)
736 i1 = layer_name_to_idx['conv2d_101']
737 copy2(model_body.layers[i1], 102)
738 i1 = layer_name_to_idx['conv2d_109']
739 copy2(model_body.layers[i1], 110)
740
741 weights = model_body.get_weights()
742 print('\nDone.')
743 return weights
744
745
746class Mish(layers.Layer):
747
748 def __init__(self):
749 super(Mish, self).__init__()
750
751 def compute_output_shape(self, input_shape):
752 return input_shape
753
754 def call(self, x):
755 return x * tf.tanh(tf.math.softplus(x))
756
757
758def conv2d_unit(x, filters, kernels, strides=1, padding='valid', bn=1, act='mish'):
759 use_bias = (bn != 1)
760 x = layers.Conv2D(filters, kernels,
761 padding=padding,
762 strides=strides,
763 use_bias=use_bias,
764 activation='linear',
765 kernel_initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.01))(x)
766 if bn:
767 x = layers.BatchNormalization(fused=False)(x)
768 if act == 'leaky':
769 x = keras.layers.LeakyReLU(alpha=0.1)(x)
770 elif act == 'mish':
771 x = Mish()(x)
772 return x
773
774
775def residual_block(inputs, filters_1, filters_2):
776 x = conv2d_unit(inputs, filters_1, 1, strides=1, padding='valid')
777 x = conv2d_unit(x, filters_2, 3, strides=1, padding='same')
778 x = layers.add([inputs, x])
779 return x
780
781
782def stack_residual_block(inputs, filters_1, filters_2, n):
783 x = residual_block(inputs, filters_1, filters_2)
784 for i in range(n - 1):
785 x = residual_block(x, filters_1, filters_2)
786 return x
787
788
789def spp(x):
790 x_1 = x
791 x_2 = layers.MaxPooling2D(pool_size=5, strides=1, padding='same')(x)
792 x_3 = layers.MaxPooling2D(pool_size=9, strides=1, padding='same')(x)
793 x_4 = layers.MaxPooling2D(pool_size=13, strides=1, padding='same')(x)
794 out = layers.Concatenate()([x_4, x_3, x_2, x_1])
795 return out
796
797
798def YOLOv4(inputs, num_classes, num_anchors, input_shape=(608, 608), initial_filters=32,
799 fast=False, anchors=None, conf_thresh=0.05, nms_thresh=0.45, keep_top_k=100, nms_top_k=100):
800 i32 = initial_filters
801 i64 = i32 * 2
802 i128 = i32 * 4
803 i256 = i32 * 8
804 i512 = i32 * 16
805 i1024 = i32 * 32
806
807 x, image_shape = layers.Lambda(lambda t: preprocessor(t, input_shape))(inputs)
808
809 # cspdarknet53
810 x = conv2d_unit(x, i32, 3, strides=1, padding='same')
811
812 # ============================= s2 =============================
813 x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(x)
814 x = conv2d_unit(x, i64, 3, strides=2)
815 s2 = conv2d_unit(x, i64, 1, strides=1)
816 x = conv2d_unit(x, i64, 1, strides=1)
817 x = stack_residual_block(x, i32, i64, n=1)
818 x = conv2d_unit(x, i64, 1, strides=1)
819 x = layers.Concatenate()([x, s2])
820 s2 = conv2d_unit(x, i64, 1, strides=1)
821
822 # ============================= s4 =============================
823 x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(s2)
824 x = conv2d_unit(x, i128, 3, strides=2)
825 s4 = conv2d_unit(x, i64, 1, strides=1)
826 x = conv2d_unit(x, i64, 1, strides=1)
827 x = stack_residual_block(x, i64, i64, n=2)
828 x = conv2d_unit(x, i64, 1, strides=1)
829 x = layers.Concatenate()([x, s4])
830 s4 = conv2d_unit(x, i128, 1, strides=1)
831
832 # ============================= s8 =============================
833 x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(s4)
834 x = conv2d_unit(x, i256, 3, strides=2)
835 s8 = conv2d_unit(x, i128, 1, strides=1)
836 x = conv2d_unit(x, i128, 1, strides=1)
837 x = stack_residual_block(x, i128, i128, n=8)
838 x = conv2d_unit(x, i128, 1, strides=1)
839 x = layers.Concatenate()([x, s8])
840 s8 = conv2d_unit(x, i256, 1, strides=1)
841
842 # ============================= s16 =============================
843 x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(s8)
844 x = conv2d_unit(x, i512, 3, strides=2)
845 s16 = conv2d_unit(x, i256, 1, strides=1)
846 x = conv2d_unit(x, i256, 1, strides=1)
847 x = stack_residual_block(x, i256, i256, n=8)
848 x = conv2d_unit(x, i256, 1, strides=1)
849 x = layers.Concatenate()([x, s16])
850 s16 = conv2d_unit(x, i512, 1, strides=1)
851
852 # ============================= s32 =============================
853 x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(s16)
854 x = conv2d_unit(x, i1024, 3, strides=2)
855 s32 = conv2d_unit(x, i512, 1, strides=1)
856 x = conv2d_unit(x, i512, 1, strides=1)
857 x = stack_residual_block(x, i512, i512, n=4)
858 x = conv2d_unit(x, i512, 1, strides=1)
859 x = layers.Concatenate()([x, s32])
860 s32 = conv2d_unit(x, i1024, 1, strides=1)
861
862 # fpn
863 x = conv2d_unit(s32, i512, 1, strides=1, act='leaky')
864 x = conv2d_unit(x, i1024, 3, strides=1, padding='same', act='leaky')
865 x = conv2d_unit(x, i512, 1, strides=1, act='leaky')
866 x = spp(x)
867
868 x = conv2d_unit(x, i512, 1, strides=1, act='leaky')
869 x = conv2d_unit(x, i1024, 3, strides=1, padding='same', act='leaky')
870 fpn_s32 = conv2d_unit(x, i512, 1, strides=1, act='leaky')
871
872 # pan01
873 x = conv2d_unit(fpn_s32, i256, 1, strides=1, act='leaky')
874 x = layers.UpSampling2D(2)(x)
875 s16 = conv2d_unit(s16, i256, 1, strides=1, act='leaky')
876 x = layers.Concatenate()([s16, x])
877 x = conv2d_unit(x, i256, 1, strides=1, act='leaky')
878 x = conv2d_unit(x, i512, 3, strides=1, padding='same', act='leaky')
879 x = conv2d_unit(x, i256, 1, strides=1, act='leaky')
880 x = conv2d_unit(x, i512, 3, strides=1, padding='same', act='leaky')
881 fpn_s16 = conv2d_unit(x, i256, 1, strides=1, act='leaky')
882
883 # pan02
884 x = conv2d_unit(fpn_s16, i128, 1, strides=1, act='leaky')
885 x = layers.UpSampling2D(2)(x)
886 s8 = conv2d_unit(s8, i128, 1, strides=1, act='leaky')
887 x = layers.Concatenate()([s8, x])
888 x = conv2d_unit(x, i128, 1, strides=1, act='leaky')
889 x = conv2d_unit(x, i256, 3, strides=1, padding='same', act='leaky')
890 x = conv2d_unit(x, i128, 1, strides=1, act='leaky')
891 x = conv2d_unit(x, i256, 3, strides=1, padding='same', act='leaky')
892 x = conv2d_unit(x, i128, 1, strides=1, act='leaky')
893
894 # output_s, doesn't need concat()
895 output_s = conv2d_unit(x, i256, 3, strides=1, padding='same', act='leaky')
896 output_s = conv2d_unit(output_s, num_anchors * (num_classes + 5), 1, strides=1, bn=0, act=None)
897
898 # output_m, need concat()
899 x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(x)
900 x = conv2d_unit(x, i256, 3, strides=2, act='leaky')
901 x = layers.Concatenate()([x, fpn_s16])
902 x = conv2d_unit(x, i256, 1, strides=1, act='leaky')
903 x = conv2d_unit(x, i512, 3, strides=1, padding='same', act='leaky')
904 x = conv2d_unit(x, i256, 1, strides=1, act='leaky')
905 x = conv2d_unit(x, i512, 3, strides=1, padding='same', act='leaky')
906 x = conv2d_unit(x, i256, 1, strides=1, act='leaky')
907 output_m = conv2d_unit(x, i512, 3, strides=1, padding='same', act='leaky')
908 output_m = conv2d_unit(output_m, num_anchors * (num_classes + 5), 1, strides=1, bn=0, act=None)
909
910 # output_l, need concat()
911 x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)))(x)
912 x = conv2d_unit(x, i512, 3, strides=2, act='leaky')
913 x = layers.Concatenate()([x, fpn_s32])
914 x = conv2d_unit(x, i512, 1, strides=1, act='leaky')
915 x = conv2d_unit(x, i1024, 3, strides=1, padding='same', act='leaky')
916 x = conv2d_unit(x, i512, 1, strides=1, act='leaky')
917 x = conv2d_unit(x, i1024, 3, strides=1, padding='same', act='leaky')
918 x = conv2d_unit(x, i512, 1, strides=1, act='leaky')
919 output_l = conv2d_unit(x, i1024, 3, strides=1, padding='same', act='leaky')
920 output_l = conv2d_unit(output_l, num_anchors * (num_classes + 5), 1, strides=1, bn=0, act=None)
921
922 def cast_float32(tensor):
923 return tf.cast(tensor, tf.float32)
924
925 output_l = layers.Lambda(cast_float32)(output_l)
926 output_m = layers.Lambda(cast_float32)(output_m)
927 output_s = layers.Lambda(cast_float32)(output_s)
928
929 # originally reshape in multi_thread_post
930 output_lr = layers.Reshape((1, input_shape[0] // 32, input_shape[1] // 32, 3, 5 + num_classes))(output_l)
931 output_mr = layers.Reshape((1, input_shape[0] // 16, input_shape[1] // 16, 3, 5 + num_classes))(output_m)
932 output_sr = layers.Reshape((1, input_shape[0] // 8, input_shape[1] // 8, 3, 5 + num_classes))(output_s)
933
934 # originally _yolo_out
935 masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
936 anchors = [[12, 16], [19, 36], [40, 28], [36, 75], [76, 55],
937 [72, 146], [142, 110], [192, 243], [459, 401]]
938
939 def batch_process_feats(out, anchors, mask):
940 grid_h, grid_w, num_boxes = map(int, out.shape[2:5])
941
942 anchors = [anchors[i] for i in mask]
943 anchors_tensor = np.array(anchors).reshape(1, 1, len(anchors), 2)
944
945 # Reshape to batch, height, width, num_anchors, box_params.
946 box_xy = tf.sigmoid(out[..., :2])
947 box_wh = tf.exp(out[..., 2:4])
948 box_wh = box_wh * anchors_tensor
949
950 box_confidence = tf.sigmoid(out[..., 4])
951 box_confidence = tf.expand_dims(box_confidence, axis=-1)
952 box_class_probs = tf.sigmoid(out[..., 5:])
953
954 col = np.tile(np.arange(0, grid_w), grid_w).reshape(-1, grid_w)
955 row = np.tile(np.arange(0, grid_h).reshape(-1, 1), grid_h)
956
957 col = col.reshape(grid_h, grid_w, 1, 1).repeat(3, axis=-2)
958 row = row.reshape(grid_h, grid_w, 1, 1).repeat(3, axis=-2)
959 grid = np.concatenate((col, row), axis=-1).astype(np.float32)
960
961 box_xy += grid
962 box_xy /= (grid_w, grid_h)
963 box_wh /= input_shape
964 box_xy -= (box_wh / 2.) # normalized xywh
965 boxes = tf.concat((box_xy, box_xy + box_wh), axis=-1)
966
967 box_scores = box_confidence * box_class_probs
968 num_boxes = np.prod(boxes.shape[1:-1])
969 boxes = tf.reshape(boxes, [-1, num_boxes, boxes.shape[-1]])
970 box_scores = tf.reshape(box_scores, [-1, num_boxes, box_scores.shape[-1]])
971 return boxes, box_scores
972
973 def filter_boxes(outputs):
974 boxes_l, boxes_m, boxes_s, box_scores_l, box_scores_m, box_scores_s, image_shape = outputs
975 boxes_l, box_scores_l = filter_boxes_one_size(boxes_l, box_scores_l)
976 boxes_m, box_scores_m = filter_boxes_one_size(boxes_m, box_scores_m)
977 boxes_s, box_scores_s = filter_boxes_one_size(boxes_s, box_scores_s)
978 boxes = tf.concat([boxes_l, boxes_m, boxes_s], axis=0)
979 box_scores = tf.concat([box_scores_l, box_scores_m, box_scores_s], axis=0)
980 image_shape_wh = image_shape[1::-1]
981 image_shape_whwh = tf.concat([image_shape_wh, image_shape_wh], axis=-1)
982 image_shape_whwh = tf.cast(image_shape_whwh, tf.float32)
983 boxes *= image_shape_whwh
984 boxes = tf.expand_dims(boxes, 0)
985 box_scores = tf.expand_dims(box_scores, 0)
986 boxes = tf.expand_dims(boxes, 2)
987 nms_boxes, nms_scores, nms_classes, valid_detections = tf.image.combined_non_max_suppression(
988 boxes,
989 box_scores,
990 max_output_size_per_class=nms_top_k,
991 max_total_size=nms_top_k,
992 iou_threshold=nms_thresh,
993 score_threshold=conf_thresh,
994 pad_per_class=False,
995 clip_boxes=False,
996 name='CombinedNonMaxSuppression',
997 )
998 return nms_boxes[0], nms_scores[0], nms_classes[0]
999
1000 def filter_boxes_one_size(boxes, box_scores):
1001 box_class_scores = tf.reduce_max(box_scores, axis=-1)
1002 keep = box_class_scores > conf_thresh
1003 boxes = boxes[keep]
1004 box_scores = box_scores[keep]
1005 return boxes, box_scores
1006
1007 def batch_yolo_out(outputs):
1008 with tf.name_scope('yolo_out'):
1009 b_output_lr, b_output_mr, b_output_sr, b_image_shape = outputs
1010 with tf.name_scope('process_feats'):
1011 b_boxes_l, b_box_scores_l = batch_process_feats(b_output_lr, anchors, masks[0])
1012 with tf.name_scope('process_feats'):
1013 b_boxes_m, b_box_scores_m = batch_process_feats(b_output_mr, anchors, masks[1])
1014 with tf.name_scope('process_feats'):
1015 b_boxes_s, b_box_scores_s = batch_process_feats(b_output_sr, anchors, masks[2])
1016 with tf.name_scope('filter_boxes'):
1017 b_nms_boxes, b_nms_scores, b_nms_classes = tf.map_fn(
1018 filter_boxes, [b_boxes_l, b_boxes_m, b_boxes_s, b_box_scores_l, b_box_scores_m, b_box_scores_s, b_image_shape],
1019 dtype=(tf.float32, tf.float32, tf.float32), back_prop=False, parallel_iterations=16)
1020 return b_nms_boxes, b_nms_scores, b_nms_classes
1021
1022 boxes_scores_classes = layers.Lambda(batch_yolo_out)([output_lr, output_mr, output_sr, image_shape])
1023
1024 model_body = keras.models.Model(inputs=inputs, outputs=boxes_scores_classes)
1025 return model_body
1026
1027
1028def decode_jpeg_resize(input_tensor, image_size):
1029 tensor = tf.image.decode_png(input_tensor, channels=3)
1030 shape = tf.shape(tensor)
1031 tensor = tf.cast(tensor, tf.float32)
1032 tensor = tf.image.resize(tensor, image_size)
1033 tensor /= 255.0
1034 return tf.cast(tensor, tf.float16), shape
1035
1036
1037def preprocessor(input_tensor, image_size):
1038 with tf.name_scope('Preprocessor'):
1039 tensor = tf.map_fn(
1040 partial(decode_jpeg_resize, image_size=image_size), input_tensor,
1041 dtype=(tf.float16, tf.int32), back_prop=False, parallel_iterations=16)
1042 return tensor
1043
1044
1045def main():
1046 os.system('aws s3 cp s3://neuron-s3/training_checkpoints/pytorch/yolov4/yolov4.pth . --no-sign-request')
1047 torch_weights = rename_weights('./yolov4.pth')
1048 keras_weights = convert_pt_checkpoint_to_keras_h5(torch_weights)
1049 keras.backend.set_learning_phase(0)
1050 num_anchors = 3
1051 num_classes = 80
1052 input_shape = (608, 608)
1053 conf_thresh = 0.001
1054 nms_thresh = 0.45
1055 inputs = layers.Input(shape=[], dtype='string')
1056 yolo = YOLOv4(inputs, num_classes, num_anchors, input_shape, conf_thresh=conf_thresh, nms_thresh=nms_thresh)
1057 yolo.set_weights(keras_weights)
1058 sess = keras.backend.get_session()
1059 inputs = {'image': yolo.inputs[0]}
1060 output_names = ['boxes', 'scores', 'classes']
1061 outputs = {name: ts for name, ts in zip(output_names, yolo.outputs)}
1062 tf.saved_model.simple_save(sess, './yolo_v4_coco_saved_model', inputs, outputs)
1063
1064
1065if __name__ == '__main__':
1066 main()
This document is relevant for: Inf1