Skip to content

Commit bb9d305

Browse files
LucasWilkinsonSageMooreyewentao256tlrmchlsmth
authored andcommitted
[Core/DBO][2/N] Dual-Batch Overlap add DeepEP High Throughput support and Prefill support (vllm-project#24845)
Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent c89a532 commit bb9d305

File tree

19 files changed

+604
-238
lines changed

19 files changed

+604
-238
lines changed

tests/v1/attention/test_attention_splitting.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import torch
66

77
from tests.v1.attention.test_attention_backends import BATCH_SPECS
8-
from tests.v1.attention.utils import create_common_attn_metadata
8+
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
99
from vllm.v1.attention.backends.utils import (UBatchSlice,
1010
_make_metadata_with_slice,
1111
slice_query_start_locs,
1212
split_attn_metadata)
13+
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
1314

1415

1516
@pytest.fixture
@@ -155,3 +156,83 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
155156
assert results[1].num_reqs == mid_point
156157
assert results[1].num_actual_tokens == mid_point
157158
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
159+
160+
161+
@pytest.mark.parametrize(
162+
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
163+
[
164+
# Split in the middle of request 1
165+
([32, 40], [8, 8], 12, 2, 1),
166+
# Split inside the first request
167+
([32, 40], [8, 8], 4, 1, 2),
168+
],
169+
)
170+
def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point,
171+
expected_first_reqs,
172+
expected_second_reqs):
173+
"""Test splitting a prefill across ubatches"""
174+
import numpy as np
175+
176+
device = torch.device("cpu")
177+
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens)
178+
common = create_common_attn_metadata(batch_spec,
179+
block_size=16,
180+
device=device)
181+
182+
num_scheduled_tokens = np.array(query_lens, dtype=np.int32)
183+
qsl_np = common.query_start_loc_cpu.numpy()
184+
num_tokens = common.num_actual_tokens
185+
186+
ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point)
187+
assert len(ubatch_slices) == 2
188+
189+
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
190+
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)
191+
192+
# Token counts match the split
193+
assert first_meta.num_actual_tokens == split_point
194+
assert second_meta.num_actual_tokens == num_tokens - split_point
195+
196+
# Number of requests per ubatch
197+
assert first_meta.num_reqs == expected_first_reqs
198+
assert second_meta.num_reqs == expected_second_reqs
199+
200+
# Identify which request is split and how many tokens are in the first chunk
201+
split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1)
202+
tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx])
203+
orig_q_lens = (common.query_start_loc_cpu[1:] -
204+
common.query_start_loc_cpu[:-1])
205+
206+
# Check query length continuity: first-chunk + second-chunk == original qlen
207+
# First ubatch last request query length
208+
qlen_first_last = int(first_meta.query_start_loc_cpu[-1] -
209+
first_meta.query_start_loc_cpu[-2])
210+
# Second ubatch first request query length
211+
qlen_second_first = int(second_meta.query_start_loc_cpu[1] -
212+
second_meta.query_start_loc_cpu[0])
213+
assert qlen_first_last == tokens_in_first_chunk
214+
assert qlen_first_last + qlen_second_first == int(
215+
orig_q_lens[split_req_idx])
216+
217+
# Check seq_lens adjustments
218+
# Context lengths per original request
219+
context_lens = [s - q for s, q in zip(seq_lens, query_lens)]
220+
221+
# First ubatch: last request's seq_len should be
222+
# context + tokens_in_first_chunk
223+
expected_seqlen = context_lens[split_req_idx] + tokens_in_first_chunk
224+
assert int(first_meta.seq_lens[-1]) == expected_seqlen
225+
226+
# For full preceding requests in first ubatch, seq_lens should match
227+
# originals
228+
for i in range(first_meta.num_reqs - 1):
229+
assert int(first_meta.seq_lens[i]) == seq_lens[i]
230+
231+
# Second ubatch: first request (continuation) seq_len should be full
232+
# original
233+
assert int(second_meta.seq_lens[0]) == seq_lens[split_req_idx]
234+
# Any following full requests in second ubatch should match originals
235+
for j in range(1, second_meta.num_reqs):
236+
# Map to original request index
237+
orig_idx = split_req_idx + j
238+
assert int(second_meta.seq_lens[j]) == seq_lens[orig_idx]

tests/v1/spec_decode/test_eagle.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,8 @@ def create_deterministic_logits(token_ids):
532532
# Mock runner for attention metadata building
533533
proposer.runner = mock.MagicMock()
534534
proposer.runner.attn_groups.append([mock.MagicMock()])
535-
proposer.runner.attn_groups[0][0].metadata_builders = [
535+
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
536536
attn_metadata_builder
537-
]
538537

539538
result = proposer.propose(target_token_ids=target_token_ids,
540539
target_positions=target_positions,
@@ -659,9 +658,8 @@ def create_deterministic_logits(token_ids, k: int):
659658
# Mock runner for attention metadata building.
660659
proposer.runner = mock.MagicMock()
661660
proposer.runner.attn_groups.append([mock.MagicMock()])
662-
proposer.runner.attn_groups[0][0].metadata_builders = [
661+
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
663662
attn_metadata_builder
664-
]
665663

666664
# Setup inputs for the proposer.
667665
target_token_ids = torch.randint(0,

vllm/config/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -638,11 +638,13 @@ def __post_init__(self):
638638

639639
if self.parallel_config.enable_dbo:
640640
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
641-
assert a2a_backend == "deepep_low_latency", \
642-
"Microbatching currently only supports the deepep_low_latency "\
643-
f"all2all backend. {a2a_backend} is not supported. To fix set "\
644-
"the VLLM_ALL2ALL_BACKEND environment variable to "\
645-
"deepep_low_latency and install the DeepEP kerenls."
641+
assert a2a_backend in \
642+
["deepep_low_latency", "deepep_high_throughput"], \
643+
"Microbatching currently only supports the deepep_low_latency and "\
644+
f"deepep_high_throughput all2all backend. {a2a_backend} is not "\
645+
"supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\
646+
"variable to deepep_low_latency or deepep_high_throughput and "\
647+
"install the DeepEP kernels."
646648

647649
if not self.instance_id:
648650
self.instance_id = random_uuid()[:5]

vllm/config/parallel.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,18 @@ class ParallelConfig:
139139
"""Disable the custom all-reduce kernel and fall back to NCCL."""
140140

141141
enable_dbo: bool = False
142-
"""Enable microbatching for the model executor."""
142+
"""Enable dual batch overlap for the model executor."""
143143

144144
dbo_decode_token_threshold: int = 32
145-
"""The threshold for microbatching. If the number of tokens in the
146-
request is greater than this threshold, microbatching will be used.
147-
Otherwise, the request will be processed in a single batch."""
145+
"""The threshold for dual batch overlap for batches only containing decodes.
146+
If the number of tokens in the request is greater than this threshold,
147+
microbatching will be used. Otherwise, the request will be processed in a
148+
single batch."""
149+
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
150+
"""The threshold for dual batch overlap for batches that contain one or more
151+
prefills. If the number of tokens in the request is greater than this
152+
threshold, microbatching will be used. Otherwise, the request will be
153+
processed in a single batch."""
148154

149155
ray_workers_use_nsight: bool = False
150156
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

vllm/distributed/device_communicators/all2all.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import Any
3+
from typing import Any, Optional
44

55
import torch
66
import torch.distributed as dist
77

8+
import vllm.envs as envs
89
from vllm.distributed import get_dp_group
910
from vllm.forward_context import get_forward_context
1011
from vllm.logger import init_logger
@@ -200,12 +201,12 @@ def __init__(self, cpu_group):
200201

201202
def _make_all2all_kwargs(self) -> dict[Any, Any]:
202203
# Defaults for internode and intranode are taken from DeepEP tests.
203-
num_nvl_bytes = 1024 * 1024 * 1024
204+
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
204205
num_rdma_bytes = None
205206
num_qps_per_rank = None
206207

207208
if self.internode:
208-
num_rdma_bytes = 1024 * 1024 * 1024
209+
num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
209210
num_qps_per_rank = self.num_sms // 2
210211
else:
211212
num_rdma_bytes = 0
@@ -230,13 +231,18 @@ def get_handle(self, kwargs):
230231
logger.debug("DeepEP all2all args %s", buffer_kwargs)
231232
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
232233
buffer_kwargs, deep_ep.Buffer)
233-
# It is dangerous to set num sms outside this function. num_sms is not
234-
# a part of the hash-key that identifies this object. If we are in a
235-
# situation where we make objects with different num_sms, the hash key
236-
# in get_or_create must be updated.
237-
handle.set_num_sms(self.num_sms)
238234
return handle
239235

236+
def set_num_sms(self, num_sms: int):
237+
import deep_ep
238+
239+
# Right now the buffers are sized for only what the kernels were
240+
# created with. So we can only reduce the number of SMS used
241+
# but not increase it.
242+
if num_sms > self.num_sms:
243+
num_sms = self.num_sms
244+
deep_ep.Buffer.set_num_sms(num_sms)
245+
240246

241247
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
242248
"""
@@ -265,7 +271,7 @@ def _make_all2all_kwargs(
265271
import deep_ep
266272

267273
# Defaults for internode and intranode are taken from DeepEP tests.
268-
num_nvl_bytes = 1024 * 1024 * 1024
274+
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
269275
num_qps_per_rank = num_local_experts
270276
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
271277
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
@@ -291,3 +297,7 @@ def get_handle(self, kwargs):
291297
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
292298
buffer_kwargs, deep_ep.Buffer)
293299
return handle
300+
301+
# DeepEP LL uses RDMA so no SMs are used for communication
302+
def max_sms_used(self) -> Optional[int]:
303+
return 0

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def get_handle(self, kwargs):
6060
# and reuse it for the same config.
6161
raise NotImplementedError
6262

63+
def set_num_sms(self, num_sms: int):
64+
pass
65+
66+
def max_sms_used(self) -> Optional[int]:
67+
return None # None means it could use the whole GPU
68+
6369
def dispatch(self, hidden_states: torch.Tensor,
6470
router_logits: torch.Tensor):
6571
raise NotImplementedError

vllm/engine/arg_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,8 @@ class EngineArgs:
330330
enable_dbo: bool = ParallelConfig.enable_dbo
331331
dbo_decode_token_threshold: int = \
332332
ParallelConfig.dbo_decode_token_threshold
333+
dbo_prefill_token_threshold: int = \
334+
ParallelConfig.dbo_prefill_token_threshold
333335
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
334336
enable_eplb: bool = ParallelConfig.enable_eplb
335337
expert_placement_strategy: ExpertPlacementStrategy = \
@@ -698,6 +700,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
698700
parallel_group.add_argument(
699701
"--dbo-decode-token-threshold",
700702
**parallel_kwargs["dbo_decode_token_threshold"])
703+
parallel_group.add_argument(
704+
"--dbo-prefill-token-threshold",
705+
**parallel_kwargs["dbo_prefill_token_threshold"])
701706
parallel_group.add_argument("--enable-eplb",
702707
**parallel_kwargs["enable_eplb"])
703708
parallel_group.add_argument("--eplb-config",
@@ -1316,6 +1321,7 @@ def create_engine_config(
13161321
enable_expert_parallel=self.enable_expert_parallel,
13171322
enable_dbo=self.enable_dbo,
13181323
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
1324+
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
13191325
enable_eplb=self.enable_eplb,
13201326
eplb_config=self.eplb_config,
13211327
expert_placement_strategy=self.expert_placement_strategy,

vllm/envs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@
189189
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
190190
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
191191
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
192+
VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024
193+
VLLM_DBO_COMM_SMS: int = 20
192194
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
193195
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
194196

@@ -1392,6 +1394,15 @@ def get_vllm_port() -> Optional[int]:
13921394
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
13931395
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
13941396

1397+
# The size in MB of the buffers (NVL and RDMA) used by DeepEP
1398+
"VLLM_DEEPEP_BUFFER_SIZE_MB":
1399+
lambda: int(os.getenv("VLLM_DEEPEP_BUFFER_SIZE_MB", "1024")),
1400+
1401+
# The number of SMs to allocate for communication kernels when running DBO
1402+
# the rest of the SMs on the device will be allocated to compute
1403+
"VLLM_DBO_COMM_SMS":
1404+
lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")),
1405+
13951406
# Valid values are container,code_interpreter,web_search_preview
13961407
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
13971408
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS":

0 commit comments

Comments
 (0)