Skip to content

Commit

Permalink
Fix RPN show_result error (open-mmlab#4716)
Browse files Browse the repository at this point in the history
* Fix RPN show_result error

* Remove EmbeddingRPN show_result

* Add docstr
  • Loading branch information
hhaAndroid authored Mar 12, 2021
1 parent dbc6b67 commit 895a8d5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 32 deletions.
20 changes: 0 additions & 20 deletions mmdet/models/dense_heads/embedding_rpn_head.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import mmcv
import torch
import torch.nn as nn
from mmcv import tensor2imgs

from mmdet.models.builder import HEADS
from ...core import bbox_cxcywh_to_xyxy
Expand Down Expand Up @@ -100,21 +98,3 @@ def forward_train(self, img, img_metas):
def simple_test_rpn(self, img, img_metas):
"""Forward function in testing stage."""
return self._decode_init_proposals(img, img_metas)

def show_result(self, data):
"""Show the init proposals in EmbeddingRPN.
Args:
data (dict): Dict contains image and
corresponding meta information.
"""
img_tensor = data['img'][0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
proposals, _ = self._decode_init_proposals(data['img'],
data['img_metas'])
assert len(imgs) == len(img_metas)
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
mmcv.imshow_bboxes(img_show, proposals)
11 changes: 10 additions & 1 deletion mmdet/models/detectors/cascade_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,16 @@ def __init__(self,
pretrained=pretrained)

def show_result(self, data, result, **kwargs):
"""Show prediction results of the detector."""
"""Show prediction results of the detector.
Args:
data (str or np.ndarray): Image filename or loaded image.
result (Tensor or tuple): The results to draw over `img`
bbox_result or (bbox_result, segm_result).
Returns:
np.ndarray: The image with bboxes drawn on it.
"""
if self.with_mask:
ms_bbox_result, ms_segm_result = result
if isinstance(ms_bbox_result, dict):
Expand Down
22 changes: 11 additions & 11 deletions mmdet/models/detectors/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,17 @@ def aug_test(self, imgs, img_metas, rescale=False):
flip_direction)
return [proposal.cpu().numpy() for proposal in proposal_list]

def show_result(self, data, result, dataset=None, top_k=20):
def show_result(self, data, result, top_k=20, **kwargs):
"""Show RPN proposals on the image.
Although we assume batch size is 1, this method supports arbitrary
batch size.
Args:
data (str or np.ndarray): Image filename or loaded image.
result (Tensor or tuple): The results to draw over `img`
bbox_result or (bbox_result, segm_result).
top_k (int): Plot the first k bboxes only
if set positive. Default: 20
Returns:
np.ndarray: The image with bboxes drawn on it.
"""
img_tensor = data['img'][0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
mmcv.imshow_bboxes(img_show, result, top_k=top_k)
mmcv.imshow_bboxes(data, result, top_k=top_k)

0 comments on commit 895a8d5

Please sign in to comment.