From 573946a8f66c1f4a2fe1faad22e352a66bbb0092 Mon Sep 17 00:00:00 2001 From: junhaozhang98 <37370309+junhaozhang98@users.noreply.github.com> Date: Wed, 17 Mar 2021 17:56:55 +0800 Subject: [PATCH 1/3] [Feature] Support semantic seg metrics (#332) * add ini * add semantickitti_dataset * add test semantickitti_dataset * delete last line in test_semmaticdataset * add test data * change_names * load_labels dytype * change_name * numpy * name * dtype string * minor issue-string * seg_3d_dtype * add_fast hist * add per_class iou * add seg_eval * add unitest * minor error * minor error2 * minor error3 --- mmdet3d/core/evaluation/__init__.py | 6 +- mmdet3d/core/evaluation/seg_eval.py | 121 ++++++++++++++++++++++++++++ tests/test_metrics/test_seg_eval.py | 35 ++++++++ 3 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 mmdet3d/core/evaluation/seg_eval.py create mode 100644 tests/test_metrics/test_seg_eval.py diff --git a/mmdet3d/core/evaluation/__init__.py b/mmdet3d/core/evaluation/__init__.py index 0d472ca01b..f8cd210a33 100644 --- a/mmdet3d/core/evaluation/__init__.py +++ b/mmdet3d/core/evaluation/__init__.py @@ -1,5 +1,9 @@ from .indoor_eval import indoor_eval from .kitti_utils import kitti_eval, kitti_eval_coco_style from .lyft_eval import lyft_eval +from .seg_eval import seg_eval -__all__ = ['kitti_eval_coco_style', 'kitti_eval', 'indoor_eval', 'lyft_eval'] +__all__ = [ + 'kitti_eval_coco_style', 'kitti_eval', 'indoor_eval', 'lyft_eval', + 'seg_eval' +] diff --git a/mmdet3d/core/evaluation/seg_eval.py b/mmdet3d/core/evaluation/seg_eval.py new file mode 100644 index 0000000000..ad60e8e350 --- /dev/null +++ b/mmdet3d/core/evaluation/seg_eval.py @@ -0,0 +1,121 @@ +import numpy as np +from mmcv.utils import print_log +from terminaltables import AsciiTable + + +def fast_hist(preds, labels, num_classes): + """Compute the confusion matrix for every batch. + + Args: + preds (np.ndarray): Prediction labels of points with shape of + (num_points, ). + labels (np.ndarray): Ground truth labels of points with shape of + (num_points, ). + num_classes (int): number of classes + + Returns: + np.ndarray: Calculated confusion matrix. + """ + + k = (labels >= 0) & (labels < num_classes) + bin_count = np.bincount( + num_classes * labels[k].astype(int) + preds[k], + minlength=num_classes**2) + return bin_count[:num_classes**2].reshape(num_classes, num_classes) + + +def per_class_iou(hist): + """Compute the per class iou. + + Args: + hist(np.ndarray): Overall confusion martix + (num_classes, num_classes ). + + Returns: + np.ndarray: Calculated per class iou + """ + + return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) + + +def get_acc(hist): + """Compute the overall accuracy. + + Args: + hist(np.ndarray): Overall confusion martix + (num_classes, num_classes ). + + Returns: + float: Calculated overall acc + """ + + return np.diag(hist).sum() / hist.sum() + + +def get_acc_cls(hist): + """Compute the class average accuracy. + + Args: + hist(np.ndarray): Overall confusion martix + (num_classes, num_classes ). + + Returns: + float: Calculated class average acc + """ + + return np.nanmean(np.diag(hist) / hist.sum(axis=1)) + + +def seg_eval(gt_labels, seg_preds, label2cat, logger=None): + """Semantic Segmentation Evaluation. + + Evaluate the result of the Semantic Segmentation. + + Args: + gt_labels (list[torch.Tensor]): Ground truth labels. + seg_preds (list[torch.Tensor]): Predtictions + label2cat (dict): Map from label to category. + logger (logging.Logger | str | None): The way to print the mAP + summary. See `mmdet.utils.print_log()` for details. Default: None. + + Return: + dict[str, float]: Dict of results. + """ + assert len(seg_preds) == len(gt_labels) + + hist_list = [] + for i in range(len(seg_preds)): + hist_list.append( + fast_hist(seg_preds[i].numpy().astype(int), + gt_labels[i].numpy().astype(int), len(label2cat))) + iou = per_class_iou(sum(hist_list)) + miou = np.nanmean(iou) + acc = get_acc(sum(hist_list)) + acc_cls = get_acc_cls(sum(hist_list)) + + header = ['classes'] + for i in range(len(label2cat)): + header.append(label2cat[i]) + header.extend(['miou', 'acc', 'acc_cls']) + + ret_dict = dict() + table_columns = [['results']] + for i in range(len(label2cat)): + ret_dict[label2cat[i]] = float(iou[i]) + table_columns.append([f'{iou[i]:.4f}']) + ret_dict['miou'] = float(miou) + ret_dict['acc'] = float(acc) + ret_dict['acc_cls'] = float(acc_cls) + + table_columns.append([f'{miou:.4f}']) + table_columns.append([f'{acc:.4f}']) + table_columns.append([f'{acc_cls:.4f}']) + + table_data = [header] + table_rows = list(zip(*table_columns)) + table_data += table_rows + table = AsciiTable(table_data) + table.inner_footing_row_border = True + print_log('\n' + table.table, logger=logger) + + return ret_dict diff --git a/tests/test_metrics/test_seg_eval.py b/tests/test_metrics/test_seg_eval.py new file mode 100644 index 0000000000..d8850775ad --- /dev/null +++ b/tests/test_metrics/test_seg_eval.py @@ -0,0 +1,35 @@ +import numpy as np +import pytest +import torch + +from mmdet3d.core.evaluation.seg_eval import seg_eval + + +def test_indoor_eval(): + if not torch.cuda.is_available(): + pytest.skip() + seg_preds = [ + torch.Tensor( + [0, 0, 1, 0, 2, 1, 3, 1, 1, 0, 2, 2, 2, 2, 1, 3, 0, 3, 3, 3]) + ] + gt_labels = [ + torch.Tensor( + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) + ] + + label2cat = { + 0: 'car', + 1: 'bicycle', + 2: 'motorcycle', + 3: 'truck', + } + ret_value = seg_eval(gt_labels, seg_preds, label2cat) + + assert np.isclose(ret_value['car'], 0.428571429) + assert np.isclose(ret_value['bicycle'], 0.428571429) + assert np.isclose(ret_value['motorcycle'], 0.6666667) + assert np.isclose(ret_value['truck'], 0.6666667) + + assert np.isclose(ret_value['acc'], 0.7) + assert np.isclose(ret_value['acc_cls'], 0.7) + assert np.isclose(ret_value['miou'], 0.547619048) From 5f99b50efd4c3fcded47a050e6b5a29d740cdd2e Mon Sep 17 00:00:00 2001 From: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com> Date: Wed, 17 Mar 2021 18:13:52 +0800 Subject: [PATCH 2/3] [Feature] Support knn gpu op (#360) * support knn gpu op * made it more robust and fixed comments --- mmdet3d/ops/__init__.py | 3 +- mmdet3d/ops/gather_points/gather_points.py | 14 +- mmdet3d/ops/knn/__init__.py | 3 + mmdet3d/ops/knn/knn.py | 68 +++++ mmdet3d/ops/knn/src/knn.cpp | 62 ++++ mmdet3d/ops/knn/src/knn_cuda.cu | 268 ++++++++++++++++++ setup.py | 5 + .../test_common_modules/test_pointnet_ops.py | 47 ++- 8 files changed, 461 insertions(+), 9 deletions(-) create mode 100644 mmdet3d/ops/knn/__init__.py create mode 100644 mmdet3d/ops/knn/knn.py create mode 100644 mmdet3d/ops/knn/src/knn.cpp create mode 100644 mmdet3d/ops/knn/src/knn_cuda.cu diff --git a/mmdet3d/ops/__init__.py b/mmdet3d/ops/__init__.py index 2bd6ce5dfd..1acda9b19c 100644 --- a/mmdet3d/ops/__init__.py +++ b/mmdet3d/ops/__init__.py @@ -9,6 +9,7 @@ from .group_points import (GroupAll, QueryAndGroup, group_points, grouping_operation) from .interpolate import three_interpolate, three_nn +from .knn import knn from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d from .pointnet_modules import (PointFPModule, PointSAModule, PointSAModuleMSG, build_sa_module) @@ -25,7 +26,7 @@ 'dynamic_scatter', 'DynamicScatter', 'sigmoid_focal_loss', 'SigmoidFocalLoss', 'SparseBasicBlock', 'SparseBottleneck', 'RoIAwarePool3d', 'points_in_boxes_gpu', 'points_in_boxes_cpu', - 'make_sparse_convmodule', 'ball_query', 'furthest_point_sample', + 'make_sparse_convmodule', 'ball_query', 'knn', 'furthest_point_sample', 'furthest_point_sample_with_dist', 'three_interpolate', 'three_nn', 'gather_points', 'grouping_operation', 'group_points', 'GroupAll', 'QueryAndGroup', 'PointSAModule', 'PointSAModuleMSG', 'PointFPModule', diff --git a/mmdet3d/ops/gather_points/gather_points.py b/mmdet3d/ops/gather_points/gather_points.py index a6d240dae4..8f109a0cbd 100644 --- a/mmdet3d/ops/gather_points/gather_points.py +++ b/mmdet3d/ops/gather_points/gather_points.py @@ -12,28 +12,28 @@ class GatherPoints(Function): @staticmethod def forward(ctx, features: torch.Tensor, - indicies: torch.Tensor) -> torch.Tensor: + indices: torch.Tensor) -> torch.Tensor: """forward. Args: features (Tensor): (B, C, N) features to gather. - indicies (Tensor): (B, M) where M is the number of points. + indices (Tensor): (B, M) where M is the number of points. Returns: Tensor: (B, C, M) where M is the number of points. """ assert features.is_contiguous() - assert indicies.is_contiguous() + assert indices.is_contiguous() - B, npoint = indicies.size() + B, npoint = indices.size() _, C, N = features.size() output = torch.cuda.FloatTensor(B, C, npoint) gather_points_ext.gather_points_wrapper(B, C, N, npoint, features, - indicies, output) + indices, output) - ctx.for_backwards = (indicies, C, N) - ctx.mark_non_differentiable(indicies) + ctx.for_backwards = (indices, C, N) + ctx.mark_non_differentiable(indices) return output @staticmethod diff --git a/mmdet3d/ops/knn/__init__.py b/mmdet3d/ops/knn/__init__.py new file mode 100644 index 0000000000..c8cb712b09 --- /dev/null +++ b/mmdet3d/ops/knn/__init__.py @@ -0,0 +1,3 @@ +from .knn import knn + +__all__ = ['knn'] diff --git a/mmdet3d/ops/knn/knn.py b/mmdet3d/ops/knn/knn.py new file mode 100644 index 0000000000..170ac9a32b --- /dev/null +++ b/mmdet3d/ops/knn/knn.py @@ -0,0 +1,68 @@ +import torch +from torch.autograd import Function + +from . import knn_ext + + +class KNN(Function): + """KNN (CUDA). + + Find k-nearest points. + """ + + @staticmethod + def forward(ctx, + k: int, + xyz: torch.Tensor, + center_xyz: torch.Tensor, + transposed: bool = False) -> torch.Tensor: + """forward. + + Args: + k (int): number of nearest neighbors. + xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N). + xyz coordinates of the features. + center_xyz (Tensor): (B, npoint, 3) if transposed == False, + else (B, 3, npoint). centers of the knn query. + transposed (bool): whether the input tensors are transposed. + defaults to False. + + Returns: + Tensor: (B, k, npoint) tensor with the indicies of + the features that form k-nearest neighbours. + """ + assert k > 0 + + B, npoint = center_xyz.shape[:2] + N = xyz.shape[1] + + if not transposed: + xyz = xyz.transpose(2, 1).contiguous() + center_xyz = center_xyz.transpose(2, 1).contiguous() + + assert center_xyz.is_contiguous() + assert xyz.is_contiguous() + + center_xyz_device = center_xyz.get_device() + assert center_xyz_device == xyz.get_device(), \ + 'center_xyz and xyz should be put on the same device' + if torch.cuda.current_device() != center_xyz_device: + torch.cuda.set_device(center_xyz_device) + + idx = center_xyz.new_zeros((B, k, npoint)).long() + + for bi in range(B): + knn_ext.knn_wrapper(xyz[bi], N, center_xyz[bi], npoint, idx[bi], k) + + ctx.mark_non_differentiable(idx) + + idx -= 1 + + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None + + +knn = KNN.apply diff --git a/mmdet3d/ops/knn/src/knn.cpp b/mmdet3d/ops/knn/src/knn.cpp new file mode 100644 index 0000000000..a86e13d0af --- /dev/null +++ b/mmdet3d/ops/knn/src/knn.cpp @@ -0,0 +1,62 @@ +// Modified from https://github.com/unlimblue/KNN_CUDA + +#include +#include +#include + +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_TYPE(x, t) AT_ASSERTM(x.dtype() == t, #x " must be " #t) +#define CHECK_CUDA(x) AT_ASSERTM(x.device().type() == at::Device::Type::CUDA, #x " must be on CUDA") +#define CHECK_INPUT(x, t) CHECK_CONTIGUOUS(x); CHECK_TYPE(x, t); CHECK_CUDA(x) + + +void knn_kernels_launcher( + const float* ref_dev, + int ref_nb, + const float* query_dev, + int query_nb, + int dim, + int k, + float* dist_dev, + long* ind_dev, + cudaStream_t stream + ); + +// std::vector knn_wrapper( +void knn_wrapper( + at::Tensor & ref, + int ref_nb, + at::Tensor & query, + int query_nb, + at::Tensor & ind, + const int k + ) { + + CHECK_INPUT(ref, at::kFloat); + CHECK_INPUT(query, at::kFloat); + const float * ref_dev = ref.data_ptr(); + const float * query_dev = query.data_ptr(); + int dim = query.size(0); + auto dist = at::empty({ref_nb, query_nb}, query.options().dtype(at::kFloat)); + float * dist_dev = dist.data_ptr(); + long * ind_dev = ind.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + knn_kernels_launcher( + ref_dev, + ref_nb, + query_dev, + query_nb, + dim, + k, + dist_dev, + ind_dev, + stream + ); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("knn_wrapper", &knn_wrapper, "knn_wrapper"); +} diff --git a/mmdet3d/ops/knn/src/knn_cuda.cu b/mmdet3d/ops/knn/src/knn_cuda.cu new file mode 100644 index 0000000000..074ce0dfb5 --- /dev/null +++ b/mmdet3d/ops/knn/src/knn_cuda.cu @@ -0,0 +1,268 @@ +/** Modified from https://github.com/unlimblue/KNN_CUDA + * which is the modified version of knn-CUDA + * from https://github.com/vincentfpgarcia/kNN-CUDA + * Last modified by Christopher B. Choy 12/23/2016 + * vincentfpgarcia wrote the original cuda code, Christopher modified it and + * set it up for pytorch 0.4, and unlimblue updated it to pytorch >= 1.0 + */ + +// Includes +#include +#include "cuda.h" + +// Constants used by the program +#define BLOCK_DIM 16 +#define DEBUG 0 + +/** + * Computes the distance between two matrix A (reference points) and + * B (query points) containing respectively wA and wB points. + * + * @param A pointer on the matrix A + * @param wA width of the matrix A = number of points in A + * @param B pointer on the matrix B + * @param wB width of the matrix B = number of points in B + * @param dim dimension of points = height of matrices A and B + * @param AB pointer on the matrix containing the wA*wB distances computed + */ +__global__ void cuComputeDistanceGlobal(const float* A, int wA, + const float* B, int wB, int dim, float* AB){ + + // Declaration of the shared memory arrays As and Bs used to store the sub-matrix of A and B + __shared__ float shared_A[BLOCK_DIM][BLOCK_DIM]; + __shared__ float shared_B[BLOCK_DIM][BLOCK_DIM]; + + // Sub-matrix of A (begin, step, end) and Sub-matrix of B (begin, step) + __shared__ int begin_A; + __shared__ int begin_B; + __shared__ int step_A; + __shared__ int step_B; + __shared__ int end_A; + + // Thread index + int tx = threadIdx.x; + int ty = threadIdx.y; + + // Other variables + float tmp; + float ssd = 0; + + // Loop parameters + begin_A = BLOCK_DIM * blockIdx.y; + begin_B = BLOCK_DIM * blockIdx.x; + step_A = BLOCK_DIM * wA; + step_B = BLOCK_DIM * wB; + end_A = begin_A + (dim-1) * wA; + + // Conditions + int cond0 = (begin_A + tx < wA); // used to write in shared memory + int cond1 = (begin_B + tx < wB); // used to write in shared memory & to computations and to write in output matrix + int cond2 = (begin_A + ty < wA); // used to computations and to write in output matrix + + // Loop over all the sub-matrices of A and B required to compute the block sub-matrix + for (int a = begin_A, b = begin_B; a <= end_A; a += step_A, b += step_B) { + // Load the matrices from device memory to shared memory; each thread loads one element of each matrix + if (a/wA + ty < dim){ + shared_A[ty][tx] = (cond0)? A[a + wA * ty + tx] : 0; + shared_B[ty][tx] = (cond1)? B[b + wB * ty + tx] : 0; + } + else{ + shared_A[ty][tx] = 0; + shared_B[ty][tx] = 0; + } + + // Synchronize to make sure the matrices are loaded + __syncthreads(); + + // Compute the difference between the two matrixes; each thread computes one element of the block sub-matrix + if (cond2 && cond1){ + for (int k = 0; k < BLOCK_DIM; ++k){ + tmp = shared_A[k][ty] - shared_B[k][tx]; + ssd += tmp*tmp; + } + } + + // Synchronize to make sure that the preceding computation is done before loading two new sub-matrices of A and B in the next iteration + __syncthreads(); + } + + // Write the block sub-matrix to device memory; each thread writes one element + if (cond2 && cond1) + AB[(begin_A + ty) * wB + begin_B + tx] = ssd; +} + + +/** + * Gathers k-th smallest distances for each column of the distance matrix in the top. + * + * @param dist distance matrix + * @param ind index matrix + * @param width width of the distance matrix and of the index matrix + * @param height height of the distance matrix and of the index matrix + * @param k number of neighbors to consider + */ +__global__ void cuInsertionSort(float *dist, long *ind, int width, int height, int k){ + + // Variables + int l, i, j; + float *p_dist; + long *p_ind; + float curr_dist, max_dist; + long curr_row, max_row; + unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x; + if (xIndexcurr_dist){ + i=a; + break; + } + } + for (j=l; j>i; j--){ + p_dist[j*width] = p_dist[(j-1)*width]; + p_ind[j*width] = p_ind[(j-1)*width]; + } + p_dist[i*width] = curr_dist; + p_ind[i*width] = l + 1; + } else { + p_ind[l*width] = l + 1; + } + max_dist = p_dist[curr_row]; + } + + // Part 2 : insert element in the k-th first lines + max_row = (k-1)*width; + for (l=k; lcurr_dist){ + i=a; + break; + } + } + for (j=k-1; j>i; j--){ + p_dist[j*width] = p_dist[(j-1)*width]; + p_ind[j*width] = p_ind[(j-1)*width]; + } + p_dist[i*width] = curr_dist; + p_ind[i*width] = l + 1; + max_dist = p_dist[max_row]; + } + } + } +} + + +/** + * Computes the square root of the first line (width-th first element) + * of the distance matrix. + * + * @param dist distance matrix + * @param width width of the distance matrix + * @param k number of neighbors to consider + */ +__global__ void cuParallelSqrt(float *dist, int width, int k){ + unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int yIndex = blockIdx.y * blockDim.y + threadIdx.y; + if (xIndex>>(ref_dev, ref_nb, + query_dev, query_nb, dim, dist_dev); + +#if DEBUG + printf("Pre insertionSort\n"); + debug(dist_dev, ind_dev, query_nb, k); +#endif + + // Kernel 2: Sort each column + cuInsertionSort<<>>(dist_dev, ind_dev, query_nb, ref_nb, k); + +#if DEBUG + printf("Post insertionSort\n"); + debug(dist_dev, ind_dev, query_nb, k); +#endif + + // Kernel 3: Compute square root of k first elements + cuParallelSqrt<<>>(dist_dev, query_nb, k); +} diff --git a/setup.py b/setup.py index 4a4a5a2b14..7f123e1160 100644 --- a/setup.py +++ b/setup.py @@ -221,6 +221,11 @@ def gen_packages_items(): module='mmdet3d.ops.ball_query', sources=['src/ball_query.cpp'], sources_cuda=['src/ball_query_cuda.cu']), + make_cuda_ext( + name='knn_ext', + module='mmdet3d.ops.knn', + sources=['src/knn.cpp'], + sources_cuda=['src/knn_cuda.cu']), make_cuda_ext( name='group_points_ext', module='mmdet3d.ops.group_points', diff --git a/tests/test_models/test_common_modules/test_pointnet_ops.py b/tests/test_models/test_common_modules/test_pointnet_ops.py index e7335017cb..c5565e4214 100644 --- a/tests/test_models/test_common_modules/test_pointnet_ops.py +++ b/tests/test_models/test_common_modules/test_pointnet_ops.py @@ -3,7 +3,7 @@ from mmdet3d.ops import (ball_query, furthest_point_sample, furthest_point_sample_with_dist, gather_points, - grouping_operation, three_interpolate, three_nn) + grouping_operation, knn, three_interpolate, three_nn) def test_fps(): @@ -73,6 +73,51 @@ def test_ball_query(): assert torch.all(idx == expected_idx) +def test_knn(): + if not torch.cuda.is_available(): + pytest.skip() + new_xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], + [-2.2769, 2.7817, -0.2334], + [-0.4003, 2.4666, -0.5116], + [-0.0740, 1.3147, -1.3625], + [-0.0740, 1.3147, -1.3625]], + [[-2.0289, 2.4952, -0.1708], + [-2.0668, 6.0278, -0.4875], + [0.4066, 1.4211, -0.2947], + [-2.0289, 2.4952, -0.1708], + [-2.0289, 2.4952, -0.1708]]]).cuda() + + xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634], + [-0.4003, 2.4666, + -0.5116], [-0.5251, 2.4379, -0.8466], + [-0.9691, 1.1418, + -1.3733], [-0.2232, 0.9561, -1.3626], + [-2.2769, 2.7817, -0.2334], + [-0.2822, 1.3192, -1.3645], [0.1533, 1.5024, -1.0432], + [0.4917, 1.1529, -1.3496]], + [[-2.0289, 2.4952, + -0.1708], [-0.7188, 0.9956, -0.5096], + [-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610], + [0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791], + [-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947], + [0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, + -1.2000]]]).cuda() + + idx = knn(5, xyz, new_xyz) + new_xyz_ = new_xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1) + xyz_ = xyz.unsqueeze(1).repeat(1, new_xyz.shape[1], 1, 1) + dist = ((new_xyz_ - xyz_) * (new_xyz_ - xyz_)).sum(-1) + expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1) + assert torch.all(idx == expected_idx) + + idx = knn(5, xyz, xyz) + xyz_ = xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1) + xyz__ = xyz.unsqueeze(1).repeat(1, xyz.shape[1], 1, 1) + dist = ((xyz_ - xyz__) * (xyz_ - xyz__)).sum(-1) + expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1) + assert torch.all(idx == expected_idx) + + def test_grouping_points(): if not torch.cuda.is_available(): pytest.skip() From f0ba0ce2903a880f5de46afd6c7559f801965bcf Mon Sep 17 00:00:00 2001 From: Tianwei Yin Date: Sat, 20 Mar 2021 02:22:53 -0500 Subject: [PATCH 3/3] Update Bibtex (#368) * Update README.md * Update README_zh-CN.md * Update README.md --- README.md | 2 +- README_zh-CN.md | 2 +- configs/centerpoint/README.md | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f93acb0590..9ed761295e 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ Support methods - [x] [3DSSD (CVPR'2020)](configs/3dssd/README.md) - [x] [Part-A2 (TPAMI'2020)](configs/parta2/README.md) - [x] [MVXNet (ICRA'2019)](configs/mvxnet/README.md) -- [x] [CenterPoint (Arxiv'2020)](configs/centerpoint/README.md) +- [x] [CenterPoint (CVPR'2021)](configs/centerpoint/README.md) - [x] [SSN (ECCV'2020)](configs/ssn/README.md) | | ResNet | ResNeXt | SENet |PointNet++ | HRNet | RegNetX | Res2Net | diff --git a/README_zh-CN.md b/README_zh-CN.md index b0c4c4936e..641a9c8baf 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -83,7 +83,7 @@ MMDetection3D 是一个基于 PyTorch 的目标检测开源工具箱, 下一代 - [x] [3DSSD (CVPR'2020)](configs/3dssd/README.md) - [x] [Part-A2 (TPAMI'2020)](configs/parta2/README.md) - [x] [MVXNet (ICRA'2019)](configs/mvxnet/README.md) -- [x] [CenterPoint (Arxiv'2020)](configs/centerpoint/README.md) +- [x] [CenterPoint (CVPR'2021)](configs/centerpoint/README.md) - [x] [SSN (ECCV'2020)](configs/ssn/README.md) | | ResNet | ResNeXt | SENet |PointNet++ | HRNet | RegNetX | Res2Net | diff --git a/configs/centerpoint/README.md b/configs/centerpoint/README.md index 03372ce610..ffab49d7ac 100644 --- a/configs/centerpoint/README.md +++ b/configs/centerpoint/README.md @@ -27,11 +27,11 @@ We follow the below style to name config files. Contributors are advised to foll `{dataset}`: dataset like nus-3d, kitti-3d, lyft-3d, scannet-3d, sunrgbd-3d. We also indicate the number of classes we are using if there exist multiple settings, e.g., kitti-3d-3class and kitti-3d-car means training on KITTI dataset with 3 classes and single class, respectively. ``` -@article{yin2020center, - title={Center-based 3d object detection and tracking}, +@article{yin2021center, + title={Center-based 3D Object Detection and Tracking}, author={Yin, Tianwei and Zhou, Xingyi and Kr{\"a}henb{\"u}hl, Philipp}, - journal={arXiv preprint arXiv:2006.11275}, - year={2020} + journal={CVPR}, + year={2021}, } ```