Skip to content

Commit

Permalink
[Fix]Fix a bug in StackQueryAndGroup (#2043)
Browse files Browse the repository at this point in the history
* fix a bug

* fix a batch inference bug

* fix docs
  • Loading branch information
VVsssssk authored Nov 22, 2022
1 parent b29be44 commit 3296b4f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ def forward(self,
'new_xyz: str(new_xyz.shape), new_xyz_batch_cnt: ' \
'str(new_xyz_batch_cnt)'

# idx: (M1 + M2 ..., nsample), empty_ball_mask: (M1 + M2 ...)
idx, empty_ball_mask = ball_query(0, self.radius, self.sample_nums,
xyz, new_xyz, xyz_batch_cnt,
new_xyz_batch_cnt)
# idx: (M1 + M2 ..., nsample)
idx = ball_query(0, self.radius, self.sample_nums, xyz, new_xyz,
xyz_batch_cnt, new_xyz_batch_cnt)
empty_ball_mask = (idx[:, 0] == -1)
idx[empty_ball_mask] = 0
grouped_xyz = grouping_operation(
xyz, idx, xyz_batch_cnt,
new_xyz_batch_cnt) # (M1 + M2, 3, nsample)
Expand Down
14 changes: 7 additions & 7 deletions mmdet3d/models/roi_heads/bbox_heads/pv_rcnn_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import torch
from mmcv.cnn import ConvModule
from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.models.utils import multi_apply
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import nn as nn
Expand All @@ -14,8 +16,6 @@
from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes,
rotation_3d_in_axis, xywhr2xyxyr)
from mmdet3d.utils import InstanceList
from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.models.utils import multi_apply


@MODELS.register_module()
Expand Down Expand Up @@ -440,21 +440,21 @@ def get_results(self,
# post processing
result_list = []
for batch_id in range(batch_size):
cls_preds = cls_preds[roi_batch_id == batch_id]
cur_cls_preds = cls_preds[roi_batch_id == batch_id]
box_preds = batch_box_preds[roi_batch_id == batch_id]
label_preds = class_labels[batch_id]

cls_preds = cls_preds.sigmoid()
cls_preds, _ = torch.max(cls_preds, dim=-1)
cur_cls_preds = cur_cls_preds.sigmoid()
cur_cls_preds, _ = torch.max(cur_cls_preds, dim=-1)
selected = self.class_agnostic_nms(
scores=cls_preds,
scores=cur_cls_preds,
bbox_preds=box_preds,
input_meta=input_metas[batch_id],
nms_cfg=test_cfg)

selected_bboxes = box_preds[selected]
selected_label_preds = label_preds[selected]
selected_scores = cls_preds[selected]
selected_scores = cur_cls_preds[selected]

results = InstanceData()
results.bboxes_3d = input_metas[batch_id]['box_type_3d'](
Expand Down

0 comments on commit 3296b4f

Please sign in to comment.