Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: support dynamic shape for grid_anchor #4684

Merged
merged 3 commits into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions mmdet/core/anchor/anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,9 @@ def _meshgrid(self, x, y, row_major=True):
Returns:
tuple[torch.Tensor]: The mesh grids of x and y.
"""
xx = x.repeat(len(y))
yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
# use shape instead of len to keep tracing while exporting to onnx
xx = x.repeat(y.shape[0])
yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
if row_major:
return xx, yy
else:
Expand Down Expand Up @@ -250,10 +251,8 @@ def single_level_grid_anchors(self,
Returns:
torch.Tensor: Anchors in the overall feature maps.
"""
# keep as Tensor, so that we can covert to ONNX correctly
feat_h, feat_w = featmap_size
# convert Tensor to int, so that we can covert to ONNX correctlly
feat_h = int(feat_h)
feat_w = int(feat_w)
shift_x = torch.arange(0, feat_w, device=device) * stride[0]
shift_y = torch.arange(0, feat_h, device=device) * stride[1]

Expand Down
15 changes: 13 additions & 2 deletions mmdet/models/dense_heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,11 @@ def _get_bboxes_single(self,
"""
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
# convert to tensor to keep tracing
nms_pre_tensor = torch.tensor(
cfg.get('nms_pre', -1),
device=cls_score_list[0].device,
dtype=torch.long)
mlvl_bboxes = []
mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_score_list,
Expand All @@ -632,8 +637,14 @@ def _get_bboxes_single(self,
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
# Always keep topk op for dynamic input in onnx
if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export()
or scores.shape[-2] > nms_pre_tensor):
from torch import _shape_as_tensor
# keep shape as tensor and get k
num_anchor = _shape_as_tensor(scores)[-2]
nms_pre = torch.where(nms_pre_tensor < num_anchor,
nms_pre_tensor, num_anchor)
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
Expand Down