Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix a bug that may cause compilation failure of dynamic voxelization when using GPUs with compute capability lower than 6.x #326

Merged
merged 3 commits into from
Mar 2, 2021

Conversation

zhanggefan
Copy link
Contributor

fix a bug that may cause compilation failure of dynamic voxelization when using GPUs with compute capability lower than 6.x

fix imperfection kernel code that may unintentionally discard valid points when input points count is larger than 50000 * 512 (nearly impossible though).

…when using gpus with compute capability lower than 6.x

fix imperfection kernel code that may unintentionally discard valid points when input points count is larger than 50000 * 512 (nearly impossible though).
@codecov
Copy link

codecov bot commented Feb 28, 2021

Codecov Report

Merging #326 (ce3f06a) into master (93597a5) will decrease coverage by 0.00%.
The diff coverage is 0.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #326      +/-   ##
==========================================
- Coverage   49.70%   49.69%   -0.01%     
==========================================
  Files         174      174              
  Lines       11754    11758       +4     
  Branches     1838     1838              
==========================================
+ Hits         5842     5843       +1     
- Misses       5552     5555       +3     
  Partials      360      360              
Flag Coverage Δ
unittests 49.69% <0.00%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmdet3d/ops/voxel/scatter_points.py 38.00% <0.00%> (-0.78%) ⬇️
mmdet3d/core/visualizer/show_result.py 79.66% <0.00%> (ø)
mmdet3d/core/visualizer/open3d_vis.py 10.58% <0.00%> (+0.40%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 93597a5...ce3f06a. Read the comment docs.

@zhanggefan zhanggefan changed the title [FIX] fix a bug that may cause compilation failure of dynamic voxelization when using GPUs with compute capability lower than 6.x [Fix] fix a bug that may cause compilation failure of dynamic voxelization when using GPUs with compute capability lower than 6.x Mar 1, 2021
@ZwwWayne
Copy link
Collaborator

ZwwWayne commented Mar 1, 2021

We just found that the CUDA kernel cannot successfully be compiled with CUDA 9.0. Does this PR fix this issue?

@zhanggefan
Copy link
Contributor Author

Hi, @ZwwWayne
Sorry I did not test the code on CUDA 9.0. I only tested it on 10.2 and 11.0 with no error.
Could you please share with me the specific compiler error? I will have a look into that and it won't take long.

@ZwwWayne
Copy link
Collaborator

ZwwWayne commented Mar 1, 2021

Hi, @ZwwWayne
Sorry I did not test the code on CUDA 9.0. I only tested it on 10.2 and 11.0 with no error.
Could you please share with me the specific compiler error? I will have a look into that and it won't take long.

We also test the code on CUDA 10.1 environments and it seems to be OK. However, MMDetection3d still needs compatibility with CUDA 9.0 due to some reasons. So it would be nice if you could help and also fix that issue in this PR.

The screenshots of the error is listed below, @Tai-Wang may provide a more detailed log if necessary.
image
image

Thanks in advance.

@zhanggefan
Copy link
Contributor Author

Hi, @ZwwWayne
Sorry I did not test the code on CUDA 9.0. I only tested it on 10.2 and 11.0 with no error.
Could you please share with me the specific compiler error? I will have a look into that and it won't take long.

We also test the code on CUDA 10.1 environments and it seems to be OK. However, MMDetection3d still needs compatibility with CUDA 9.0 due to some reasons. So it would be nice if you could help and also fix that issue in this PR.

The screenshots of the error is listed below, @Tai-Wang may provide a more detailed log if necessary.

Thanks in advance.

It is really weird that from the log it seems all the errors occured when instantiating templates in PyTorch headers.
This PR will not fix this issue because this PR mainly deals with potential compiling errors on cuda arch<6.0, while you were using 6.1.
I don't have enough clue now and I am gonna test it with CUDA 9.0 on docker. It might take a couple of hours. Sorry for the delay.

@zhanggefan
Copy link
Contributor Author

Hi, @Tai-Wang
Which version of PyTorch did you use with CUDA 9.0? Looks like only versions no higher than 1.1 are compatible with CUDA 9.0. Did you build PyTorch from source?

@Tai-Wang
Copy link
Member

Tai-Wang commented Mar 1, 2021

@zhanggefan Yes, exactly. We use the pytorch 1.5 built from source. The compilation error does not exist in the previous released version and it seems to point to the "scatter" cuda file related to your PR.

@ZwwWayne
Copy link
Collaborator

ZwwWayne commented Mar 1, 2021

Hi, @ZwwWayne
Sorry I did not test the code on CUDA 9.0. I only tested it on 10.2 and 11.0 with no error.
Could you please share with me the specific compiler error? I will have a look into that and it won't take long.

We also test the code on CUDA 10.1 environments and it seems to be OK. However, MMDetection3d still needs compatibility with CUDA 9.0 due to some reasons. So it would be nice if you could help and also fix that issue in this PR.
The screenshots of the error is listed below, @Tai-Wang may provide a more detailed log if necessary.
Thanks in advance.

It is really weird that from the log it seems all the errors occured when instantiating templates in PyTorch headers.
This PR will not fix this issue because this PR mainly deals with potential compiling errors on cuda arch<6.0, while you were using 6.1.
I don't have enough clue now and I am gonna test it with CUDA 9.0 on docker. It might take a couple of hours. Sorry for the delay.

LGTM. The failure might need further exploration.

@zhanggefan
Copy link
Contributor Author

@Tai-Wang It is really interesting that when compiling with pytorch1.5 and cuda9.0, NVCC cannot compile any code successfully if it includes "torch/extension.h". I am still trying to find out the root cause. But the workaround code will be ready soon.

@Tai-Wang
Copy link
Member

Tai-Wang commented Mar 1, 2021

@Tai-Wang It is really interesting that when compiling with pytorch1.5 and cuda9.0, NVCC cannot compile any code successfully if it includes "torch/extension.h". I am still trying to find out the root cause. But the workaround code will be ready soon.

OK. When your workaround is ready, I can help check it together.

@zhanggefan
Copy link
Contributor Author

This version can be compiled successfully with PyTorch1.5 on CUDA9.0, and pass the unit test on CUDA10.2. But I have not been able to test its functionality with CUDA9.0. Installing all the dependencies is a nightmare for me...

@Tai-Wang
Copy link
Member

Tai-Wang commented Mar 1, 2021

This version can be compiled successfully with PyTorch1.5 on CUDA9.0, and pass the unit test on CUDA10.2. But I have not been able to test its functionality with CUDA9.0. Installing all the dependencies is a nightmare for me...

How do you test the functionality with CUDA10? Just run some experiments using dynamic scatter? Could you please show an example or give me a standard benchmark, or could you please add some simple tests to validate it on CUDA9 devices?

@zhanggefan
Copy link
Contributor Author

This version can be compiled successfully with PyTorch1.5 on CUDA9.0, and pass the unit test on CUDA10.2. But I have not been able to test its functionality with CUDA9.0. Installing all the dependencies is a nightmare for me...

How do you test the functionality with CUDA10? Just run some experiments using dynamic scatter? Could you please show an example or give me a standard benchmark, or could you please add some simple tests to validate it on CUDA9 devices?

The pytest script is here:
https://github.com/open-mmlab/mmdetection3d/blob/ccd3047a1d62048cc5707e60181b2ab586b8e479/tests/test_models/test_voxel_encoder/test_dynamic_scatter.py

@zhanggefan
Copy link
Contributor Author

Please let me know if the test fails. I am not experienced with the torch extensions before CUDA 10 so I could not guarantee that code works as expected. For CUDA versions>=10 the torch/extension.h is the all-in-one header. But for CUDA9 I have to turn to those backend aten headers that I am not familiar with.

@Tai-Wang
Copy link
Member

Tai-Wang commented Mar 2, 2021

Please let me know if the test fails. I am not experienced with the torch extensions before CUDA 10 so I could not guarantee that code works as expected. For CUDA versions>=10 the torch/extension.h is the all-in-one header. But for CUDA9 I have to turn to those backend aten headers that I am not familiar with.

The compilation and installation is fine, but there is a RuntimeError for the gradcheck in the unit test. Here is the error message for your reference:

  File "tests/test_models/test_voxel_encoder/test_dynamic_scatter.py", line 59, in test_dynamic_scatter
    gradcheck(dsmean, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)
  File "/home/wangtai/anaconda3/envs/open-mmlab/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 281, in gradcheck
    analytical, reentrant, correct_grad_sizes = get_analytical_jacobian(tupled_inputs, o, nondet_tol=nondet_tol)
  File "/home/wangtai/anaconda3/envs/open-mmlab/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 156, in get_analytical_jacobian
    grads_input = torch.autograd.grad(output, diff_input_list, grad_output,
  File "/home/wangtai/anaconda3/envs/open-mmlab/lib/python3.8/site-packages/torch/autograd/__init__.py", line 156, in grad
    return Variable._execution_engine.run_backward(
RuntimeError: Expected isFloatingType(grads[i].scalar_type()) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

@zhanggefan
Copy link
Contributor Author

Please let me know if the test fails. I am not experienced with the torch extensions before CUDA 10 so I could not guarantee that code works as expected. For CUDA versions>=10 the torch/extension.h is the all-in-one header. But for CUDA9 I have to turn to those backend aten headers that I am not familiar with.

The compilation and installation is fine, but there is a RuntimeError for the gradcheck in the unit test. Here is the error message for your reference:

  File "tests/test_models/test_voxel_encoder/test_dynamic_scatter.py", line 59, in test_dynamic_scatter
    gradcheck(dsmean, (feats, coors), eps=1e-2, atol=1e-2, rtol=1e-5)
  File "/home/wangtai/anaconda3/envs/open-mmlab/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 281, in gradcheck
    analytical, reentrant, correct_grad_sizes = get_analytical_jacobian(tupled_inputs, o, nondet_tol=nondet_tol)
  File "/home/wangtai/anaconda3/envs/open-mmlab/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 156, in get_analytical_jacobian
    grads_input = torch.autograd.grad(output, diff_input_list, grad_output,
  File "/home/wangtai/anaconda3/envs/open-mmlab/lib/python3.8/site-packages/torch/autograd/__init__.py", line 156, in grad
    return Variable._execution_engine.run_backward(
RuntimeError: Expected isFloatingType(grads[i].scalar_type()) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Looks like you switched to python 3.8 this time. What is the PyTorch and CUDA version this time?

@Tai-Wang
Copy link
Member

Tai-Wang commented Mar 2, 2021

I think it should be able to be compatible both with python 3.7 and 3.8, which is not a strong constraint for environment. Except python version, we use the same version of CUDA and pytorch built from source.

@Tai-Wang
Copy link
Member

Tai-Wang commented Mar 2, 2021

Just validate it with python 3.7 and get the same error message.

…rk non-floating-point tensor as non-differentiable.
@zhanggefan
Copy link
Contributor Author

Just validate it with python 3.7 and get the same error message.

Error reproduced. The issue closely relates to the following discussion as well as an issue and a PR to PyTorch:
https://discuss.pytorch.org/t/custom-autograd-function-with-int-tensor-input/65907
pytorch/pytorch#37680
pytorch/pytorch#38326

Before this PR the Pytorch's auto-grad engine marks integer output as requires grad by default (for example marking indices tensors as require grad and for most cases, this does not make sense). This PR deals with it and has been merge into the master branch before Pytorch 1.6, so this issue cannot be reproduced with versions later than 1.6. Explicitly marking voxel_coors as non-differentiable could solve this issue.

@zhanggefan
Copy link
Contributor Author

@Tai-Wang
I managed to build up a very similar environment as yours (really lots of labor bro... and I'm really curious about the reason behind your environmental requirements of such a strange combination... hahaha):

  • 1080Ti==sm6.1
  • cuda==9.0
  • nvdriver==450.80.02
  • pytorch==1.5 built from source
  • torchvision==0.6 built from source
  • mmcv-full==1.2.5

The last commit passed the gradcheck without error.

@Tai-Wang
Copy link
Member

Tai-Wang commented Mar 2, 2021

@zhanggefan Haha it is just due to many GPUs available for us are 1080Ti, so we are only able to use CUDA9 and pytorch built from source. Thanks for your contribution. It also works on my side and looks good to me.

tpoisonooo pushed a commit to tpoisonooo/mmdetection3d that referenced this pull request Sep 5, 2022
* start up

* zh-cn v0.1

* [Docs] Add a from-scratch example for "Get Started" (open-mmlab#326)

* Add a from-scratch example

* Fix typo

* resolve comment

* bachslash

* Resolve comments

* Refine commands

* add cn docs

* Correct commands

* fixing...

* update zn-cn docs

* update en link

* add sdk's get-started (open-mmlab#331)

* add sdk's get-started

* add SDK build command

* fix chinglish

* add sdk get start zh_cn

* update zh_cn cite

* fix command

* add selfsup/razor readme

* Fix command

Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants