Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA out of memory for CenterPoint even with batch size 1. #1398

Closed
Divadi opened this issue Apr 14, 2022 · 15 comments
Closed

CUDA out of memory for CenterPoint even with batch size 1. #1398

Divadi opened this issue Apr 14, 2022 · 15 comments
Assignees

Comments

@Divadi
Copy link
Contributor

Divadi commented Apr 14, 2022

I have tried to train CenterPoint with the default config, and reduced batch size when I ran out of CUDA memory. However, even with batch size 1 on 10GB memory, I am unable to fit into CUDA memory, which I think is strange.

Describe the bug
I encounter CUDA out of memory error when training CenterPoint on 3080 GPUs (10GB) even with batch size 1.

Reproduction

  1. What command or script did you run?
./tools/dist_train.sh configs/centerpoint/centerpoint_01voxel_second_secfpn_4x8_cyclic_20e_nus.py 2 --autoscale-lr --gpus 2
  1. Did you make any modifications on the code or config? Did you understand what you have modified?
    No changes.

  2. What dataset did you use?
    nuScenes, no changes.

Environment

  1. Please run python mmdet3d/utils/collect_env.py to collect necessary environment information and paste it here.
sys.platform: linux
Python: 3.7.13 (default, Mar 29 2022, 02:18:16) [GCC 7.5.0]
CUDA available: True
GPU 0,1: GeForce RTX 3080
CUDA_HOME: /usr/local/cuda-11.1
NVCC: Build cuda_11.1.TC455_06.29190527_0
GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
PyTorch: 1.10.0
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) oneAPI Math Kernel Library Version 2021.4-Product Build 20210904 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.2.3 (Git Hash 7336ca9f055cf1bfa13efb658fe15dc9b41f0740)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.3
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  - CuDNN 8.2
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.2.0, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.10.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, 

TorchVision: 0.11.1
OpenCV: 4.5.5
MMCV: 1.4.8
MMCV Compiler: GCC 7.3
MMCV CUDA Compiler: 11.3
MMDetection: 2.19.0
MMSegmentation: 0.20.1
MMDetection3D: 1.0.0rc1+333536f
  1. You may add addition that may be helpful for locating the problem, such as
    • How you installed PyTorch [e.g., pip, conda, source] conda
    • Other environment variables that may be related (such as $PATH, $LD_LIBRARY_PATH, $PYTHONPATH, etc.)

Error traceback
If applicable, paste the error trackback here.

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 9.78 GiB total capacity; 1.43 GiB already allocated; 6.44 MiB free; 1.52 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
@Divadi
Copy link
Contributor Author

Divadi commented Apr 14, 2022

I was able to run a few iterations using 1 batch size, 1 GPU. However, the loss seems unreasonably high:

2022-04-14 08:30:04,262 - mmdet - INFO - Epoch [1][50/123580]	lr: 3.125e-06, eta: 12 days, 13:00:14, time: 0.438, data_time: 0.106, memory: 1394, task0.loss_heatmap: 876.6298, task0.loss_bbox: 1.9027, task1.loss_heatmap: 763.6281, task1.loss_bbox: 2.0514, task2.loss_heatmap: 1303.4208, task2.loss_bbox: 2.4338, task3.loss_heatmap: 1390.9180, task3.loss_bbox: 1.5618, task4.loss_heatmap: 505.8650, task4.loss_bbox: 1.7147, task5.loss_heatmap: 964.0964, task5.loss_bbox: 1.9008, loss: 5816.1234, grad_norm: 143364.3639
2022-04-14 08:30:20,176 - mmdet - INFO - Epoch [1][100/123580]	lr: 3.125e-06, eta: 10 days, 19:44:54, time: 0.318, data_time: 0.002, memory: 1394, task0.loss_heatmap: 1001.6220, task0.loss_bbox: 1.9754, task1.loss_heatmap: 774.8115, task1.loss_bbox: 1.9916, task2.loss_heatmap: 870.7324, task2.loss_bbox: 2.4756, task3.loss_heatmap: 1200.6614, task3.loss_bbox: 1.6422, task4.loss_heatmap: 472.1978, task4.loss_bbox: 1.5901, task5.loss_heatmap: 1012.5033, task5.loss_bbox: 1.7780, loss: 5343.9811, grad_norm: 123658.1527

I am not sure what the cause of this is - was CenterPoint affected by the v1 update?

With regards to the GRAM issue, it was pointed out previously in #983 that the 4096 parameter in get_indice_pairs causes a huge increase in amount of GRAM consumed. Previously, I would simply modify this in mmdet3d ops and re-build, reducing GRAM usage by a few GB. However, now, with everything in mmcv, this is more difficult.

