Skip to content

[collective op][cuda graph] capture collective ops but got an HIP error: operation not permitted when stream is capturing #2022

@zejunchen-zejun

Description

@zejunchen-zejun

Problem Description

Hi, all

We are LLM developers and running the LLM models on MI355*8 single node. We have encountered an cuda graph(hip graph) capture issue when launching our LLM application. It is urgent issue and may block the customer's project.
After deep dive, we have found the torch.dist.all_reduce cannot be captured by cuda graph.
Here is the associated code:
https://github.com/ROCm/vllm/blob/dev/perf/vllm/v1/worker/dp_utils.py#L52

Here is the small reproducer python file:

import os
import torch
import torch.distributed as dist
import multiprocessing as mp

def worker(rank, world_size):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29502'
    os.environ['RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['LOCAL_RANK'] = str(rank)

    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)

    backend = "nccl"
    dist.init_process_group(backend=backend, rank=rank, world_size=world_size)

    tensor = torch.ones(1, device=device, dtype=torch.float32)
    group = dist.group.WORLD

    graph = torch.cuda.CUDAGraph()
    stream = torch.cuda.Stream()

    try:
        with torch.cuda.stream(stream):
            graph.capture_begin()
            tensor += 1.0
            dist.all_reduce(tensor, group=group)
            graph.capture_end()
        print(f"Rank {rank}: Unexpected success")
    except Exception as e:
        print(f"Rank {rank}: Expected failure - {e}")

    dist.destroy_process_group()

def main():
    world_size = 8
    if torch.cuda.device_count() < world_size:
        print(f"Error: {world_size} GPUs required, but only {torch.cuda.device_count()} available")
        return

    ctx = mp.get_context("spawn")
    processes = []
    for rank in range(world_size):
        p = ctx.Process(target=worker, args=(rank, world_size))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

if __name__ == "__main__":
    main()

The reproducing command is python -u reproducer.py.

Then the issue can be reproduced. Here is the error log. The core error message is operation not permitted when stream is capturing

Rank 6: Expected failure - NCCL error in: /app/pytorch/torch/csrc/distributed/c10d/NCCLUtils.cpp:94, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.27.7
ncclUnhandledCudaError: Call to CUDA function failed.
Last error:
Cuda failure 'operation not permitted when stream is capturing'
terminate called after throwing an instance of 'c10::AcceleratorError'
  what():  HIP error: operation not permitted when stream is capturing
Search for `hipErrorStreamCaptureUnsupported' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__HIPRT__TYPES.html for more information.
HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing AMD_SERIALIZE_KERNEL=3
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.

Exception raised from c10_hip_check_implementation at /app/pytorch/c10/hip/HIPException.cpp:45 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x9c (0x7fe5b9c341bc in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x374e1 (0x7fe5eaa6d4e1 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_hip.so)
frame #2: c10::hip::c10_hip_check_implementation(int, char const*, char const*, int, bool) + 0x1f1 (0x7fe5eaa6d371 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_hip.so)
frame #3: at::cuda::CUDAGraph::~CUDAGraph() + 0xb9 (0x7fe5ed360e59 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_hip.so)
frame #4: <unknown function> + 0xd479b6 (0x7fe600a6d9b6 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0xd47c4a (0x7fe600a6dc4a in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x3c8520 (0x7fe6000ee520 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #7: <unknown function> + 0x3c8bc5 (0x7fe6000eebc5 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #8: /usr/bin/python() [0x59e53d]
frame #9: /usr/bin/python() [0x566be8]
frame #10: _PyEval_EvalFrameDefault + 0xb11 (0x54a6f1 in /usr/bin/python)
frame #11: PyEval_EvalCode + 0x99 (0x620799 in /usr/bin/python)
frame #12: /usr/bin/python() [0x65c44b]
frame #13: /usr/bin/python() [0x6574d6]
frame #14: PyRun_StringFlags + 0x63 (0x653403 in /usr/bin/python)
frame #15: PyRun_SimpleStringFlags + 0x3e (0x65310e in /usr/bin/python)
frame #16: Py_RunMain + 0x4b2 (0x650742 in /usr/bin/python)
frame #17: Py_BytesMain + 0x2d (0x60962d in /usr/bin/python)
frame #18: <unknown function> + 0x29d90 (0x7fe601887d90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #19: __libc_start_main + 0x80 (0x7fe601887e40 in /lib/x86_64-linux-gnu/libc.so.6)
frame #20: _start + 0x25 (0x6094a5 in /usr/bin/python)

We have checked the RCCL commit and found there is one commit that sync with NCCL 2.9 version code and support the cuda graph capture for collective ops. Here is the associated commit:
6021329#diff-6445d3902b6d88df81be2bc5a58abf93b3aa3417132fdf27a0659815d20ec719

However, we have still found the capture failure issue for RCCL allreduce.
The using RCCL version is 2.27.7

root@smci355-ccs-aus-m01-25:/home/zejchen/rocm_vllm/vllm/evaluation/dp_attn# python -c "import torch; print(f'NCCL version: {torch.cuda.nccl.version()}')"
NCCL version: (2, 27, 7)

Thank you.

Operating System

Ubuntu 22.04.5 LTS

CPU

AMD EPYC 9575F 64-Core Processor

GPU

AMD MI355*8

ROCm Version

ROCm 7.1.0

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions