|
| 1 | +# Dual Batch Overlap |
| 2 | + |
| 3 | +## Motivation |
| 4 | + |
| 5 | +The core motivation of the DBO system in vLLM is to overlap the sparse all-to-all communication in the MoE layer with the surrounding computation. This system currently only targets DP+EP deployments. |
| 6 | + |
| 7 | +## Introduction |
| 8 | + |
| 9 | +The Dual Batch Overlap system works by splitting the batch in the model runner, creating two worker threads, and then running the model on each of these worker threads. When DBO is enabled, yield points within the `FusedMoEModularKernel` allow the two CPU worker threads (also called UBatch threads) to ping-pong between each other so that when one is running compute, the other is waiting on communication. Throughout the code, ubatch may be used as a short form of microbatch; this is an ASCII-friendly version of the short form µ-batch. |
| 10 | + |
| 11 | +The DBO system includes modifications to `GpuModelRunner` and `ModularKernel`, and defines two utility classes: `UBatchWrapper` and `UBatchContext`. `UBatchWrapper` manages thread lifecycle and CUDA graph execution of the model. `UBatchContext` wraps `ForwardContext` to coordinate synchronization between the two UBatch threads. |
| 12 | + |
| 13 | +Below is the overlap schedule that is currently implemented in vLLM. |
| 14 | + |
| 15 | +```python |
| 16 | +# Schedule notation legend: |
| 17 | +# S = Shared expert |
| 18 | +# A0 = MLA qkv proj, |
| 19 | +# A1 = Core attn + out proj + MoE gate |
| 20 | +# D = Dispatch |
| 21 | +# C = Combine |
| 22 | + |
| 23 | +# Comp: |-A0₀-A1₀-||-MLP₁-||-S₁-MLP₀-||-S₀-A0₁-A1₁-| |
| 24 | +# Comm: |----D₁---||--D₀--||----C₁---||-----C₀-----| |
| 25 | +# Order: D₁ send, A0₀, A1₀, D₁ recv, D₀ send, MLP₁, D₀ recv, |
| 26 | +# C₁ send, S₁, MLP₀, C₁ recv, C₀ send, S₀, A0₁, A1₁, C₀ recv. |
| 27 | +# MLP_SHARED_OVERLAP = "mlp_shared_overlap" |
| 28 | +``` |
| 29 | + |
| 30 | +## Running with DBO |
| 31 | + |
| 32 | +To enable the DBO system pass in the `--enable-dbo` argument to your vllm serve command. This must be run in conjunction with `--data-parallel-size N` where N is greater than 1 and `--enable-expert-parallel`. Additionally, there are two configuration knobs. |
| 33 | + |
| 34 | +* `--dbo-decode-token-threshold` the minimum number of tokens in a decode-only batch required to enable DBO for that batch |
| 35 | +* `--dbo-prefill-token-threshold` the minimum number of tokens in a batch containing at least one prefill required to enable DBO for that batch |
| 36 | + |
| 37 | +Currently, DBO is only supported with DeepEP, so DeepEP must be installed and the `VLLM_ALL2ALL_BACKEND` environment variable must be set to `deepep_low_latency` if your workload is primarily decode requests, or `deepep_high_throughput` if your workload is primarily prefill requests. |
| 38 | + |
| 39 | +Below is a command that will spin up a two DP rank server with expert parallelism and DBO enabled. |
| 40 | +EX: `VLLM_ALL2ALL_BACKEND=deepep_low_latency vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --enable-dbo` |
| 41 | + |
| 42 | +Note that there must be at least two GPUs visible in `CUDA_VISIBLE_DEVICES` |
| 43 | + |
| 44 | +## DBO Components |
| 45 | + |
| 46 | +* GPUModelRunner |
| 47 | +* UBatchWrapper |
| 48 | +* UBatchContext |
| 49 | + |
| 50 | +### GPU Model Runner |
| 51 | + |
| 52 | +The batch is split into microbatches by the `GPUModelRunner` class. This is accomplished in two steps. First, coordination across all DP ranks is performed to determine whether microbatching will be applied. Microbatching must be uniform across all DP ranks. If microbatching is not feasible for any DP rank, it is disabled for all ranks. If all DP ranks are going to microbatch, the total number of tokens is padded up to the max number of tokens amongst all ranks. If any rank would end up with an empty second microbatch after the padding is applied, microbatching will be aborted and no ranks will microbatch. Once microbatching has been initiated by all ranks, the second step is performed. The `CommonAttentionMetadata` is sliced in half by the `GPUModelRunner` so that there is one attention metadata per-microbatch. |
| 53 | + |
| 54 | +### UBatchWrapper |
| 55 | + |
| 56 | +gpu_ubatch_wrapper |
| 57 | + |
| 58 | +The `UBatchWrapper` class is a model wrapper that's responsible for all of the thread, UBatchContext, and CUDA graph management for DBO. It's designed to be relatively transparent to the GPU Model Runner. |
| 59 | + |
| 60 | +The implementation runs the model twice, once for each microbatch. Each model invocation occurs within a UBatch thread. These threads are launched in parallel and are synchronized using the `UBatchContext`. Each thread is provided with a sliced version of the attention metadata that is used to run its half of the batch. |
| 61 | + |
| 62 | +CUDA graphs for DBO are entirely managed by the `UBatchWrapper`. Because of this, DBO only supports running with Full CUDA graphs. However, once a DBO CUDA graph has been captured, it can be replayed without any multithreading or CPU synchronization. |
| 63 | + |
| 64 | +#### Interfaces |
| 65 | + |
| 66 | +The `__init__` method takes in the model, VllmConfig, CUDAGraphMode, and device. |
| 67 | + |
| 68 | +The `forward` method exclusively takes in model arguments. It determines whether or not to run with DBO based on whether a `ubatch_slices` object is present in the `forward_context`. Otherwise, the model is run without DBO. |
| 69 | + |
| 70 | +### UBatchContext |
| 71 | + |
| 72 | +ubatch_context |
| 73 | + |
| 74 | +The `UBatchContext` class is a `ForwardContext` wrapper class that is used by the `UBatchWrapper` class to synchronize the two UBatch threads. It should only be instantiated by using `make_ubatch_contexts`. |
| 75 | + |
| 76 | +When one of the UBatch threads reaches a `dbo_yield` call, it pauses, and starts the other thread which will run until it reaches the same `dbo_yield` call. This "ping-pong" dynamic continues, with threads swapping at each `dbo_yield call`, until the model's execution is complete. |
| 77 | + |
| 78 | +The current implementation has all `dbo_yield` and `dbo_maybe_run_recv_hook` calls in the `FusedMoEModularKernel.forward` method. |
| 79 | + |
| 80 | +#### Interfaces |
| 81 | + |
| 82 | +The `make_ubatch_context` function initializes two `UBatchContexts`, one for each UBatch thread. It takes two CUDA streams, the preexisting `ForwardContexts` and a CPU thread barrier. This function should be used exclusively to instantiate `UBatchContexts`. It will handle all of the event initialization. |
| 83 | + |
| 84 | +The `dbo_register_recv_hook` method registers a callback that can be returned by the `FusedMoEPrepareAndFinalize` class in the other UBatch thread’s `UBatchContext`. The callback will be run when the other thread calls `dbo_maybe_run_recv_hook`. This is typically used to wait on an all-to-all kernel. |
| 85 | + |
| 86 | +The `dbo_maybe_run_recv_hook` method runs a callback that’s set by the `dbo_register_recv_hook` function if that callback exists. |
| 87 | + |
| 88 | +The `dbo_yield` method puts the current thread to sleep and wakes up the other UBatch thread. |
0 commit comments