Skip to content

Commit

Permalink
Merge branch 'master' into add_initializer
Browse files Browse the repository at this point in the history
  • Loading branch information
xiliu8006 committed Jun 7, 2021
2 parents ff56ae9 + c33d4ec commit 8830d63
Show file tree
Hide file tree
Showing 11 changed files with 1,083 additions and 6 deletions.
6 changes: 4 additions & 2 deletions mmdet3d/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .interpolate import three_interpolate, three_nn
from .knn import knn
from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d
from .paconv import PAConv, PAConvCUDA, assign_score_withk
from .pointnet_modules import (PointFPModule, PointSAModule, PointSAModuleMSG,
build_sa_module)
from .roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_batch,
Expand All @@ -30,6 +31,7 @@
'furthest_point_sample_with_dist', 'three_interpolate', 'three_nn',
'gather_points', 'grouping_operation', 'group_points', 'GroupAll',
'QueryAndGroup', 'PointSAModule', 'PointSAModuleMSG', 'PointFPModule',
'points_in_boxes_batch', 'get_compiler_version',
'get_compiling_cuda_version', 'Points_Sampler', 'build_sa_module'
'points_in_boxes_batch', 'get_compiler_version', 'assign_score_withk',
'get_compiling_cuda_version', 'Points_Sampler', 'build_sa_module',
'PAConv', 'PAConvCUDA'
]
21 changes: 17 additions & 4 deletions mmdet3d/ops/group_points/group_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Tuple

from ..ball_query import ball_query
from ..knn import knn
from . import group_points_ext


Expand All @@ -13,7 +14,8 @@ class QueryAndGroup(nn.Module):
Groups with a ball query of radius
Args:
max_radius (float): The maximum radius of the balls.
max_radius (float | None): The maximum radius of the balls.
If None is given, we will use kNN sampling instead of ball query.
sample_num (int): Maximum number of features to gather in the ball.
min_radius (float): The minimum radius of the balls.
use_xyz (bool): Whether to use xyz.
Expand Down Expand Up @@ -48,7 +50,12 @@ def __init__(self,
self.uniform_sample = uniform_sample
self.return_unique_cnt = return_unique_cnt
if self.return_unique_cnt:
assert self.uniform_sample
assert self.uniform_sample, \
'uniform_sample should be True when ' \
'returning the count of unique samples'
if self.max_radius is None:
assert not self.normalize_xyz, \
'can not normalize grouped xyz when max_radius is None'

def forward(self, points_xyz, center_xyz, features=None):
"""forward.
Expand All @@ -61,8 +68,14 @@ def forward(self, points_xyz, center_xyz, features=None):
Return:
Tensor: (B, 3 + C, npoint, sample_num) Grouped feature.
"""
idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
points_xyz, center_xyz)
# if self.max_radius is None, we will perform kNN instead of ball query
# idx is of shape [B, npoint, sample_num]
if self.max_radius is None:
idx = knn(self.sample_num, points_xyz, center_xyz, False)
idx = idx.transpose(1, 2).contiguous()
else:
idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
points_xyz, center_xyz)

if self.uniform_sample:
unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
Expand Down
4 changes: 4 additions & 0 deletions mmdet3d/ops/paconv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .assign_score import assign_score_withk
from .paconv import PAConv, PAConvCUDA

__all__ = ['assign_score_withk', 'PAConv', 'PAConvCUDA']
101 changes: 101 additions & 0 deletions mmdet3d/ops/paconv/assign_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from torch.autograd import Function

from . import assign_score_withk_ext


class AssignScoreWithK(Function):
r"""Perform weighted sum to generate output features according to scores.
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
scene_seg/lib/paconv_lib/src/gpu>`_.
This is a memory-efficient CUDA implementation of assign_scores operation,
which first transform all point feature with weight bank, then assemble
neighbor features with `knn_idx` and perform weighted sum of `scores`.
See the `paper <https://arxiv.org/pdf/2103.14635.pdf>`_ appendix Sec. D for
more detailed descriptions.
Note:
This implementation assumes using ``neighbor`` kernel input, which is
(point_features - center_features, point_features).
See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
pointnet2/paconv.py#L128 for more details.
"""

@staticmethod
def forward(ctx,
scores,
point_features,
center_features,
knn_idx,
aggregate='sum'):
"""Forward.
Args:
scores (torch.Tensor): (B, npoint, K, M), predicted scores to
aggregate weight matrices in the weight bank.
``npoint`` is the number of sampled centers.
``K`` is the number of queried neighbors.
``M`` is the number of weight matrices in the weight bank.
point_features (torch.Tensor): (B, N, M, out_dim)
Pre-computed point features to be aggregated.
center_features (torch.Tensor): (B, N, M, out_dim)
Pre-computed center features to be aggregated.
knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
We assume the first idx in each row is the idx of the center.
aggregate (str, optional): Aggregation method.
Can be 'sum', 'avg' or 'max'. Defaults to 'sum'.
Returns:
torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
"""
agg = {'sum': 0, 'avg': 1, 'max': 2}

B, N, M, out_dim = point_features.size()
_, npoint, K, _ = scores.size()

output = point_features.new_zeros((B, out_dim, npoint, K))
assign_score_withk_ext.assign_score_withk_forward_wrapper(
B, N, npoint, M, K, out_dim, agg[aggregate],
point_features.contiguous(), center_features.contiguous(),
scores.contiguous(), knn_idx.contiguous(), output)

ctx.save_for_backward(output, point_features, center_features, scores,
knn_idx)
ctx.agg = agg[aggregate]

return output

@staticmethod
def backward(ctx, grad_out):
"""Backward.
Args:
grad_out (torch.Tensor): (B, out_dim, npoint, K)
Returns:
grad_scores (torch.Tensor): (B, npoint, K, M)
grad_point_features (torch.Tensor): (B, N, M, out_dim)
grad_center_features (torch.Tensor): (B, N, M, out_dim)
"""
_, point_features, center_features, scores, knn_idx = ctx.saved_tensors

agg = ctx.agg

B, N, M, out_dim = point_features.size()
_, npoint, K, _ = scores.size()

grad_point_features = point_features.new_zeros(point_features.shape)
grad_center_features = center_features.new_zeros(center_features.shape)
grad_scores = scores.new_zeros(scores.shape)

assign_score_withk_ext.assign_score_withk_backward_wrapper(
B, N, npoint, M, K, out_dim, agg, grad_out.contiguous(),
point_features.contiguous(), center_features.contiguous(),
scores.contiguous(), knn_idx.contiguous(), grad_point_features,
grad_center_features, grad_scores)

return grad_scores, grad_point_features, \
grad_center_features, None, None


assign_score_withk = AssignScoreWithK.apply
Loading

0 comments on commit 8830d63

Please sign in to comment.