From 13101b72983b3aa533355ffb564fcb3164dec118 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 21 Nov 2022 19:09:51 +0800 Subject: [PATCH 1/2] fix roi align symbolic for torch>1.13 --- mmcv/ops/roi_align.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/mmcv/ops/roi_align.py b/mmcv/ops/roi_align.py index 839843c7f0..e0fa374e45 100644 --- a/mmcv/ops/roi_align.py +++ b/mmcv/ops/roi_align.py @@ -20,16 +20,25 @@ class RoIAlignFunction(Function): def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio, pool_mode, aligned): from torch.onnx import TensorProtoDataType - from torch.onnx.symbolic_helper import _slice_helper - from torch.onnx.symbolic_opset9 import squeeze, sub + from torch.onnx.symbolic_opset9 import sub + + def _select(g, self, dim, index): + return g.op("Gather", self, index, axis_i=dim) # batch_indices = rois[:, 0].long() - batch_indices = _slice_helper(g, rois, axes=[1], starts=[0], ends=[1]) - batch_indices = squeeze(g, batch_indices, 1) + batch_indices = _select( + g, rois, 1, + g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))) + batch_indices = g.op('Squeeze', batch_indices, axes_i=[1]) batch_indices = g.op( 'Cast', batch_indices, to_i=TensorProtoDataType.INT64) # rois = rois[:, 1:] - rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5]) + rois = _select( + g, rois, 1, + g.op( + "Constant", + value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long))) + if aligned: # rois -= 0.5/spatial_scale aligned_offset = g.op( From f090bc8fbb4bf0b0ac987ed91b09597a6ee00894 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 22 Nov 2022 10:19:22 +0800 Subject: [PATCH 2/2] fix lint --- mmcv/ops/roi_align.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mmcv/ops/roi_align.py b/mmcv/ops/roi_align.py index e0fa374e45..de2bed204d 100644 --- a/mmcv/ops/roi_align.py +++ b/mmcv/ops/roi_align.py @@ -23,12 +23,12 @@ def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio, from torch.onnx.symbolic_opset9 import sub def _select(g, self, dim, index): - return g.op("Gather", self, index, axis_i=dim) + return g.op('Gather', self, index, axis_i=dim) # batch_indices = rois[:, 0].long() batch_indices = _select( g, rois, 1, - g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))) + g.op('Constant', value_t=torch.tensor([0], dtype=torch.long))) batch_indices = g.op('Squeeze', batch_indices, axes_i=[1]) batch_indices = g.op( 'Cast', batch_indices, to_i=TensorProtoDataType.INT64) @@ -36,7 +36,7 @@ def _select(g, self, dim, index): rois = _select( g, rois, 1, g.op( - "Constant", + 'Constant', value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long))) if aligned: