-
Notifications
You must be signed in to change notification settings - Fork 23.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ExpandableMemorySegments not working on H100s/A100s #122057
Comments
Not reproducible on A100 using the latest nightly binaries: PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' python -c "import torch; print(torch.randn(1).cuda()); print(torch.__version__)"
tensor([0.6288], device='cuda:0')
2.4.0.dev20240405+cu118
...
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' python -c "import torch; print(torch.randn(1).cuda()); print(torch.__version__)"
tensor([-1.6317], device='cuda:0')
2.4.0.dev20240405+cu121 Creating a profile for this simple code also shows the expected Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- --------- --------- -------- -------- ----------- ---------------------------------
75.0 85860531 20 4293026.6 4443887.0 6863 16916216 4680471.9 cudaLaunchKernel
22.5 25785597 9 2865066.3 2258438.0 2142408 6395693 1426342.9 cudaGetDeviceProperties_v2_v12000
1.6 1879258 2 939629.0 939629.0 13465 1865793 1309793.7 cudaHostAlloc
0.4 505239 1 505239.0 505239.0 505239 505239 0.0 cuMemSetAccess
0.1 136748 8 17093.5 18715.5 5721 33353 10284.1 cudaMemcpyAsync
0.1 130117 1 130117.0 130117.0 130117 130117 0.0 cuMemCreate
0.1 82436 8 10304.5 11296.5 3757 16101 3873.2 cudaStreamSynchronize
0.0 27362 1 27362.0 27362.0 27362 27362 0.0 cuMemAddressReserve
0.0 6192 1 6192.0 6192.0 6192 6192 0.0 cuMemMap
0.0 2756 1 2756.0 2756.0 2756 2756 0.0 cuModuleGetLoadingMode Calls without setting the Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- --------- --------- -------- -------- ----------- ---------------------------------
77.4 76950244 20 3847512.2 4424450.0 6813 13703647 3900974.4 cudaLaunchKernel
19.8 19650438 8 2456304.8 2223391.0 2065391 3934390 614051.4 cudaGetDeviceProperties_v2_v12000
1.6 1585640 2 792820.0 792820.0 19066 1566574 1094253.4 cudaHostAlloc
1.0 997744 1 997744.0 997744.0 997744 997744 0.0 cudaMalloc
0.1 125379 8 15672.4 16461.5 5230 31550 9835.2 cudaMemcpyAsync
0.1 80774 8 10096.8 11116.0 3386 15860 3699.4 cudaStreamSynchronize
0.0 14888 1 14888.0 14888.0 14888 14888 0.0 cudaStreamIsCapturing_v10000
0.0 2705 1 2705.0 2705.0 2705 2705 0.0 cuModuleGetLoadingMode |
@ptrblck @chauhang I've also encountered this issue, below is my reproduction command:
Output:
It seems that this issue only occurs during the initialization of the process group. Using the command of @ptrblck runs fine. |
This warning appears to originate from here: ProcessGroupNCCL.cpp#L777. This area specifically deals with the logic for 'expandable_segments', but it's possible that the macro |
A colleague at work mentioned also seeing this issue. I think @yitongh is right. To elaborate, I believe the problem is the following: Basically, expandable_segments() is both declared and defined in a header file https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDAAllocatorConfig.h#L26. However, CUDAAllocatorConfig.h is inconsistently used in some places where PYTORCH_C10_DRIVER_API_SUPPORTED is defined, and other places where it is not defined. That macro is added to the c10_cuda library with pytorch/c10/cuda/CMakeLists.txt Line 64 in 0d2ac97
However, ProcessGroupNCCL.cpp is not part of the c10_cuda library. So it does not have access to this macro. I confirmed this by looking at my compile_commands.json file:
My hunch is that applying this diff will fix the issue:
However, the current way of doing things is error prone. Maybe the defintion of expandable_segments() should be moved to the cpp file instead. Or the macro should be marked PUBLIC. I would recommend moving to definition out of the header file, speaking personally. I don't know how far transitively the PYTORCH_C10_DRIVER_API_SUPPORTED macro would propagate if it were set to PUBLIC. As far as whether this is a serious bug, I'm not sure if it will cause a crash. Unfortunately, I don't know what the purpose of the useTensorRegisterAllocatorHook_ is, or whether useTensorRegisterAllocatorHook_=false is compatible (intentionally or not!) with expandable_segments=true. |
cc @zdevito |
@michaelroyzen as long as you don't set the environment variable TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=true
I looked into this warning a bit more, and basically the code is intended to disable NVLink SHARP support when expandable segments is turned on, even when the user requests it (via TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=true). For reasons having to do with memory page alignment, tensors allocated via the expandable segments option cannot always be used with NVLink SHARP. The code fails to do this right now, but I don't see how this is a problem unless you set TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=true. |
What's the support matrix for expandable segments? I keep getting denied with "not supported on this platform" using EasyDiffusion and EasyTraining GUIs, but it's no stating whether I'm missing hardware or software support. |
@galv I am not explicitly setting TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=true. Expandable segments simply stopped working in PyTorch 2.2 due to the refactor https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDAAllocatorConfig.h#L28. PyTorch 2.1.2 is the last version that works for me with expandable segments -- upgrading to 2.2+ gives this warning and expandable segments are not enabled (and I get OOMs). |
I have the same experience. |
For those hitting this issue recently, it would help to get a |
Alright, here's mine. Not sure why it can't get the cudnn version, it should be there.
|
I don't think this is the actual reason. I build pytorch from source and |
I agree with @InhabitancyCocoon. I strongly suspect this is a bug in the refactor of the expandable segments between 2.1.2 and 2.2.0. The official PyPI 2.1.2+cu121 works but 2.2.0+cu121 and onwards does not. |
Our way to access driver APIs doesn't work on windows, so expandable_segments do not work there either. I expect that is why this particular instance of the error shows up. However, for @InhabitancyCocoon @michaelroyzen I suspect this is another manifestation of the error due to something else. Since I can build pytorch from source and have expandable segments available, I suspect something changed between 2.1.2 and 2.2.0 that caused driver API detection to change but only in certain setups. |
My setup is 8xH100, Ubuntu 22.04, CUDA 12.1+drivers installed, and the official PyTorch 2.x+cu121 build from PyPI. |
And what specifically is warning? "expandable_segments not supported on this platform" should only show up if the torch binary was built in a way that didn't define PYTORCH_C10_DRIVER_API_SUPPORTED. It doesn't check anything about the machine. |
Correct, "expandable_segments not supported on this platform" is the warning. Which would then imply that the official PyPI packages are not being built correctly. All I know is that the refactor between 2.1.2 and 2.2.0 triggered this issue. |
I tried
which is weird because it suggests that that package is built with it enabled. |
My experience is: |
I second that. The sanity checks above all pass without any warnings but real training causes problems. |
Can someone give the backtrace you see during training? |
Inspired by @galv , setting the macro public in the CMAKELists.txt and re-build pytorch from source seem to solve my problem.
But I'm not sure whether this hack will cause unexpected error in the future.
As mentioned by my colleague, I think the the sanity check works but real training fails has something to do with |
Some are likely seeing this through NCCL as the Opened #137828 to fix |
Yep, can confirm that I am using NCCL when I see these errors. Thanks for finding it @eqy. |
Does this mean SHARP is not supported when this env var is enabled? |
It seems this error occurs in |
In 2.6.0, this problem no longer appears. |
🐛 Describe the bug
On H100s and A100s instances setting PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' is not taking affect with the latest nightllies.
Warning is coming from: https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDAAllocatorConfig.h#L28
Versions
H100 machine:
A100 machine:
cc @ptrblck
The text was updated successfully, but these errors were encountered: