Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong output from DynamicScatter #1177

Closed
Abyssaledge opened this issue Jan 15, 2022 · 6 comments
Closed

Wrong output from DynamicScatter #1177

Abyssaledge opened this issue Jan 15, 2022 · 6 comments

Comments

@Abyssaledge
Copy link

Abyssaledge commented Jan 15, 2022

Reproduction
In my perspective, the following script should not generate an empty tensor, because there are clearly two valid points.
I have not analyzed it in depth, but I hope the following script can help you figure it out.

BTW, according to my experience, it seems better to implement dynamic voxelization with torch.unique and torch_scatter. The output of torch.unique with return_inverse=True make it very easy to map voxel features to points by a simple indexing op instead of map_voxel_center_to_point function in DynamicVFE. The integer canvas in map_voxel_center_to_point seems memory-consuming when using very tiny voxels in 3D space. I share my torch_scatter based DynamicVFE in the end, hope it could help.

import torch
from mmdet3d.ops import DynamicScatter
voxel_size = (0.32, 0.32, 6)
voxel_size_t = torch.tensor(voxel_size, device='cuda:0')
point_cloud_range = [-74.88, -74.88, -2, 74.88, 74.88, 4]
pc_range = torch.tensor(point_cloud_range, device='cuda:0')
cluster_scatter = DynamicScatter(voxel_size, point_cloud_range, average_points=True)
features = torch.tensor([[68.9309, 31.5947,  0.8600], 
        [69.0586, 31.4590,  0.8594]], device='cuda:0')
coors = torch.tensor([[  0,   0, 332, 449],[  0,   0, 332, 449]], device='cuda:0', dtype=torch.int32)
voxel_mean, mean_coors = cluster_scatter(features, coors)
print(voxel_mean, mean_coors)

# make sure voxelization is right
print((features - pc_range[None, :3]) // voxel_size_t[None,:])

'''
Output:
tensor([], device='cuda:0', size=(0, 3)) tensor([], device='cuda:0', size=(0, 4), dtype=torch.int32)
tensor([[449., 332.,   0.], 
        [449., 332.,   0.]], device='cuda:0')
'''

Environment

sys.platform: linux
Python: 3.6.8 (default, Jul  2 2019, 13:27:03) [GCC 5.4.0 20160609]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA GeForce RTX 2080 Ti
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 10.2, V10.2.89
GCC: gcc (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
PyTorch: 1.9.0+cu102
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.2
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70
  - CuDNN 7.6.5
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=10.2, CUDNN_VERSION=7.6.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, 

TorchVision: 0.10.0+cu102
OpenCV: 4.1.0
MMCV: 1.3.8
MMCV Compiler: GCC 5.4
MMCV CUDA Compiler: 10.2
MMDetection: 2.14.0
MMSegmentation: 0.14.1
MMDetection3D: 0.15.0+8e0d418

torch_scatter based VFE, having the same interface with the one in MMDetection3D:

@VOXEL_ENCODERS.register_module()
class DynamicScatterVFE(DynamicVFE):
    """ Same with DynamicVFE but use torch_scatter to avoid construct canvas in map_voxel_center_to_point.
    The canvas is very memory-consuming when use tiny voxel size (5cm * 5cm * 5cm) in large 3D space.
    """

    def __init__(self,
                 in_channels=4,
                 feat_channels=[],
                 with_distance=False,
                 with_cluster_center=False,
                 with_voxel_center=False,
                 voxel_size=(0.2, 0.2, 4),
                 point_cloud_range=(0, -40, -3, 70.4, 40, 1),
                 norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
                 mode='max',
                 fusion_layer=None,
                 return_point_feats=False,
                 return_inv=False,
                 ):
        super(DynamicScatterVFE, self).__init__(
            in_channels,
            feat_channels,
            with_distance,
            with_cluster_center,
            with_voxel_center,
            voxel_size,
            point_cloud_range,
            norm_cfg,
            mode,
            fusion_layer,
            return_point_feats,
        )
        # overwrite
        self.scatter = None
        self.vfe_scatter = None
        self.cluster_scatter = None
        self.mode = mode
        self.return_inv=return_inv

    def map_voxel_center_to_point(self, voxel_mean, voxel2point_inds):

        return voxel_mean[voxel2point_inds]

    # if out_fp16=True, the large numbers of points 
    # lead to overflow error in following layers
    @force_fp32(out_fp16=False)
    def forward(self,
                features,
                coors,
                points=None,
                img_feats=None,
                img_metas=None):

        features_ls = [features]
        origin_point_coors = features[:, :3]
        # Find distance of x, y, and z from cluster center
        if self._with_cluster_center:
            voxel_mean, mean_coors, unq_inv = self.scatter_v2(features, coors, 'avg')
            points_mean = self.map_voxel_center_to_point(
                voxel_mean, unq_inv)
            # TODO: maybe also do cluster for reflectivity
            f_cluster = features[:, :3] - points_mean[:, :3]
            features_ls.append(f_cluster)

        # Find distance of x, y, and z from pillar center
        if self._with_voxel_center:
            f_center = features.new_zeros(size=(features.size(0), 3))
            f_center[:, 0] = features[:, 0] - (
                coors[:, 3].type_as(features) * self.vx + self.x_offset)
            f_center[:, 1] = features[:, 1] - (
                coors[:, 2].type_as(features) * self.vy + self.y_offset)
            f_center[:, 2] = features[:, 2] - (
                coors[:, 1].type_as(features) * self.vz + self.z_offset)
            features_ls.append(f_center)

        if self._with_distance:
            points_dist = torch.norm(features[:, :3], 2, 1, keepdim=True)
            features_ls.append(points_dist)


        # Combine together feature decorations
        features = torch.cat(features_ls, dim=-1)

        for i, vfe in enumerate(self.vfe_layers):
            point_feats = vfe(features)

            if (i == len(self.vfe_layers) - 1 and self.fusion_layer is not None
                    and img_feats is not None):
                point_feats = self.fusion_layer(img_feats, points, point_feats,
                                                img_metas)
            voxel_feats, voxel_coors, unq_inv = self.scatter_v2(point_feats, coors, self.mode)
            if i != len(self.vfe_layers) - 1:
                # need to concat voxel feats if it is not the last vfe
                feat_per_point = self.map_voxel_center_to_point(voxel_feats, unq_inv)
                features = torch.cat([point_feats, feat_per_point], dim=1)
        if self.return_point_feats:
            return point_feats

        if self.return_inv:
            return voxel_feats, voxel_coors, unq_inv
        else:
            return voxel_feats, voxel_coors

    def scatter_v2(self, feat, coors, mode):

        new_coors, unq_inv, unq_cnt = torch.unique(coors, return_inverse=True, return_counts=True, dim=0)

        if mode == 'max':
            new_feat, argmax = torch_scatter.scatter_max(feat, unq_inv, dim=0)
        elif mode == 'avg':
            new_feat = torch_scatter.scatter(feat, unq_inv, dim=0, reduce='mean')
        else:
            raise NotImplementedError

        return new_feat, new_coors, unq_inv
@ZCMax ZCMax added the usage label Jan 18, 2022
@Tai-Wang
Copy link
Member

Why does your given coors have 4 dimensions? Does it cause the unexpected output?

@Abyssaledge
Copy link
Author

Abyssaledge commented Jan 19, 2022

I think coors should be 4-dimension (batch_idx, z, y, x) according to https://github.com/open-mmlab/mmdetection3d/blob/master/mmdet3d/models/voxel_encoders/voxel_encoder.py#L230.
I mmdet3d, voxelize function in detector will attach the batch_idx to 3-dimension coors before passing it to voxel_encoder.

@Tai-Wang
Copy link
Member

Yes, I see. Have you ever tried to add a breakpoint to have a look at the intermediate results when running this function for a complete point cloud? Does it run correctly? Let's see whether this problem is a corner case or not.

@Abyssaledge
Copy link
Author

Abyssaledge commented Jan 19, 2022

According to my observations, this error only occurs when the number of points is very small. When I take point clouds containing at least 500 points as input, no empty tensor is generated and the algorithm works fine.
Although no errors are raised under this situation, I am afraid that the logical correctness can not be strictly guaranteed, because the code has normal behavior as long as there is no empty tensor. I have encountered a problem that the output DynamicScatter is slightly different from the output of torch_scatter, but I can not reproduce it now.
The good news is I can obtain almost the same performance with DynamicScatter and torch_scatter, so it might not a big deal.

@Tai-Wang
Copy link
Member

Thanks for your sharing of observations. I guess there might be some problems with corner cases, and I will just keep this issue open for community discussion. We will also further check this part afterward.

@zhanggefan
Copy link
Contributor

This is a known issue #768. It has been fixed by PR #915 since version 0.17.1.

@Tai-Wang Tai-Wang closed this as completed Jun 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants