-
Notifications
You must be signed in to change notification settings - Fork 188
Description
Problem Description
Hi, all
We are LLM developers and running the LLM models on MI355*8 single node. We have encountered an cuda graph(hip graph) capture issue when launching our LLM application. It is urgent issue and may block the customer's project.
After deep dive, we have found the torch.dist.all_reduce cannot be captured by cuda graph.
Here is the associated code:
https://github.com/ROCm/vllm/blob/dev/perf/vllm/v1/worker/dp_utils.py#L52
Here is the small reproducer python file:
import os
import torch
import torch.distributed as dist
import multiprocessing as mp
def worker(rank, world_size):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29502'
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(rank)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
backend = "nccl"
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
tensor = torch.ones(1, device=device, dtype=torch.float32)
group = dist.group.WORLD
graph = torch.cuda.CUDAGraph()
stream = torch.cuda.Stream()
try:
with torch.cuda.stream(stream):
graph.capture_begin()
tensor += 1.0
dist.all_reduce(tensor, group=group)
graph.capture_end()
print(f"Rank {rank}: Unexpected success")
except Exception as e:
print(f"Rank {rank}: Expected failure - {e}")
dist.destroy_process_group()
def main():
world_size = 8
if torch.cuda.device_count() < world_size:
print(f"Error: {world_size} GPUs required, but only {torch.cuda.device_count()} available")
return
ctx = mp.get_context("spawn")
processes = []
for rank in range(world_size):
p = ctx.Process(target=worker, args=(rank, world_size))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
main()
The reproducing command is python -u reproducer.py.
Then the issue can be reproduced. Here is the error log. The core error message is operation not permitted when stream is capturing
Rank 6: Expected failure - NCCL error in: /app/pytorch/torch/csrc/distributed/c10d/NCCLUtils.cpp:94, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.27.7
ncclUnhandledCudaError: Call to CUDA function failed.
Last error:
Cuda failure 'operation not permitted when stream is capturing'
terminate called after throwing an instance of 'c10::AcceleratorError'
what(): HIP error: operation not permitted when stream is capturing
Search for `hipErrorStreamCaptureUnsupported' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__HIPRT__TYPES.html for more information.
HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing AMD_SERIALIZE_KERNEL=3
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.
Exception raised from c10_hip_check_implementation at /app/pytorch/c10/hip/HIPException.cpp:45 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x9c (0x7fe5b9c341bc in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x374e1 (0x7fe5eaa6d4e1 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_hip.so)
frame #2: c10::hip::c10_hip_check_implementation(int, char const*, char const*, int, bool) + 0x1f1 (0x7fe5eaa6d371 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_hip.so)
frame #3: at::cuda::CUDAGraph::~CUDAGraph() + 0xb9 (0x7fe5ed360e59 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_hip.so)
frame #4: <unknown function> + 0xd479b6 (0x7fe600a6d9b6 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0xd47c4a (0x7fe600a6dc4a in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x3c8520 (0x7fe6000ee520 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #7: <unknown function> + 0x3c8bc5 (0x7fe6000eebc5 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #8: /usr/bin/python() [0x59e53d]
frame #9: /usr/bin/python() [0x566be8]
frame #10: _PyEval_EvalFrameDefault + 0xb11 (0x54a6f1 in /usr/bin/python)
frame #11: PyEval_EvalCode + 0x99 (0x620799 in /usr/bin/python)
frame #12: /usr/bin/python() [0x65c44b]
frame #13: /usr/bin/python() [0x6574d6]
frame #14: PyRun_StringFlags + 0x63 (0x653403 in /usr/bin/python)
frame #15: PyRun_SimpleStringFlags + 0x3e (0x65310e in /usr/bin/python)
frame #16: Py_RunMain + 0x4b2 (0x650742 in /usr/bin/python)
frame #17: Py_BytesMain + 0x2d (0x60962d in /usr/bin/python)
frame #18: <unknown function> + 0x29d90 (0x7fe601887d90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #19: __libc_start_main + 0x80 (0x7fe601887e40 in /lib/x86_64-linux-gnu/libc.so.6)
frame #20: _start + 0x25 (0x6094a5 in /usr/bin/python)
We have checked the RCCL commit and found there is one commit that sync with NCCL 2.9 version code and support the cuda graph capture for collective ops. Here is the associated commit:
6021329#diff-6445d3902b6d88df81be2bc5a58abf93b3aa3417132fdf27a0659815d20ec719
However, we have still found the capture failure issue for RCCL allreduce.
The using RCCL version is 2.27.7
root@smci355-ccs-aus-m01-25:/home/zejchen/rocm_vllm/vllm/evaluation/dp_attn# python -c "import torch; print(f'NCCL version: {torch.cuda.nccl.version()}')"
NCCL version: (2, 27, 7)
Thank you.
Operating System
Ubuntu 22.04.5 LTS
CPU
AMD EPYC 9575F 64-Core Processor
GPU
AMD MI355*8
ROCm Version
ROCm 7.1.0
ROCm Component
No response
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response