I also tried running the above setup with spconv v2. With it, I can comfortably fit 4 batch size on 10 GB (8 GB used). When using FP16, I am able to fit 8 batch size on 10 GB (9GB used). Are there plans to integrate spconv v2 into mmdetection3d?

@ammaryasirnaich
Copy link

@Divadi Can you share the steps that you have managed to modify the mmdet3d ops and re-building it along with spconv v2? It will be a big deal for me

@Divadi
Copy link
Contributor Author

Divadi commented Apr 14, 2022

Of course - it's a bit hacky and I have not confirmed it yet, but here it is.
Note that this is with mmcv 1.4.8, newest version of mmdetection3d

Installed spconv v2 via pip
Changed mmdet3d/ops/sparse_block.py, mmdet3d/models/middle_encoders/sparse_encoder.py, mmdet3d/models/middle_encoders/sparse_unet.py to use spconv v2 when environment variable USE_SPCONV_v2=1
Changes can be found in the zip:
changed_files.zip

Changed files also includes a "overwrite_spconv" - just put this anywhere and import it in your train/test. This hackily removes the mmcv spconv modules and replaces them with the spconv versions so they can be used in the build-from-registry functions in mmdet3d.

This basically doesn't use spconv in mmcv at all and instead uses spconv v2.

@ammaryasirnaich
Copy link

Much appreciated @Divadi, I will give it a try. Many thanks for sharing the files 👍

@Tai-Wang
Copy link
Member

Great solutions! We won't consider using v2 to directly replace v1 because we still would like to support cuda9, so a temporary solution is to sort out a relatively elegant way to use spconv v2.0 and we will add the solution into our documentation for reference. Many thanks for your sharing.

@Tai-Wang
Copy link
Member

If anyone from the community has better solutions, welcome to share yours in this issue. We will also have a try and update the documentation accordingly ASAP.

@Tai-Wang
Copy link
Member

BTW, @Divadi have you ever successfully trained this CenterPoint config before the release of mmdet3d v1.0.0rc? We need to check whether our refactoring does not affect this.

@Divadi
Copy link
Contributor Author

Divadi commented Apr 16, 2022

I have not tried; I can try pulling an older version later and train.

@ammaryasirnaich
Copy link

Much appreciated @Divadi, I will give it a try. Many thanks for sharing the files 👍

@Divadi, I have managed to run the workflow with the files that you shared, and it's working. 💯

@ammaryasirnaich
Copy link

I have the development environment as below:

GPU 0: NVIDIA GeForce RTX 3080
CUDA_HOME: /usr/local/cuda
NVCC: Build cuda_11.3.r11.3/compiler.29920130_0
GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
PyTorch: 1.11.0
PyTorch compiling details: PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) oneAPI Math Kernel Library Version 2021.4-Product Build 20210904 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.5.2 (Git Hash a9302535553c73243c632ad3c4c80beec3d19a1e)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.3
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  - CuDNN 8.2
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.2.0, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

TorchVision: 0.12.0
OpenCV: 4.5.5
MMCV: 1.4.8
MMCV Compiler: GCC 7.5
MMCV CUDA Compiler: 11.3
MMDetection: 2.22.0
MMSegmentation: 0.22.1
MMDetection3D: 1.0.0rc0+

@Divadi
Copy link
Contributor Author

Divadi commented Apr 16, 2022

That's great - if possible, could you share if the performance is as you would expect?
I'm not sure how different, theoretically, spconv 1 and 2 are.

@ammaryasirnaich
Copy link

That's great - if possible, could you share if the performance is as you would expect? I'm not sure how different, theoretically, spconv 1 and 2 are.

Well, to be honest, I couldn't notice any significant decrease in memory. But I think the performance is almost the same. It can be concluded with the help of using a tensorboard profiler I think so. I tried to use the profiler logger, but I think there is a bug, it's not logging the details. Also, I myself am not aware of the theoretical difference b/w spconv 1 and 2.

@Divadi
Copy link
Contributor Author

Divadi commented Apr 19, 2022

@ammaryasirnaich That's strange; are you sure the spconv 2 version is the one being used? spconv 2, compared to the implementation in mmcv, should use substantially less memory & be noticeably faster for models with an spconv backbone.

@ammaryasirnaich
Copy link

@Divadi , your impression about the wrong spconv version was right. It was my bad, typo mistake in the environmental variable that caused to skip the spconv 2 import. After correcting it have experience a significate decrease in memory allocation as you have mentioned. Many thanks for pointing it out

@Tai-Wang
Copy link
Member

Thanks for your discussion. We have decided to officially support the spconv 2.0 option with #1421 . Please stay tuned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants