Skip to content

Commit bea5949

Browse files
fix vllm graph register and add test (#1894)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> The all-reduce buffers are incorrectly registered during the cuda graph capture phase when using vLLM custom all reduce. The base_ptr is not initialized and leads to illegal memory access error. vLLM's implementation with proper memory initialization - https://github.com/vllm-project/vllm/blob/0d21b9b51eccabfa1f8114eab2df61d75459bee7/csrc/custom_all_reduce.cuh#L452 This PR ensures that the memory pointer is initialized correctly and adds tests for vLLM custom all reduce during cuda graph capture ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent a88349f commit bea5949

File tree

2 files changed

+124
-1
lines changed

2 files changed

+124
-1
lines changed

include/flashinfer/comm/vllm_custom_all_reduce.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ struct cuda_error : public std::runtime_error {
4444
namespace vllm {
4545

4646
constexpr int kMaxBlocks = 36;
47+
constexpr CUpointer_attribute rangeStartAddrAttr = CU_POINTER_ATTRIBUTE_RANGE_START_ADDR;
48+
4749
// Counter may overflow, but it's fine since unsigned int overflow is
4850
// well-defined behavior.
4951
using FlagType = uint32_t;
@@ -366,6 +368,9 @@ class CustomAllreduce {
366368
void* base_ptr;
367369
// note: must share the base address of each allocation, or we get wrong
368370
// address
371+
if (cuPointerGetAttribute(&base_ptr, rangeStartAddrAttr, (CUdeviceptr)ptr) != CUDA_SUCCESS)
372+
throw std::runtime_error("failed to get pointer attr");
373+
369374
CHECK_CUDA_SUCCESS(
370375
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
371376
offsets[i] = ((char*)ptr) - ((char*)base_ptr);

tests/comm/test_vllm_custom_allreduce.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21-
def _run_correctness_worker(world_size, rank, distributed_init_port):
21+
def _initialize_process_group(world_size, rank, distributed_init_port):
2222
device = torch.device(f"cuda:{rank}")
2323
torch.cuda.set_device(device)
2424
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
@@ -29,8 +29,12 @@ def _run_correctness_worker(world_size, rank, distributed_init_port):
2929
world_size=world_size,
3030
)
3131
group = dist.group.WORLD
32+
return group
33+
3234

35+
def _run_correctness_worker(world_size, rank, distributed_init_port):
3336
try:
37+
group = _initialize_process_group(world_size, rank, distributed_init_port)
3438
device = torch.device(f"cuda:{rank}")
3539
max_size = 8192 * 1024
3640
meta_ptrs = comm.create_shared_buffer(
@@ -104,6 +108,103 @@ def get_open_port() -> int:
104108
return s.getsockname()[1]
105109

106110

111+
def _run_graph_buffer_ipc_meta_worker(
112+
world_size: int, rank: int, distributed_init_port: int
113+
):
114+
"""Test get_graph_buffer_ipc_meta function with CUDA graph capture."""
115+
116+
custom_ptr = None
117+
meta_ptrs = None
118+
119+
try:
120+
# Setup
121+
group = _initialize_process_group(world_size, rank, distributed_init_port)
122+
device = torch.device(f"cuda:{rank}")
123+
max_size = 8192 * 1024
124+
meta_ptrs = comm.create_shared_buffer(
125+
comm.vllm_meta_size() + max_size, group=group
126+
)
127+
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
128+
custom_ptr = comm.vllm_init_custom_ar(meta_ptrs, rank_data, rank, True)
129+
130+
# Test 1: Empty state before graph capture
131+
handle_bytes, offsets = comm.vllm_get_graph_buffer_ipc_meta(custom_ptr)
132+
assert len(handle_bytes) == 0 and len(offsets) == 0, (
133+
"Expected empty buffers before graph capture"
134+
)
135+
136+
# Test 2: Capture graph and validate IPC metadata structure
137+
test_size = 4096
138+
num_cta = 16
139+
dtype = torch.float16
140+
141+
inp1 = torch.randn(test_size, dtype=dtype, device=device)
142+
inp2 = torch.randn(test_size, dtype=dtype, device=device)
143+
out1 = torch.empty_like(inp1)
144+
out2 = torch.empty_like(inp2)
145+
146+
g = torch.cuda.CUDAGraph()
147+
with torch.cuda.graph(g, pool=None):
148+
comm.vllm_all_reduce(custom_ptr, inp1, out1, 0, 0, num_cta)
149+
comm.vllm_all_reduce(custom_ptr, inp2, out2, 0, 0, num_cta)
150+
151+
handle_bytes, offsets = comm.vllm_get_graph_buffer_ipc_meta(custom_ptr)
152+
153+
# Validate structure: 2 buffers, correct handle size (64 bytes each)
154+
ipc_handle_size = 64
155+
expected_num_buffers = 2
156+
assert len(offsets) == expected_num_buffers, (
157+
f"Expected {expected_num_buffers} offsets, got {len(offsets)}"
158+
)
159+
assert len(handle_bytes) == ipc_handle_size * expected_num_buffers, (
160+
f"Expected {ipc_handle_size * expected_num_buffers} handle bytes"
161+
)
162+
assert all(isinstance(o, int) and o >= 0 for o in offsets), (
163+
"All offsets should be non-negative integers"
164+
)
165+
166+
# Test 3: Distributed gather and register graph buffers
167+
all_handle_bytes = [None] * world_size
168+
all_offsets = [None] * world_size
169+
170+
dist.all_gather_object(all_handle_bytes, handle_bytes, group=group)
171+
dist.all_gather_object(all_offsets, offsets, group=group)
172+
173+
# All ranks should have same number of buffers
174+
assert all(len(off) == expected_num_buffers for off in all_offsets), (
175+
"All ranks should have same number of buffers"
176+
)
177+
178+
comm.vllm_register_graph_buffers(custom_ptr, all_handle_bytes, all_offsets)
179+
180+
# Test 4: Graph replay produces correct results
181+
inp1_test = torch.randn(test_size, dtype=dtype, device=device)
182+
inp2_test = torch.randn(test_size, dtype=dtype, device=device)
183+
184+
inp1.copy_(inp1_test)
185+
inp2.copy_(inp2_test)
186+
187+
g.replay()
188+
torch.cuda.synchronize()
189+
190+
# Verify with NCCL reference
191+
inp1_ref = inp1_test.clone()
192+
inp2_ref = inp2_test.clone()
193+
dist.all_reduce(inp1_ref, group=group)
194+
dist.all_reduce(inp2_ref, group=group)
195+
196+
torch.testing.assert_close(out1, inp1_ref, rtol=1e-3, atol=1e-3)
197+
torch.testing.assert_close(out2, inp2_ref, rtol=1e-3, atol=1e-3)
198+
199+
finally:
200+
dist.barrier(group=group)
201+
if custom_ptr is not None:
202+
comm.vllm_dispose(custom_ptr)
203+
if meta_ptrs:
204+
comm.free_shared_buffer(meta_ptrs, group)
205+
dist.destroy_process_group(group=group)
206+
207+
107208
def multi_process_parallel(
108209
world_size: int, test_target: Any, target_args: tuple = ()
109210
) -> None:
@@ -138,3 +239,20 @@ def test_vllm_custom_allreduce(world_size):
138239
target_args=(),
139240
)
140241
print(f"custom allreduce tp = {world_size}: OK")
242+
243+
244+
@pytest.mark.parametrize("world_size", [2, 4])
245+
def test_get_graph_buffer_ipc_meta(world_size: int):
246+
"""Test get_graph_buffer_ipc_meta function with CUDA graph capture."""
247+
available_gpus = torch.cuda.device_count()
248+
if world_size > available_gpus:
249+
pytest.skip(
250+
f"world_size {world_size} is greater than available_gpus {available_gpus}"
251+
)
252+
print(f"Running get_graph_buffer_ipc_meta test for world_size={world_size}")
253+
multi_process_parallel(
254+
world_size,
255+
_run_graph_buffer_ipc_meta_worker,
256+
target_args=(),
257+
)
258+
print(f"get_graph_buffer_ipc_meta test for world_size={world_size}: OK")

0 commit comments

Comments
 (0)