diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index f240074f252d..8184b073275c 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -313,9 +313,10 @@ def __init__(self, runner): cache_config = runner.cache_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + self.enable_prefix_caching = cache_config.enable_prefix_caching - if self.chunked_prefill_enabled: - self.chunked_prefill_workspace_size = min( + if self.chunked_prefill_enabled or self.enable_prefix_caching: + self.context_chunk_workspace_size = min( # Max sure there is enough for 8 full length request or at least # 4 pages of cache per request max( @@ -330,7 +331,7 @@ def __init__(self, runner): # 2*(192*128)*(64*1024) = 3gb # (assuming 192 QK head dim, 128 heads, and fp16) 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ + assert self.context_chunk_workspace_size >= \ scheduler_config.max_num_seqs * cache_config.block_size @contextmanager @@ -430,23 +431,23 @@ def prepare_graph_input_buffers(self, "TritonMLAState does not support encoder/decoder yet") def begin_forward(self, model_input): - if self.chunked_prefill_enabled: - if not hasattr(self, "chunked_prefill_workspace"): + if self.chunked_prefill_enabled or self.enable_prefix_caching: + if not hasattr(self, "context_chunk_workspace"): # not self.runner.device does not return the correct device # for this process, (init_device sets the correct device but # only on the Worker). The only way Ive figured out to get the # correct device is to allocate the workspace on the first call # to begin_forward and use the device of the input tokens assert model_input.input_tokens is not None - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, + self.context_chunk_workspace = torch.empty( + (self.context_chunk_workspace_size, self.model_config.get_head_size()), dtype=self.model_config.dtype, device=model_input.input_tokens.device, ) - model_input.attn_metadata.chunked_prefill_workspace = \ - self.chunked_prefill_workspace + model_input.attn_metadata.context_chunk_workspace = \ + self.context_chunk_workspace @dataclass @@ -537,7 +538,7 @@ class MLACommonMetadata(AttentionMetadata): context_chunk_seq_tot: Optional[List[int]] = None context_chunk_max_seq_lens: Optional[List[int]] = None # Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted - chunked_prefill_workspace: Optional[torch.Tensor] = None + context_chunk_workspace: Optional[torch.Tensor] = None def __post_init__(self): supported_head_sizes = MLACommonBackend.get_supported_head_sizes() @@ -747,11 +748,13 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.block_size = input_builder.block_size self.chunked_prefill_enabled = \ self.runner.scheduler_config.chunked_prefill_enabled + self.enable_prefix_caching = \ + self.runner.cache_config.enable_prefix_caching - if self.chunked_prefill_enabled: + if self.chunked_prefill_enabled or self.enable_prefix_caching: attn_state = self.input_builder.runner.attn_state - self.chunked_prefill_workspace_size = \ - attn_state.chunked_prefill_workspace_size + self.context_chunk_workspace_size = \ + attn_state.context_chunk_workspace_size self.page_size = self.runner.block_size def prepare(self): @@ -920,7 +923,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], context_chunk_seq_tot = None context_chunk_max_seq_lens = None - if self.chunked_prefill_enabled and self.num_prefills > 0 \ + if (self.chunked_prefill_enabled or self.enable_prefix_caching) \ + and self.num_prefills > 0 \ and context_lens_tensor is not None \ and context_lens_tensor[:self.num_prefills].max() > 0: @@ -936,7 +940,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], # algorithm here and allocate more workspace to prefills with # longer context lengths max_context_chunk = \ - self.chunked_prefill_workspace_size // num_prefills_with_context + self.context_chunk_workspace_size // num_prefills_with_context # align max_context_chunk to page_size by rounding down, # currently the `gather_cache` kernel cannot handle @@ -965,7 +969,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], chunk_seq_lens.max(dim=1).values.tolist() context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() assert max(context_chunk_seq_tot) <= \ - self.chunked_prefill_workspace_size + self.context_chunk_workspace_size return self.runner.attn_backend.make_metadata( # Required by ModelRunner @@ -1288,8 +1292,8 @@ def _compute_prefill_context( # Fetch from attn_metadata directly, since it late bound by # MLAAttentionState, grabbing it directly `attn_metadata` can avoid # any weirdness around prefill_metadata caching - assert attn_metadata.chunked_prefill_workspace is not None - workspace = attn_metadata.chunked_prefill_workspace + assert attn_metadata.context_chunk_workspace is not None + workspace = attn_metadata.context_chunk_workspace for i in range(iters): toks = prefill_metadata.context_chunk_seq_tot[i] @@ -1502,12 +1506,12 @@ def forward( "output is not yet supported for MLAImplBase") if attn_metadata.is_profile_run and \ - attn_metadata.chunked_prefill_workspace is not None: + attn_metadata.context_chunk_workspace is not None: # During the profile run try to simulate to worse case output size # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` # since this can be large _ = torch.empty( - (attn_metadata.chunked_prefill_workspace.shape[0], + (attn_metadata.context_chunk_workspace.shape[0], self.num_heads, self.qk_nope_head_dim + self.v_head_dim), device=k_c_normed.device, dtype=k_c_normed.dtype,