Skip to content

Conversation

@SageMoore
Copy link
Contributor

@SageMoore SageMoore commented Jul 3, 2025

This PR was done in collaboration with @LucasWilkinson

This PR adds experimental support for Dual-Batch Overlap in VLLM. In it's current state it will only be abled when a user provides the --enable-microbatching flag. Futhermore, it will only be used when all DP groups are running full-decode batches.

This PR is purely infrastructural, meaning it's slow. We will attempt to improve performance, assuming this approach is accepted, in follow on PRs. The immediate next step is to add support for full cudagraphs.

To implement Dual-Batch Overlap (DBO), at a high level, we split the batch into two microbatches. Then using two threads and two cuda streams, one for communication and one for computation, to overlap the dispatch and combine all-to-all kernels of one microbatch with the compute kernels of the other microbatch.

When microbatching is enabled and supported, the GPUModelRunner will split the batch into two token_slices. These token_slices are then passed into the attention meta data builders during _prepare_inputs to generate one attention metadata object per-microbatch. When actually running the model, the model runner will spawn off two microbatching threads that will each communicate with each other using a UBatchContext. Each of these threads will then run self.model with the appropriate attention meta data.

Without any additional modifications to the code, this will just result in one microbatch running to completion before the other microbatch starts. In order to get overlaps, we've added a "yield" call that can be inserted into the all-to-all kernels to interleave the two microbatches. The yield_and_switch_from_compute_to_comm function yield the CPU from this thread (thread A) to the other microbatching thread (thread B). Once thread A has resumed execution, either because thread B yielded the CPU or finished it's execution, it will swap over to the communication stream and start dispatching kernels there. yield_and_switch_from_comm_to_compute behaves similarly but in the opposite direction. It swaps from the communication stream to the compute stream.

There are both GPU and CPU events to synchronize all of this. That being said, it is absolutely critical that only one microbatching thread is running at a time, meaning the other one is waiting on an event. It is also absolutely critical that both microbatches are running the exact same number of yields.

lm_eval results

With Microbatching

VLLM_ALL2ALL_BACKEND=pplx vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --gpu-memory-utilization 0.75 --port 4444 --disable-log-requests --enforce-eager --enable-microbatching

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.184|±  |0.0173|
|     |       |strict-match    |     5|exact_match|↑  |0.180|±  |0.0172|

VLLM_ALL2ALL_BACKEND=pplx vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --trust-remote-code --data-parallel-size 2 --tensor-parallel-size 2 --enable-expert-parallel --gpu-memory-utilization 0.75 --port 4444 --disable-log-requests --enforce-eager --enable-microbatching

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.372|±  |0.0216|
|     |       |strict-match    |     5|exact_match|↑  |0.372|±  |0.0216|

Without Microbatching

VLLM_ALL2ALL_BACKEND=pplx vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --gpu-memory-utilization 0.75 --port 4444 --disable-log-requests --enforce-eager

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.164|±  |0.0166|
|     |       |strict-match    |     5|exact_match|↑  |0.160|±  |0.0164|

VLLM_ALL2ALL_BACKEND=pplx vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --trust-remote-code --data-parallel-size 2 --tensor-parallel-size 2 --enable-expert-parallel --gpu-memory-utilization 0.75 --port 4444 --disable-log-requests --enforce-eager

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.374|±  |0.0217|
|     |       |strict-match    |     5|exact_match|↑  |0.372|±  |0.0216|

LucasWilkinson and others added 30 commits May 22, 2025 20:51
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
SageMoore added 3 commits July 8, 2025 18:59
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
@SageMoore SageMoore changed the title [WIP] Add experimental Dual-Batch Overlap mechanism to VLLM Add experimental Dual-Batch Overlap mechanism to VLLM Jul 8, 2025
@mergify mergify bot removed the needs-rebase label Jul 9, 2025
SageMoore added 2 commits July 9, 2025 20:12
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
@mergify
Copy link

mergify bot commented Jul 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @SageMoore.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@SageMoore
Copy link
Contributor Author

SageMoore commented Jul 18, 2025

I'm reverting this back to a draft state while I break off various components and merge them separately. The first will be #21153.

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
@mergify mergify bot removed the needs-rebase label Jul 25, 2025
@mergify
Copy link

mergify bot commented Jul 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @SageMoore.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation needs-rebase v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants