Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExpandableMemorySegments not working on H100s/A100s #122057

Open
chauhang opened this issue Mar 17, 2024 · 30 comments
Open

ExpandableMemorySegments not working on H100s/A100s #122057

chauhang opened this issue Mar 17, 2024 · 30 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: CUDACachingAllocator triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@chauhang
Copy link
Contributor

chauhang commented Mar 17, 2024

🐛 Describe the bug

On H100s and A100s instances setting PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' is not taking affect with the latest nightllies.

[W CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator())

Warning is coming from: https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDAAllocatorConfig.h#L28

Versions

H100 machine:

<frozen runpy>:128: RuntimeWarning: 'torch.utils.collect_env' found in sys.modules after import of package 'torch.utils', but prior to execution of 'torch.utils.collect_env'; this may result in unpredictable behaviour
Collecting environment information...
PyTorch version: 2.4.0.dev20240317+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3)
Clang version: Could not collect
CMake version: version 3.26.5
Libc version: glibc-2.34

Python version: 3.11.8 (main, Feb 26 2024, 21:39:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.19.0-0_fbk12_zion_11583_g0bef9520ca2b-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H100
GPU 1: NVIDIA H100
GPU 2: NVIDIA H100
GPU 3: NVIDIA H100

Nvidia driver version: 525.105.17
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.8.8.0
/usr/lib64/libcudnn_adv_infer.so.8.8.0
/usr/lib64/libcudnn_adv_train.so.8.8.0
/usr/lib64/libcudnn_cnn_infer.so.8.8.0
/usr/lib64/libcudnn_cnn_train.so.8.8.0
/usr/lib64/libcudnn_ops_infer.so.8.8.0
/usr/lib64/libcudnn_ops_train.so.8.8.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   52 bits physical, 57 bits virtual
Byte Order:                      Little Endian
CPU(s):                          184
On-line CPU(s) list:             0-183
Vendor ID:                       AuthenticAMD
Model name:                      AMD EPYC 9654 96-Core Processor
CPU family:                      25
Model:                           17
Thread(s) per core:              1
Core(s) per socket:              184
Socket(s):                       1
Stepping:                        1
BogoMIPS:                        4792.78
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean pausefilter pfthreshold v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm arch_capabilities
Virtualization:                  AMD-V
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       11.5 MiB (184 instances)
L1i cache:                       11.5 MiB (184 instances)
L2 cache:                        92 MiB (184 instances)
L3 cache:                        2.9 GiB (184 instances)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-183
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Vulnerable, IBPB: conditional, IBRS_FW, STIBP: disabled, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-triton==3.0.0+989adb9a29
[pip3] torch==2.4.0.dev20240317+cu121
[pip3] torchao-nightly==2024.3.17
[pip3] torchaudio==2.2.0.dev20240317+cu121
[pip3] torchtune==0.0.1
[pip3] torchvision==0.18.0.dev20240317+cu121
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] pytorch-triton            3.0.0+989adb9a29          pypi_0    pypi
[conda] torch                     2.4.0.dev20240317+cu121          pypi_0    pypi
[conda] torchao-nightly           2024.3.17                pypi_0    pypi
[conda] torchaudio                2.2.0.dev20240317+cu121          pypi_0    pypi
[conda] torchtune                 0.0.1                    pypi_0    pypi
[conda] torchvision               0.18.0.dev20240317+cu121          pypi_0    pypi

A100 machine:

python -m torch.utils.collect_env
<frozen runpy>:128: RuntimeWarning: 'torch.utils.collect_env' found in sys.modules after import of package 'torch.utils', but prior to execution of 'torch.utils.collect_env'; this may result in unpredictable behaviour
Collecting environment information...
PyTorch version: 2.4.0.dev20240317+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.27.7
Libc version: glibc-2.31

Python version: 3.11.8 (main, Feb 26 2024, 21:39:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 535.104.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 48 bits virtual
CPU(s):                             96
On-line CPU(s) list:                0-95
Thread(s) per core:                 2
Core(s) per socket:                 24
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              85
Model name:                         Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping:                           7
CPU MHz:                            3599.998
BogoMIPS:                           6000.00
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          1.5 MiB
L1i cache:                          1.5 MiB
L2 cache:                           48 MiB
L3 cache:                           71.5 MiB
NUMA node0 CPU(s):                  0-23,48-71
NUMA node1 CPU(s):                  24-47,72-95
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        KVM: Mitigation: VMX unsupported
Vulnerability L1tf:                 Mitigation; PTE Inversion
Vulnerability Mds:                  Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-triton==3.0.0+989adb9a29
[pip3] torch==2.4.0.dev20240317+cu118
[pip3] torchao-nightly==2024.3.17
[pip3] torchtune==0.0.1
[pip3] torchvision==0.18.0.dev20240317+cu118
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] pytorch-triton            3.0.0+989adb9a29          pypi_0    pypi
[conda] torch                     2.4.0.dev20240317+cu118          pypi_0    pypi
[conda] torchao-nightly           2024.3.17                pypi_0    pypi
[conda] torchtune                 0.0.1                    pypi_0    pypi
[conda] torchvision               0.18.0.dev20240317+cu118          pypi_0    pypi

cc @ptrblck

@chauhang chauhang changed the title ExpandableMemorySegments not working on H100s ExpandableMemorySegments not working on H100s/A100s Mar 17, 2024
@bdhirsh bdhirsh added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: CUDACachingAllocator labels Mar 19, 2024
@ptrblck
Copy link
Collaborator

ptrblck commented Apr 5, 2024

Not reproducible on A100 using the latest nightly binaries:

PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' python -c "import torch; print(torch.randn(1).cuda()); print(torch.__version__)"
tensor([0.6288], device='cuda:0')
2.4.0.dev20240405+cu118
...
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' python -c "import torch; print(torch.randn(1).cuda()); print(torch.__version__)"
tensor([-1.6317], device='cuda:0')
2.4.0.dev20240405+cu121

Creating a profile for this simple code also shows the expected cuMemCreate etc. usage:

 Time (%)  Total Time (ns)  Num Calls  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                Name               
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ---------------------------------
     75.0         85860531         20  4293026.6  4443887.0      6863  16916216    4680471.9  cudaLaunchKernel                 
     22.5         25785597          9  2865066.3  2258438.0   2142408   6395693    1426342.9  cudaGetDeviceProperties_v2_v12000
      1.6          1879258          2   939629.0   939629.0     13465   1865793    1309793.7  cudaHostAlloc                    
      0.4           505239          1   505239.0   505239.0    505239    505239          0.0  cuMemSetAccess                   
      0.1           136748          8    17093.5    18715.5      5721     33353      10284.1  cudaMemcpyAsync                  
      0.1           130117          1   130117.0   130117.0    130117    130117          0.0  cuMemCreate                      
      0.1            82436          8    10304.5    11296.5      3757     16101       3873.2  cudaStreamSynchronize            
      0.0            27362          1    27362.0    27362.0     27362     27362          0.0  cuMemAddressReserve              
      0.0             6192          1     6192.0     6192.0      6192      6192          0.0  cuMemMap                         
      0.0             2756          1     2756.0     2756.0      2756      2756          0.0  cuModuleGetLoadingMode         

Calls without setting the PYTORCH_CUDA_ALLOC_CONF env variable as a reference:

 Time (%)  Total Time (ns)  Num Calls  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                Name               
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ---------------------------------
     77.4         76950244         20  3847512.2  4424450.0      6813  13703647    3900974.4  cudaLaunchKernel                 
     19.8         19650438          8  2456304.8  2223391.0   2065391   3934390     614051.4  cudaGetDeviceProperties_v2_v12000
      1.6          1585640          2   792820.0   792820.0     19066   1566574    1094253.4  cudaHostAlloc                    
      1.0           997744          1   997744.0   997744.0    997744    997744          0.0  cudaMalloc                       
      0.1           125379          8    15672.4    16461.5      5230     31550       9835.2  cudaMemcpyAsync                  
      0.1            80774          8    10096.8    11116.0      3386     15860       3699.4  cudaStreamSynchronize            
      0.0            14888          1    14888.0    14888.0     14888     14888          0.0  cudaStreamIsCapturing_v10000     
      0.0             2705          1     2705.0     2705.0      2705      2705          0.0  cuModuleGetLoadingMode    

@yitongh
Copy link

yitongh commented May 14, 2024

@ptrblck @chauhang I've also encountered this issue, below is my reproduction command:

docker run --gpus all --rm -t --ipc=host pytorch/pytorch:latest bash -c "PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True RANK=0 WORLD_SIZE=1 MASTER_ADDR=127.0.0.1 MASTER_PORT=9001 python -c \"import torch; torch.distributed.init_process_group(backend='nccl')\""

Output:

[W CUDAAllocatorConfig.h:30] Warning: expandable_segments not supported on this platform (function operator())

It seems that this issue only occurs during the initialization of the process group. Using the command of @ptrblck runs fine.

@yitongh
Copy link

yitongh commented May 14, 2024

This warning appears to originate from here: ProcessGroupNCCL.cpp#L777. This area specifically deals with the logic for 'expandable_segments', but it's possible that the macro PYTORCH_C10_DRIVER_API_SUPPORTED was not set during the compilation of this file, leading to allocator and NCCL getting different results.

@galv
Copy link
Collaborator

galv commented May 23, 2024

A colleague at work mentioned also seeing this issue. I think @yitongh is right. To elaborate, I believe the problem is the following:

Basically, expandable_segments() is both declared and defined in a header file https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDAAllocatorConfig.h#L26. However, CUDAAllocatorConfig.h is inconsistently used in some places where PYTORCH_C10_DRIVER_API_SUPPORTED is defined, and other places where it is not defined. That macro is added to the c10_cuda library with target_compile_options(c10_cuda PRIVATE "-DPYTORCH_C10_DRIVER_API_SUPPORTED") in

target_compile_options(c10_cuda PRIVATE "-DPYTORCH_C10_DRIVER_API_SUPPORTED")

However, ProcessGroupNCCL.cpp is not part of the c10_cuda library. So it does not have access to this macro. I confirmed this by looking at my compile_commands.json file:

{
  "directory": "/home/dgalvez/scratch/code/asr/pytorch/build",
  "command": "/usr/bin/c++ -DAT_PER_OPERATOR_HEADERS -DFLASHATTENTION_DISABLE_ALIBI -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS\
 -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DUSE_C10D_GLOO -DUSE_C10D_MPI -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_MEM_EFF_\
ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/dgalvez/scratch/code/asr/pytorch/build/aten/src -I/home/dgalvez/scratch/code/asr/pytorch/aten/src -I/ho\
me/dgalvez/scratch/code/asr/pytorch/build -I/home/dgalvez/scratch/code/asr/pytorch -I/home/dgalvez/scratch/code/asr/pytorch/cmake/../third_party/benchmark/include -I/home/dgalvez/scratch/code/asr/pytorch\
/third_party/onnx -I/home/dgalvez/scratch/code/asr/pytorch/build/third_party/onnx -I/home/dgalvez/scratch/code/asr/pytorch/third_party/foxi -I/home/dgalvez/scratch/code/asr/pytorch/build/third_party/foxi\
 -I/home/dgalvez/scratch/code/asr/pytorch/aten/src/THC -I/home/dgalvez/scratch/code/asr/pytorch/aten/src/ATen/cuda -I/home/dgalvez/scratch/code/asr/pytorch/aten/src/ATen/../../../third_party/cutlass/incl\
ude -I/home/dgalvez/scratch/code/asr/pytorch/build/caffe2/aten/src -I/home/dgalvez/scratch/code/asr/pytorch/aten/src/ATen/.. -I/home/dgalvez/scratch/code/asr/pytorch/build/nccl/include -I/home/dgalvez/sc\
ratch/code/asr/pytorch/c10/cuda/../.. -I/home/dgalvez/scratch/code/asr/pytorch/c10/.. -I/home/dgalvez/scratch/code/asr/pytorch/third_party/tensorpipe -I/home/dgalvez/scratch/code/asr/pytorch/build/third_\
party/tensorpipe -I/home/dgalvez/scratch/code/asr/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/dgalvez/scratch/code/asr/pytorch/torch/csrc/api -I/home/dgalvez/scratch/code/asr/pytorc\
h/torch/csrc/api/include -isystem /home/dgalvez/scratch/code/asr/pytorch/build/third_party/gloo -isystem /home/dgalvez/scratch/code/asr/pytorch/cmake/../third_party/gloo -isystem /home/dgalvez/scratch/co\
de/asr/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/dgalvez/scratch/code/asr/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/dgalvez/scratch/\
code/asr/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/dgalvez/scratch/code/asr/pytorch/third_party/protobuf/src -isystem /home/dgalvez/scratch/code/asr/pytorch/third_party/ge\
mmlowp -isystem /home/dgalvez/scratch/code/asr/pytorch/third_party/neon2sse -isystem /home/dgalvez/scratch/code/asr/pytorch/third_party/XNNPACK/include -isystem /home/dgalvez/scratch/code/asr/pytorch/thi\
rd_party/ittapi/include -isystem /home/dgalvez/scratch/code/asr/pytorch/cmake/../third_party/eigen -isystem /home/dgalvez/scratch/cuda/cuda-12.4/include -isystem /home/dgalvez/scratch/code/asr/pytorch/to\
rch/include -isystem /home/dgalvez/scratch/code/asr/pytorch/third_party/ideep/include -isystem /home/dgalvez/scratch/code/asr/pytorch/torch/include/oneapi/dnnl -isystem /usr/local/cuda/include  -D_GLIBCX\
X_USE_CXX11_ABI=1 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -\
fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unk\
nown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=ol\
d-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow -D\
HAVE_AVX512_CPU_DEFINITION -DHAVE_AVX2_CPU_DEFINITION -g -fno-omit-frame-pointer -O0 -std=gnu++17 -fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-unused-fu\
nction -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-type-limits -Wno-array-bounds -Wno-strict-overflow -Wno-strict-aliasing -Wno-maybe-uninitialized -fvisibility=hidden -Wno-deprecated-copy \
-o caffe2/CMakeFiles/torch_cuda.dir/__/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp.o -c /home/dgalvez/scratch/code/asr/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
  "file": "/home/dgalvez/scratch/code/asr/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
  "output": "caffe2/CMakeFiles/torch_cuda.dir/__/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp.o"
},

My hunch is that applying this diff will fix the issue:

modified   caffe2/CMakeLists.txt
@@ -614,10 +614,15 @@ if(USE_CUDA)                                                                                                                                                                          
    append_filelist("libtorch_cuda_distributed_base_sources" Caffe2_GPU_SRCS)                                                                                                                              
    if(NOT WIN32)                                                                                                                                                                                          
      append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS)                                                                                                                           
+      # It is very strange that this compile flah is specifically set here.                                                                                                                                
      set_source_files_properties(                                                                                                                                                                         
        ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp                                                                                                                                         
        PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"                                                                                                                                    
      )                                                                                                                                                                                                    
+      set_source_files_properties(                                                                                                                                                                         
+        ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupNCCL.cpp                                                                                                                                        
+        PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"                                                                                                                                    
+      )                                                                                                                                                                                                    
    endif()                                                                                                                                                                                                
  endif()                                                                                                                                                                                                  
  set_source_files_properties(                                                                                                                                                                             

However, the current way of doing things is error prone. Maybe the defintion of expandable_segments() should be moved to the cpp file instead. Or the macro should be marked PUBLIC. I would recommend moving to definition out of the header file, speaking personally. I don't know how far transitively the PYTORCH_C10_DRIVER_API_SUPPORTED macro would propagate if it were set to PUBLIC.

As far as whether this is a serious bug, I'm not sure if it will cause a crash. Unfortunately, I don't know what the purpose of the useTensorRegisterAllocatorHook_ is, or whether useTensorRegisterAllocatorHook_=false is compatible (intentionally or not!) with expandable_segments=true.

@eellison
Copy link
Contributor

cc @zdevito

@michaelroyzen
Copy link

Can also confirm that this issue exists in versions 2.2 and later. The expandable segments failures prevent us from upgrading from 2.1.2 @ptrblck @bdhirsh.

@galv
Copy link
Collaborator

galv commented Jul 18, 2024

@michaelroyzen as long as you don't set the environment variable TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=true

getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false);
, I do not believe you will encounter any issues. Or are you saying that your code is crashing? Please elaborate if so. Otherwise, it's just a spurious warning, which I can fix in a jiffy.

I looked into this warning a bit more, and basically the code is intended to disable NVLink SHARP support when expandable segments is turned on, even when the user requests it (via TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=true). For reasons having to do with memory page alignment, tensors allocated via the expandable segments option cannot always be used with NVLink SHARP. The code fails to do this right now, but I don't see how this is a problem unless you set TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=true.

@Seedmanc
Copy link

What's the support matrix for expandable segments? I keep getting denied with "not supported on this platform" using EasyDiffusion and EasyTraining GUIs, but it's no stating whether I'm missing hardware or software support.
I'm running Cuda 12.1 on Win10x64 and RTX 2070S 8Gb, torch-2.3.1+cu121, torchvision-0.18.1+cu121.

@michaelroyzen
Copy link

@galv I am not explicitly setting TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=true. Expandable segments simply stopped working in PyTorch 2.2 due to the refactor https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDAAllocatorConfig.h#L28. PyTorch 2.1.2 is the last version that works for me with expandable segments -- upgrading to 2.2+ gives this warning and expandable segments are not enabled (and I get OOMs).

@pzelasko
Copy link

@galv I am not explicitly setting TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=true. Expandable segments simply stopped working in PyTorch 2.2 due to the refactor https://github.com/pytorch/pytorch/blob/main/c10/cuda/CUDAAllocatorConfig.h#L28. PyTorch 2.1.2 is the last version that works for me with expandable segments -- upgrading to 2.2+ gives this warning and expandable segments are not enabled (and I get OOMs).

I have the same experience.

@zdevito
Copy link
Contributor

zdevito commented Sep 3, 2024

For those hitting this issue recently, it would help to get a python -m torch.utils.collect_env to understand how pytorch was built. This can happen if for some reason PyTorch was built without access to the driver API, and maybe also if it can't find the driver binary(?).

@Seedmanc
Copy link

Seedmanc commented Sep 4, 2024

Alright, here's mine. Not sure why it can't get the cudnn version, it should be there.
I'm not building anything though, it's just installed as is.

Collecting environment information...
PyTorch version: 2.3.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Home
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.11.7 (tags/v3.11.7:fa7a6f2, Dec 4 2023, 19:24:49) [MSC v.1937 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19045-SP0
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2070 SUPER
Nvidia driver version: 560.81
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=3401
DeviceID=CPU0
Family=191
L2CacheSize=1024
L2CacheSpeed=
Manufacturer=GenuineIntel
MaxClockSpeed=3401
Name=Intel(R) Core(TM) i5-2550K CPU @ 3.40GHz
ProcessorType=3
Revision=10759

Versions of relevant libraries:
[pip3] came-pytorch==0.1.3
[pip3] lion-pytorch==0.0.6
[pip3] numpy==1.26.3
[pip3] pytorch-lightning==1.9.0
[pip3] pytorch_optimizer==3.0.0
[pip3] torch==2.3.1+cu121
[pip3] torchmetrics==1.4.1
[pip3] torchvision==0.18.1+cu121
[conda] Could not collect

@InhabitancyCocoon
Copy link

For those hitting this issue recently, it would help to get a python -m torch.utils.collect_env to understand how pytorch was built. This can happen if for some reason PyTorch was built without access to the driver API, and maybe also if it can't find the driver binary(?).

I don't think this is the actual reason. I build pytorch from source and python -m torch.utils.collect_env shows everything I expect.

@michaelroyzen
Copy link

I agree with @InhabitancyCocoon. I strongly suspect this is a bug in the refactor of the expandable segments between 2.1.2 and 2.2.0. The official PyPI 2.1.2+cu121 works but 2.2.0+cu121 and onwards does not.

@zdevito
Copy link
Contributor

zdevito commented Sep 5, 2024

Alright, here's mine. Not sure why it can't get the cudnn version, it should be there. I'm not building anything though, it's just installed as is.

OS: Microsoft Windows 10 Home

Our way to access driver APIs doesn't work on windows, so expandable_segments do not work there either. I expect that is why this particular instance of the error shows up.

However, for @InhabitancyCocoon @michaelroyzen I suspect this is another manifestation of the error due to something else. Since I can build pytorch from source and have expandable segments available, I suspect something changed between 2.1.2 and 2.2.0 that caused driver API detection to change but only in certain setups.

@michaelroyzen
Copy link

My setup is 8xH100, Ubuntu 22.04, CUDA 12.1+drivers installed, and the official PyTorch 2.x+cu121 build from PyPI.

@zdevito
Copy link
Contributor

zdevito commented Sep 5, 2024

My setup is 8xH100, Ubuntu 22.04, CUDA 12.1+drivers installed, and the official PyTorch 2.x+cu121 build from PyPI.

And what specifically is warning? "expandable_segments not supported on this platform" should only show up if the torch binary was built in a way that didn't define PYTORCH_C10_DRIVER_API_SUPPORTED. It doesn't check anything about the machine.

@michaelroyzen
Copy link

Correct, "expandable_segments not supported on this platform" is the warning. Which would then imply that the official PyPI packages are not being built correctly. All I know is that the refactor between 2.1.2 and 2.2.0 triggered this issue.

@zdevito
Copy link
Contributor

zdevito commented Sep 6, 2024

I tried pip install torch in a new environment, and I am able to do

$ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' python 
> import torch
> torch.rand(3, 4, device='cuda')

which is weird because it suggests that that package is built with it enabled.

@InhabitancyCocoon
Copy link

My experience is:
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' python works, but when it comes to training, expandable_segments does not work anymore.

@michaelroyzen
Copy link

I second that. The sanity checks above all pass without any warnings but real training causes problems.

@galv
Copy link
Collaborator

galv commented Sep 6, 2024

Can someone give the backtrace you see during training?

@Seedmanc
Copy link

Seedmanc commented Sep 6, 2024

OS: Microsoft Windows 10 Home

Our way to access driver APIs doesn't work on windows, so expandable_segments do not work there either.

That's weird, a friend of mine on Win10 does not have this error. Here's what he tried:
image

But when I try running the 2.1.2 mentioned here as the most compatible one, I get error:

pip3 install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121
<omitted>
Successfully installed torch-2.1.2+cu121
set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
python
Python 3.11.7 (tags/v3.11.7:fa7a6f2, Dec  4 2023, 19:24:49) [MSC v.1937 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.randn((1, 2, 40), device="cuda")
<stdin>:1: UserWarning: expandable_segments not supported on this platform (Triggered internally at ..\c10\cuda\CUDACachingAllocator.cpp:803.)

Is there a windows build for this feature?

@InhabitancyCocoon
Copy link

InhabitancyCocoon commented Sep 9, 2024

However, the current way of doing things is error prone. Maybe the defintion of expandable_segments() should be moved to the cpp file instead. Or the macro should be marked PUBLIC. I would recommend moving to definition out of the header file, speaking personally. I don't know how far transitively the PYTORCH_C10_DRIVER_API_SUPPORTED macro would propagate if it were set to PUBLIC.

Inspired by @galv , setting the macro public in the CMAKELists.txt and re-build pytorch from source seem to solve my problem.

target_compile_options(c10_cuda PUBLIC "-DPYTORCH_C10_DRIVER_API_SUPPORTED")

But I'm not sure whether this hack will cause unexpected error in the future.

https://github.com/pytorch/pytorch/blob/release/2.4/c10/cuda/CMakeLists.txt#L72

As mentioned by my colleague, I think the the sanity check works but real training fails has something to do with cudaMalloc method.

@eqy
Copy link
Collaborator

eqy commented Oct 12, 2024

Some are likely seeing this through NCCL as the ProcessGroupNCCL.cpp build does not have the macro properly set.

Opened #137828 to fix

@michaelroyzen
Copy link

Yep, can confirm that I am using NCCL when I see these errors. Thanks for finding it @eqy.

@zaptrem
Copy link

zaptrem commented Nov 10, 2024

@michaelroyzen as long as you don't set the environment variable TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=true

getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false);

, I do not believe you will encounter any issues. Or are you saying that your code is crashing? Please elaborate if so. Otherwise, it's just a spurious warning, which I can fix in a jiffy.
I looked into this warning a bit more, and basically the code is intended to disable NVLink SHARP support when expandable segments is turned on, even when the user requests it (via TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=true). For reasons having to do with memory page alignment, tensors allocated via the expandable segments option cannot always be used with NVLink SHARP. The code fails to do this right now, but I don't see how this is a problem unless you set TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=true.

Does this mean SHARP is not supported when this env var is enabled?

@sytelus
Copy link

sytelus commented Jan 17, 2025

It seems this error occurs in pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel image but not in @nvcr.io#nvidia/pytorch:24.07-py3 image.

@eqy
Copy link
Collaborator

eqy commented Jan 17, 2025

@sytelus please check if you are still seeing it in nightlies as 2.5.1 may not container the PR with this fix #137828

in either case IIRC it's just a warning and doesn't really have any bearing on actual functionality

@ad8e
Copy link
Contributor

ad8e commented Feb 22, 2025

In 2.6.0, this problem no longer appears.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: CUDACachingAllocator triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests