Skip to content

Commit f6f8399

Browse files
committed
Deepseek-v3 Batch Invariant on 8xH100
Signed-off-by: Bram Wasti <bwasti@meta.com>
1 parent d1fcab6 commit f6f8399

File tree

25 files changed

+937
-149
lines changed

25 files changed

+937
-149
lines changed

tests/v1/generation/test_batch_invariance.py

Lines changed: 245 additions & 74 deletions
Large diffs are not rendered by default.

vllm/config/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,14 @@ def __post_init__(
426426
skip_mm_profiling: Optional[bool],
427427
video_pruning_rate: Optional[float],
428428
) -> None:
429+
# Enable batch invariance settings if requested
430+
from vllm.model_executor.layers.batch_invariant import (
431+
vllm_kernel_override_batch_invariant,
432+
)
433+
434+
if vllm_kernel_override_batch_invariant():
435+
self.enforce_eager = True
436+
429437
# Set the default seed to 0 in V1.
430438
# NOTE(woosuk): In V0, we set the default seed to None because the
431439
# driver worker shares the same process as the user process, and thus

vllm/config/parallel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,15 @@ def use_ray(self) -> bool:
531531
def _verify_args(self) -> Self:
532532
# Lazy import to avoid circular import
533533
from vllm.executor.executor_base import ExecutorBase
534+
from vllm.model_executor.layers.batch_invariant import (
535+
vllm_kernel_override_batch_invariant,
536+
)
534537
from vllm.platforms import current_platform
535538

539+
# Enable batch invariance settings if requested
540+
if vllm_kernel_override_batch_invariant():
541+
self.disable_custom_all_reduce = True
542+
536543
if (
537544
self.distributed_executor_backend is not None
538545
and not isinstance(self.distributed_executor_backend, str)

vllm/config/scheduler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,20 @@ def compute_hash(self) -> str:
170170
return hash_str
171171

172172
def __post_init__(self, is_encoder_decoder: bool) -> None:
173+
from vllm.model_executor.layers.batch_invariant import (
174+
vllm_kernel_override_batch_invariant,
175+
)
176+
173177
if self.max_model_len is None:
174178
self.max_model_len = 8192
175179

176180
if self.max_num_seqs is None:
177181
self.max_num_seqs = 128
178182

183+
# Enable batch invariance settings if requested
184+
if vllm_kernel_override_batch_invariant():
185+
self.enable_chunked_prefill = False
186+
179187
if is_encoder_decoder:
180188
# Chunked prefill should be disabled for encoder-decoder models.
181189
self.disable_chunked_mm_input = True

vllm/distributed/device_communicators/all_reduce_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
7070
from vllm.distributed.device_communicators.pynccl_allocator import (
7171
is_symmetric_memory_enabled,
7272
)
73+
from vllm.model_executor.layers.batch_invariant import (
74+
vllm_kernel_override_batch_invariant,
75+
)
76+
77+
if vllm_kernel_override_batch_invariant():
78+
return False
7379

7480
if not is_symmetric_memory_enabled():
7581
return False

vllm/distributed/device_communicators/symm_mem.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
1111
)
1212
from vllm.logger import init_logger
13+
from vllm.model_executor.layers.batch_invariant import (
14+
vllm_kernel_override_batch_invariant,
15+
)
1316
from vllm.platforms import current_platform
1417

1518
try:
@@ -96,6 +99,8 @@ def __init__(
9699
return
97100
self.force_multimem = force_multimem
98101
self.disabled = False
102+
if vllm_kernel_override_batch_invariant():
103+
self.disabled = True
99104

100105
def should_use_symm_mem(self, inp: torch.Tensor):
101106
if self.disabled:

vllm/engine/arg_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,11 +1686,12 @@ def _set_default_args(
16861686
) -> None:
16871687
"""Set Default Arguments for V1 Engine."""
16881688

1689-
# V1 always uses chunked prefills and prefix caching
1689+
# V1 uses chunked prefills and prefix caching by default
16901690
# for non-pooling tasks.
16911691
# For pooling tasks the default is False
16921692
if model_config.runner_type != "pooling":
1693-
self.enable_chunked_prefill = True
1693+
if self.enable_chunked_prefill is None:
1694+
self.enable_chunked_prefill = True
16941695

16951696
# TODO: When prefix caching supports prompt embeds inputs, this
16961697
# check can be removed.

0 commit comments

Comments
 (0)