|
35 | 35 | from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group |
36 | 36 | from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks |
37 | 37 | from vllm.distributed.parallel_state import ( |
| 38 | + get_dcp_group, |
38 | 39 | get_pp_group, |
39 | 40 | get_tp_group, |
40 | 41 | graph_capture, |
|
92 | 93 | AttentionMetadataBuilder, |
93 | 94 | CommonAttentionMetadata, |
94 | 95 | create_fast_prefill_custom_backend, |
| 96 | + get_dcp_local_seq_lens, |
95 | 97 | reorder_batch_to_split_decodes_and_prefills, |
96 | 98 | split_attn_metadata, |
97 | 99 | ) |
@@ -256,6 +258,11 @@ def __init__( |
256 | 258 | self.is_multimodal_pruning_enabled = False |
257 | 259 | self.max_model_len = model_config.max_model_len |
258 | 260 | self.dcp_world_size = self.parallel_config.decode_context_parallel_size |
| 261 | + try: |
| 262 | + self.dcp_rank = get_dcp_group().rank_in_group |
| 263 | + except AssertionError: |
| 264 | + # DCP might not be initialized in testing |
| 265 | + self.dcp_rank = 0 |
259 | 266 | self.max_num_tokens = scheduler_config.max_num_batched_tokens |
260 | 267 | self.max_num_reqs = scheduler_config.max_num_seqs |
261 | 268 |
|
@@ -372,6 +379,7 @@ def __init__( |
372 | 379 | # uses output token ids so we set this conservatively. |
373 | 380 | logitsprocs_need_output_token_ids=bool(custom_logitsprocs), |
374 | 381 | is_pooling_model=self.is_pooling_model, |
| 382 | + cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, |
375 | 383 | ) |
376 | 384 |
|
377 | 385 | self.use_async_scheduling = self.scheduler_config.async_scheduling |
@@ -1276,6 +1284,15 @@ def _prepare_inputs( |
1276 | 1284 | logits_indices |
1277 | 1285 | ) |
1278 | 1286 |
|
| 1287 | + # update seq_lens of decode reqs under DCP. |
| 1288 | + if self.dcp_world_size > 1: |
| 1289 | + self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( |
| 1290 | + self.seq_lens.cpu[:num_reqs], |
| 1291 | + self.dcp_world_size, |
| 1292 | + self.parallel_config.cp_kv_cache_interleave_size, |
| 1293 | + )[:, self.dcp_rank] |
| 1294 | + self.dcp_local_seq_lens.copy_to_gpu(num_reqs) |
| 1295 | + |
1279 | 1296 | attn_metadata: PerLayerAttnMetadata = {} |
1280 | 1297 | if ubatch_slices is not None: |
1281 | 1298 | attn_metadata = [dict() for _ in range(len(ubatch_slices))] |
|
0 commit comments