diff --git a/mmdet/models/roi_heads/cascade_roi_head.py b/mmdet/models/roi_heads/cascade_roi_head.py index 1249607f040..495b5fd2933 100644 --- a/mmdet/models/roi_heads/cascade_roi_head.py +++ b/mmdet/models/roi_heads/cascade_roi_head.py @@ -281,6 +281,10 @@ def forward_train(self, # bbox_targets is a tuple roi_labels = bbox_results['bbox_targets'][0] with torch.no_grad(): + roi_labels = torch.where( + roi_labels == self.bbox_head[i].num_classes, + bbox_results['cls_score'][:, :-1].argmax(1), + roi_labels) proposal_list = self.bbox_head[i].refine_bboxes( bbox_results['rois'], roi_labels, bbox_results['bbox_pred'], pos_is_gts, img_metas) @@ -306,7 +310,7 @@ def simple_test(self, x, proposal_list, img_metas, rescale=False): ms_scores.append(bbox_results['cls_score']) if i < self.num_stages - 1: - bbox_label = bbox_results['cls_score'].argmax(dim=1) + bbox_label = bbox_results['cls_score'][:, :-1].argmax(dim=1) rois = self.bbox_head[i].regress_by_class( rois, bbox_label, bbox_results['bbox_pred'], img_metas[0]) @@ -380,7 +384,8 @@ def aug_test(self, features, proposal_list, img_metas, rescale=False): ms_scores.append(bbox_results['cls_score']) if i < self.num_stages - 1: - bbox_label = bbox_results['cls_score'].argmax(dim=1) + bbox_label = bbox_results['cls_score'][:, :-1].argmax( + dim=1) rois = self.bbox_head[i].regress_by_class( rois, bbox_label, bbox_results['bbox_pred'], img_meta[0])