Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RPN show_result error #4716

Merged
merged 3 commits into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions mmdet/models/dense_heads/embedding_rpn_head.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import mmcv
import torch
import torch.nn as nn
from mmcv import tensor2imgs

from mmdet.models.builder import HEADS
from ...core import bbox_cxcywh_to_xyxy
Expand Down Expand Up @@ -100,21 +98,3 @@ def forward_train(self, img, img_metas):
def simple_test_rpn(self, img, img_metas):
"""Forward function in testing stage."""
return self._decode_init_proposals(img, img_metas)

def show_result(self, data):
"""Show the init proposals in EmbeddingRPN.

Args:
data (dict): Dict contains image and
corresponding meta information.
"""
img_tensor = data['img'][0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
proposals, _ = self._decode_init_proposals(data['img'],
data['img_metas'])
assert len(imgs) == len(img_metas)
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
mmcv.imshow_bboxes(img_show, proposals)
11 changes: 10 additions & 1 deletion mmdet/models/detectors/cascade_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,16 @@ def __init__(self,
pretrained=pretrained)

def show_result(self, data, result, **kwargs):
"""Show prediction results of the detector."""
"""Show prediction results of the detector.

Args:
data (str or np.ndarray): Image filename or loaded image.
result (Tensor or tuple): The results to draw over `img`
bbox_result or (bbox_result, segm_result).

Returns:
np.ndarray: The image with bboxes drawn on it.
"""
if self.with_mask:
ms_bbox_result, ms_segm_result = result
if isinstance(ms_bbox_result, dict):
Expand Down
22 changes: 11 additions & 11 deletions mmdet/models/detectors/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,17 @@ def aug_test(self, imgs, img_metas, rescale=False):
flip_direction)
return [proposal.cpu().numpy() for proposal in proposal_list]

def show_result(self, data, result, dataset=None, top_k=20):
def show_result(self, data, result, top_k=20, **kwargs):
"""Show RPN proposals on the image.

Although we assume batch size is 1, this method supports arbitrary
batch size.
Args:
data (str or np.ndarray): Image filename or loaded image.
result (Tensor or tuple): The results to draw over `img`
bbox_result or (bbox_result, segm_result).
top_k (int): Plot the first k bboxes only
if set positive. Default: 20

Returns:
np.ndarray: The image with bboxes drawn on it.
"""
img_tensor = data['img'][0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
mmcv.imshow_bboxes(img_show, result, top_k=top_k)
mmcv.imshow_bboxes(data, result, top_k=top_k)