Skip to content

Commit

Permalink
[ONNX] Fix dtype for NonMaxSuppression (#7056)
Browse files Browse the repository at this point in the history
Co-authored-by: Nikita Shulga <nshulga@fb.com>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
  • Loading branch information
3 people authored Feb 15, 2023
1 parent f9d1883 commit f627b9d
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions torchvision/ops/_register_onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,25 @@
def _register_custom_op():
from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze
from torch.onnx.symbolic_opset9 import _cast_Long

@parse_args("v", "v", "f")
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
boxes = unsqueeze(g, boxes, 0)
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold)
nms_out = g.op(
"NonMaxSuppression",
g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT),
g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT),
max_output_per_class,
iou_threshold,
)
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)

def _process_batch_indices_for_roi_align(g, rois):
return _cast_Long(
g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False
)
indices = squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1)
return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64)

def _process_rois_for_roi_align(g, rois):
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
Expand Down

0 comments on commit f627b9d

Please sign in to comment.