Skip to content

Commit

Permalink
[Enhance] refactor iou_neg_piecewise_sampler.py (#842)
Browse files Browse the repository at this point in the history
* [Refactor] Main code modification for coordinate system refactor (#677)

* [Enhance] Add script for data update (#774)

* Fixed wrong config paths and fixed a bug in test

* Fixed metafile

* Coord sys refactor (main code)

* Update test_waymo_dataset.py

* Manually resolve conflict

* Removed unused lines and fixed imports

* remove coord2box and box2coord

* update dir_limit_offset

* Some minor improvements

* Removed some \s in comments

* Revert a change

* Change Box3DMode to Coord3DMode where points are converted

* Fix points_in_bbox function

* Fix Imvoxelnet config

* Revert adding a line

* Fix rotation bug when batch size is 0

* Keep sign of dir_scores as before

* Fix several comments

* Add a comment

* Fix docstring

* Add data update scripts

* Fix comments

* fix import

* refactor iou_neg_piecewise_sampler.py

* add docstring

* modify docstring

Co-authored-by: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com>
Co-authored-by: THU17cyz <congyezhen71@hotmail.com>
  • Loading branch information
3 people authored Aug 9, 2021
1 parent 7331fd0 commit b4ea160
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions mmdet3d/core/bbox/samplers/iou_neg_piecewise_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def _sample_neg(self, assign_result, num_expected, **kwargs):
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected:
return neg_inds
if len(neg_inds) <= 0:
raise NotImplementedError(
'Not support sampling the negative samples when the length '
'of negative samples is 0')
else:
neg_inds_choice = neg_inds.new_zeros([0])
extend_num = 0
Expand Down Expand Up @@ -87,12 +89,38 @@ def _sample_neg(self, assign_result, num_expected, **kwargs):
neg_inds_choice = torch.cat(
[neg_inds_choice, neg_inds[piece_neg_inds]], dim=0)
extend_num += piece_expected_num - len(piece_neg_inds)

# for the last piece
if piece_inds == self.neg_piece_num - 1:
extend_neg_num = num_expected - len(neg_inds_choice)
# if the numbers of nagetive samples > 0, we will
# randomly select num_expected samples in last piece
if piece_neg_inds.numel() > 0:
rand_idx = torch.randint(
low=0,
high=piece_neg_inds.numel(),
size=(extend_neg_num, )).long()
neg_inds_choice = torch.cat(
[neg_inds_choice, piece_neg_inds[rand_idx]],
dim=0)
# if the numbers of nagetive samples == 0, we will
# randomly select num_expected samples in all
# previous pieces
else:
rand_idx = torch.randint(
low=0,
high=neg_inds_choice.numel(),
size=(extend_neg_num, )).long()
neg_inds_choice = torch.cat(
[neg_inds_choice, neg_inds_choice[rand_idx]],
dim=0)
else:
piece_choice = self.random_choice(piece_neg_inds,
piece_expected_num)
neg_inds_choice = torch.cat(
[neg_inds_choice, neg_inds[piece_choice]], dim=0)
extend_num = 0
assert len(neg_inds_choice) == num_expected
return neg_inds_choice

def sample(self,
Expand Down Expand Up @@ -144,7 +172,6 @@ def sample(self,
num_expected_neg = neg_upper_bound
neg_inds = self.neg_sampler._sample_neg(
assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique()

sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
Expand Down

0 comments on commit b4ea160

Please sign in to comment.