1
1
import torch
2
+ import torchvision
2
3
3
4
import torch .nn .functional as F
4
5
from torch import nn
@@ -73,7 +74,11 @@ def maskrcnn_inference(x, labels):
73
74
index = torch .arange (num_masks , device = labels .device )
74
75
mask_prob = mask_prob [index , labels ][:, None ]
75
76
76
- mask_prob = mask_prob .split (boxes_per_image , dim = 0 )
77
+ if len (boxes_per_image ) == 1 :
78
+ # TODO : remove when dynamic split supported in ONNX
79
+ mask_prob = (mask_prob ,)
80
+ else :
81
+ mask_prob = mask_prob .split (boxes_per_image , dim = 0 )
77
82
78
83
return mask_prob
79
84
@@ -250,10 +255,29 @@ def keypointrcnn_inference(x, boxes):
250
255
return kp_probs , kp_scores
251
256
252
257
258
+ def _onnx_expand_boxes (boxes , scale ):
259
+ w_half = (boxes [:, 2 ] - boxes [:, 0 ]) * .5
260
+ h_half = (boxes [:, 3 ] - boxes [:, 1 ]) * .5
261
+ x_c = (boxes [:, 2 ] + boxes [:, 0 ]) * .5
262
+ y_c = (boxes [:, 3 ] + boxes [:, 1 ]) * .5
263
+
264
+ w_half = w_half .to (dtype = torch .float32 ) * scale
265
+ h_half = h_half .to (dtype = torch .float32 ) * scale
266
+
267
+ boxes_exp0 = x_c - w_half
268
+ boxes_exp1 = y_c - h_half
269
+ boxes_exp2 = x_c + w_half
270
+ boxes_exp3 = y_c + h_half
271
+ boxes_exp = torch .stack ((boxes_exp0 , boxes_exp1 , boxes_exp2 , boxes_exp3 ), 1 )
272
+ return boxes_exp
273
+
274
+
253
275
# the next two functions should be merged inside Masker
254
276
# but are kept here for the moment while we need them
255
277
# temporarily for paste_mask_in_image
256
278
def expand_boxes (boxes , scale ):
279
+ if torchvision ._is_tracing ():
280
+ return _onnx_expand_boxes (boxes , scale )
257
281
w_half = (boxes [:, 2 ] - boxes [:, 0 ]) * .5
258
282
h_half = (boxes [:, 3 ] - boxes [:, 1 ]) * .5
259
283
x_c = (boxes [:, 2 ] + boxes [:, 0 ]) * .5
@@ -272,7 +296,10 @@ def expand_boxes(boxes, scale):
272
296
273
297
def expand_masks (mask , padding ):
274
298
M = mask .shape [- 1 ]
275
- scale = float (M + 2 * padding ) / M
299
+ if torchvision ._is_tracing ():
300
+ scale = (M + 2 * padding ).to (torch .float32 ) / M .to (torch .float32 )
301
+ else :
302
+ scale = float (M + 2 * padding ) / M
276
303
padded_mask = torch .nn .functional .pad (mask , (padding ,) * 4 )
277
304
return padded_mask , scale
278
305
@@ -303,11 +330,69 @@ def paste_mask_in_image(mask, box, im_h, im_w):
303
330
return im_mask
304
331
305
332
333
+ def _onnx_paste_mask_in_image (mask , box , im_h , im_w ):
334
+ one = torch .ones (1 , dtype = torch .int64 )
335
+ zero = torch .zeros (1 , dtype = torch .int64 )
336
+
337
+ w = (box [2 ] - box [0 ] + one )
338
+ h = (box [3 ] - box [1 ] + one )
339
+ w = torch .max (torch .cat ((w , one )))
340
+ h = torch .max (torch .cat ((h , one )))
341
+
342
+ # Set shape to [batchxCxHxW]
343
+ mask = mask .expand ((1 , 1 , mask .size (0 ), mask .size (1 )))
344
+
345
+ # Resize mask
346
+ mask = torch .nn .functional .interpolate (mask , size = (int (h ), int (w )), mode = 'bilinear' , align_corners = False )
347
+ mask = mask [0 ][0 ]
348
+
349
+ x_0 = torch .max (torch .cat ((box [0 ].unsqueeze (0 ), zero )))
350
+ x_1 = torch .min (torch .cat ((box [2 ].unsqueeze (0 ) + one , im_w .unsqueeze (0 ))))
351
+ y_0 = torch .max (torch .cat ((box [1 ].unsqueeze (0 ), zero )))
352
+ y_1 = torch .min (torch .cat ((box [3 ].unsqueeze (0 ) + one , im_h .unsqueeze (0 ))))
353
+
354
+ unpaded_im_mask = mask [(y_0 - box [1 ]):(y_1 - box [1 ]),
355
+ (x_0 - box [0 ]):(x_1 - box [0 ])]
356
+
357
+ # TODO : replace below with a dynamic padding when support is added in ONNX
358
+
359
+ # pad y
360
+ zeros_y0 = torch .zeros (y_0 , unpaded_im_mask .size (1 ))
361
+ zeros_y1 = torch .zeros (im_h - y_1 , unpaded_im_mask .size (1 ))
362
+ concat_0 = torch .cat ((zeros_y0 ,
363
+ unpaded_im_mask .to (dtype = torch .float32 ),
364
+ zeros_y1 ), 0 )[0 :im_h , :]
365
+ # pad x
366
+ zeros_x0 = torch .zeros (concat_0 .size (0 ), x_0 )
367
+ zeros_x1 = torch .zeros (concat_0 .size (0 ), im_w - x_1 )
368
+ im_mask = torch .cat ((zeros_x0 ,
369
+ concat_0 ,
370
+ zeros_x1 ), 1 )[:, :im_w ]
371
+ return im_mask
372
+
373
+
374
+ @torch .jit .script
375
+ def _onnx_paste_masks_in_image_loop (masks , boxes , im_h , im_w ):
376
+ res_append = torch .zeros (0 , im_h , im_w )
377
+ for i in range (masks .size (0 )):
378
+ mask_res = _onnx_paste_mask_in_image (masks [i ][0 ], boxes [i ], im_h , im_w )
379
+ mask_res = mask_res .unsqueeze (0 )
380
+ res_append = torch .cat ((res_append , mask_res ))
381
+ return res_append
382
+
383
+
306
384
def paste_masks_in_image (masks , boxes , img_shape , padding = 1 ):
307
385
masks , scale = expand_masks (masks , padding = padding )
308
- boxes = expand_boxes (boxes , scale ).to (dtype = torch .int64 ). tolist ()
386
+ boxes = expand_boxes (boxes , scale ).to (dtype = torch .int64 )
309
387
# im_h, im_w = img_shape.tolist()
310
388
im_h , im_w = img_shape
389
+
390
+ if torchvision ._is_tracing ():
391
+ return _onnx_paste_masks_in_image_loop (masks , boxes ,
392
+ torch .scalar_tensor (im_h , dtype = torch .int64 ),
393
+ torch .scalar_tensor (im_w , dtype = torch .int64 ))[:, None ]
394
+
395
+ boxes = boxes .tolist ()
311
396
res = [
312
397
paste_mask_in_image (m [0 ], b , im_h , im_w )
313
398
for m , b in zip (masks , boxes )
0 commit comments