Skip to content

Commit f612182

Browse files
lara-hdrfmassa
authored andcommitted
Support Exporting Mask Rcnn to ONNX (#1461)
* Support Exporting Mask Rcnn to ONNX * update tetst * add control flow test * fix * update test and fix img_shape
1 parent 30cb4e1 commit f612182

File tree

2 files changed

+129
-4
lines changed

2 files changed

+129
-4
lines changed

test/test_onnx.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
99
from torchvision.models.detection.roi_heads import RoIHeads
1010
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
11+
from torchvision.models.detection.mask_rcnn import MaskRCNNHeads, MaskRCNNPredictor
1112

1213
from collections import OrderedDict
1314

@@ -259,7 +260,7 @@ def forward(self_module, features):
259260
model = RoiHeadsModule(images)
260261
model.eval()
261262
model(features)
262-
self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True)
263+
self.run_model(model, [(features,), (test_features,)])
263264

264265
def get_image_from_url(self, url):
265266
import requests
@@ -294,6 +295,45 @@ def test_faster_rcnn(self):
294295
model(images)
295296
self.run_model(model, [(images,), (test_images,)])
296297

298+
# Verify that paste_mask_in_image beahves the same in tracing.
299+
# This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
300+
# (since jit_trace witll call _onnx_paste_masks_in_image).
301+
def test_paste_mask_in_image(self):
302+
masks = torch.rand(10, 1, 26, 26)
303+
boxes = torch.rand(10, 4)
304+
boxes[:, 2:] += torch.rand(10, 2)
305+
boxes *= 50
306+
o_im_s = (100, 100)
307+
from torchvision.models.detection.roi_heads import paste_masks_in_image
308+
out = paste_masks_in_image(masks, boxes, o_im_s)
309+
jit_trace = torch.jit.trace(paste_masks_in_image,
310+
(masks, boxes,
311+
[torch.tensor(o_im_s[0]),
312+
torch.tensor(o_im_s[1])]))
313+
out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])])
314+
315+
assert torch.all(out.eq(out_trace))
316+
317+
masks2 = torch.rand(20, 1, 26, 26)
318+
boxes2 = torch.rand(20, 4)
319+
boxes2[:, 2:] += torch.rand(20, 2)
320+
boxes2 *= 100
321+
o_im_s2 = (200, 200)
322+
from torchvision.models.detection.roi_heads import paste_masks_in_image
323+
out2 = paste_masks_in_image(masks2, boxes2, o_im_s2)
324+
out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])])
325+
326+
assert torch.all(out2.eq(out_trace2))
327+
328+
@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
329+
def test_mask_rcnn(self):
330+
images, test_images = self.get_test_images()
331+
332+
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True)
333+
model.eval()
334+
model(images)
335+
self.run_model(model, [(images,), (test_images,)])
336+
297337

298338
if __name__ == '__main__':
299339
unittest.main()

torchvision/models/detection/roi_heads.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torchvision
23

34
import torch.nn.functional as F
45
from torch import nn
@@ -73,7 +74,11 @@ def maskrcnn_inference(x, labels):
7374
index = torch.arange(num_masks, device=labels.device)
7475
mask_prob = mask_prob[index, labels][:, None]
7576

76-
mask_prob = mask_prob.split(boxes_per_image, dim=0)
77+
if len(boxes_per_image) == 1:
78+
# TODO : remove when dynamic split supported in ONNX
79+
mask_prob = (mask_prob,)
80+
else:
81+
mask_prob = mask_prob.split(boxes_per_image, dim=0)
7782

7883
return mask_prob
7984

@@ -250,10 +255,29 @@ def keypointrcnn_inference(x, boxes):
250255
return kp_probs, kp_scores
251256

252257

258+
def _onnx_expand_boxes(boxes, scale):
259+
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
260+
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
261+
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
262+
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
263+
264+
w_half = w_half.to(dtype=torch.float32) * scale
265+
h_half = h_half.to(dtype=torch.float32) * scale
266+
267+
boxes_exp0 = x_c - w_half
268+
boxes_exp1 = y_c - h_half
269+
boxes_exp2 = x_c + w_half
270+
boxes_exp3 = y_c + h_half
271+
boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
272+
return boxes_exp
273+
274+
253275
# the next two functions should be merged inside Masker
254276
# but are kept here for the moment while we need them
255277
# temporarily for paste_mask_in_image
256278
def expand_boxes(boxes, scale):
279+
if torchvision._is_tracing():
280+
return _onnx_expand_boxes(boxes, scale)
257281
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
258282
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
259283
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
@@ -272,7 +296,10 @@ def expand_boxes(boxes, scale):
272296

273297
def expand_masks(mask, padding):
274298
M = mask.shape[-1]
275-
scale = float(M + 2 * padding) / M
299+
if torchvision._is_tracing():
300+
scale = (M + 2 * padding).to(torch.float32) / M.to(torch.float32)
301+
else:
302+
scale = float(M + 2 * padding) / M
276303
padded_mask = torch.nn.functional.pad(mask, (padding,) * 4)
277304
return padded_mask, scale
278305

@@ -303,11 +330,69 @@ def paste_mask_in_image(mask, box, im_h, im_w):
303330
return im_mask
304331

305332

333+
def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
334+
one = torch.ones(1, dtype=torch.int64)
335+
zero = torch.zeros(1, dtype=torch.int64)
336+
337+
w = (box[2] - box[0] + one)
338+
h = (box[3] - box[1] + one)
339+
w = torch.max(torch.cat((w, one)))
340+
h = torch.max(torch.cat((h, one)))
341+
342+
# Set shape to [batchxCxHxW]
343+
mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
344+
345+
# Resize mask
346+
mask = torch.nn.functional.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
347+
mask = mask[0][0]
348+
349+
x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
350+
x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
351+
y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
352+
y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
353+
354+
unpaded_im_mask = mask[(y_0 - box[1]):(y_1 - box[1]),
355+
(x_0 - box[0]):(x_1 - box[0])]
356+
357+
# TODO : replace below with a dynamic padding when support is added in ONNX
358+
359+
# pad y
360+
zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
361+
zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
362+
concat_0 = torch.cat((zeros_y0,
363+
unpaded_im_mask.to(dtype=torch.float32),
364+
zeros_y1), 0)[0:im_h, :]
365+
# pad x
366+
zeros_x0 = torch.zeros(concat_0.size(0), x_0)
367+
zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
368+
im_mask = torch.cat((zeros_x0,
369+
concat_0,
370+
zeros_x1), 1)[:, :im_w]
371+
return im_mask
372+
373+
374+
@torch.jit.script
375+
def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
376+
res_append = torch.zeros(0, im_h, im_w)
377+
for i in range(masks.size(0)):
378+
mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
379+
mask_res = mask_res.unsqueeze(0)
380+
res_append = torch.cat((res_append, mask_res))
381+
return res_append
382+
383+
306384
def paste_masks_in_image(masks, boxes, img_shape, padding=1):
307385
masks, scale = expand_masks(masks, padding=padding)
308-
boxes = expand_boxes(boxes, scale).to(dtype=torch.int64).tolist()
386+
boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
309387
# im_h, im_w = img_shape.tolist()
310388
im_h, im_w = img_shape
389+
390+
if torchvision._is_tracing():
391+
return _onnx_paste_masks_in_image_loop(masks, boxes,
392+
torch.scalar_tensor(im_h, dtype=torch.int64),
393+
torch.scalar_tensor(im_w, dtype=torch.int64))[:, None]
394+
395+
boxes = boxes.tolist()
311396
res = [
312397
paste_mask_in_image(m[0], b, im_h, im_w)
313398
for m, b in zip(masks, boxes)

0 commit comments

Comments
 (0)