@@ -313,9 +313,10 @@ def __init__(self, runner):
313313 cache_config = runner .cache_config
314314
315315 self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
316+ self .enable_prefix_caching = cache_config .enable_prefix_caching
316317
317- if self .chunked_prefill_enabled :
318- self .chunked_prefill_workspace_size = min (
318+ if self .chunked_prefill_enabled or self . enable_prefix_caching :
319+ self .context_chunk_workspace_size = min (
319320 # Max sure there is enough for 8 full length request or at least
320321 # 4 pages of cache per request
321322 max (
@@ -330,7 +331,7 @@ def __init__(self, runner):
330331 # 2*(192*128)*(64*1024) = 3gb
331332 # (assuming 192 QK head dim, 128 heads, and fp16)
332333 128 * 1024 )
333- assert self .chunked_prefill_workspace_size >= \
334+ assert self .context_chunk_workspace_size >= \
334335 scheduler_config .max_num_seqs * cache_config .block_size
335336
336337 @contextmanager
@@ -430,23 +431,23 @@ def prepare_graph_input_buffers(self,
430431 "TritonMLAState does not support encoder/decoder yet" )
431432
432433 def begin_forward (self , model_input ):
433- if self .chunked_prefill_enabled :
434- if not hasattr (self , "chunked_prefill_workspace " ):
434+ if self .chunked_prefill_enabled or self . enable_prefix_caching :
435+ if not hasattr (self , "context_chunk_workspace " ):
435436 # not self.runner.device does not return the correct device
436437 # for this process, (init_device sets the correct device but
437438 # only on the Worker). The only way Ive figured out to get the
438439 # correct device is to allocate the workspace on the first call
439440 # to begin_forward and use the device of the input tokens
440441 assert model_input .input_tokens is not None
441- self .chunked_prefill_workspace = torch .empty (
442- (self .chunked_prefill_workspace_size ,
442+ self .context_chunk_workspace = torch .empty (
443+ (self .context_chunk_workspace_size ,
443444 self .model_config .get_head_size ()),
444445 dtype = self .model_config .dtype ,
445446 device = model_input .input_tokens .device ,
446447 )
447448
448- model_input .attn_metadata .chunked_prefill_workspace = \
449- self .chunked_prefill_workspace
449+ model_input .attn_metadata .context_chunk_workspace = \
450+ self .context_chunk_workspace
450451
451452
452453@dataclass
@@ -537,7 +538,7 @@ class MLACommonMetadata(AttentionMetadata):
537538 context_chunk_seq_tot : Optional [List [int ]] = None
538539 context_chunk_max_seq_lens : Optional [List [int ]] = None
539540 # Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted
540- chunked_prefill_workspace : Optional [torch .Tensor ] = None
541+ context_chunk_workspace : Optional [torch .Tensor ] = None
541542
542543 def __post_init__ (self ):
543544 supported_head_sizes = MLACommonBackend .get_supported_head_sizes ()
@@ -747,11 +748,13 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
747748 self .block_size = input_builder .block_size
748749 self .chunked_prefill_enabled = \
749750 self .runner .scheduler_config .chunked_prefill_enabled
751+ self .enable_prefix_caching = \
752+ self .runner .cache_config .enable_prefix_caching
750753
751- if self .chunked_prefill_enabled :
754+ if self .chunked_prefill_enabled or self . enable_prefix_caching :
752755 attn_state = self .input_builder .runner .attn_state
753- self .chunked_prefill_workspace_size = \
754- attn_state .chunked_prefill_workspace_size
756+ self .context_chunk_workspace_size = \
757+ attn_state .context_chunk_workspace_size
755758 self .page_size = self .runner .block_size
756759
757760 def prepare (self ):
@@ -920,7 +923,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
920923 context_chunk_seq_tot = None
921924 context_chunk_max_seq_lens = None
922925
923- if self .chunked_prefill_enabled and self .num_prefills > 0 \
926+ if (self .chunked_prefill_enabled or self .enable_prefix_caching ) \
927+ and self .num_prefills > 0 \
924928 and context_lens_tensor is not None \
925929 and context_lens_tensor [:self .num_prefills ].max () > 0 :
926930
@@ -936,7 +940,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
936940 # algorithm here and allocate more workspace to prefills with
937941 # longer context lengths
938942 max_context_chunk = \
939- self .chunked_prefill_workspace_size // num_prefills_with_context
943+ self .context_chunk_workspace_size // num_prefills_with_context
940944
941945 # align max_context_chunk to page_size by rounding down,
942946 # currently the `gather_cache` kernel cannot handle
@@ -965,7 +969,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
965969 chunk_seq_lens .max (dim = 1 ).values .tolist ()
966970 context_chunk_seq_tot = chunk_seq_lens .sum (dim = 1 ).tolist ()
967971 assert max (context_chunk_seq_tot ) <= \
968- self .chunked_prefill_workspace_size
972+ self .context_chunk_workspace_size
969973
970974 return self .runner .attn_backend .make_metadata (
971975 # Required by ModelRunner
@@ -1288,8 +1292,8 @@ def _compute_prefill_context(
12881292 # Fetch from attn_metadata directly, since it late bound by
12891293 # MLAAttentionState, grabbing it directly `attn_metadata` can avoid
12901294 # any weirdness around prefill_metadata caching
1291- assert attn_metadata .chunked_prefill_workspace is not None
1292- workspace = attn_metadata .chunked_prefill_workspace
1295+ assert attn_metadata .context_chunk_workspace is not None
1296+ workspace = attn_metadata .context_chunk_workspace
12931297
12941298 for i in range (iters ):
12951299 toks = prefill_metadata .context_chunk_seq_tot [i ]
@@ -1502,12 +1506,12 @@ def forward(
15021506 "output is not yet supported for MLAImplBase" )
15031507
15041508 if attn_metadata .is_profile_run and \
1505- attn_metadata .chunked_prefill_workspace is not None :
1509+ attn_metadata .context_chunk_workspace is not None :
15061510 # During the profile run try to simulate to worse case output size
15071511 # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
15081512 # since this can be large
15091513 _ = torch .empty (
1510- (attn_metadata .chunked_prefill_workspace .shape [0 ],
1514+ (attn_metadata .context_chunk_workspace .shape [0 ],
15111515 self .num_heads , self .qk_nope_head_dim + self .v_head_dim ),
15121516 device = k_c_normed .device ,
15131517 dtype = k_c_normed .dtype ,
0 commit comments