11from dataclasses import dataclass
22from typing import Any , Dict , List , Optional , Set , Tuple , Type
33
4- import flashinfer
4+ try :
5+ from flashinfer import BatchDecodeWithPagedKVCacheWrapper
6+ from flashinfer .prefill import BatchPrefillWithPagedKVCacheWrapper
7+ from vllm_flash_attn import flash_attn_varlen_func
8+ except ImportError :
9+ flash_attn_varlen_func = None
10+ BatchDecodeWithPagedKVCacheWrapper = None
11+ BatchPrefillWithPagedKVCacheWrapper = None
12+
513import torch
6- from flashinfer import BatchDecodeWithPagedKVCacheWrapper
7- from vllm_flash_attn import flash_attn_varlen_func
814
915from vllm import _custom_ops as ops
1016from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
@@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata):
6066 # requests only.
6167 max_prefill_seq_len : int
6268
63- use_cuda_graph : bool = False
69+ use_cuda_graph : bool = True
6470
71+ prefill_wrapper : Optional [BatchPrefillWithPagedKVCacheWrapper ] = None
6572 decode_wrapper : Optional [BatchDecodeWithPagedKVCacheWrapper ] = None
6673
67- # Metadata for the prefill stage since we still
68- # use flash attention for prefill.
74+ # Metadata for the prefill stage
6975 seq_start_loc : Optional [torch .Tensor ] = None
76+ query_start_loc : Optional [torch .Tensor ] = None
7077 block_tables : Optional [torch .Tensor ] = None
7178
72- # Metadata for the decode stage
73- # Workspace buffer required by the kernel, the buffer should not
74- # be allocated/deacollated by the FalshInfermetadata object.
75- workspace_buffer : Optional [torch .Tensor ] = None
7679 # An example for paged_kv_indices, paged_kv_indptr:
7780 # request 1, page indices [0, 5, 8]
7881 # request 2, page indices [1, 6, 7]
@@ -98,6 +101,7 @@ class FlashInferMetadata(AttentionMetadata):
98101 page_size : Optional [int ] = None
99102 # The data type of the paged kv cache
100103 data_type : torch .dtype = None
104+ device : torch .device = torch .device ("cuda" )
101105
102106 def __post_init__ (self ):
103107 # Refer to
@@ -109,13 +113,35 @@ def __post_init__(self):
109113 f"Only { supported_head_sizes } are supported for head_dim," ,
110114 f"received { self .head_dim } ." )
111115
112- # When using flashinfer, we are also creating the FlashInferMetadata,
113- # which will also call post_init by default, here we want to skip the
114- # post_init if it's the prefill phase.
115- if self .num_prefills == 0 :
116- assert self .num_decode_tokens > 0
117- self .decode_wrapper = flashinfer .BatchDecodeWithPagedKVCacheWrapper (
118- self .workspace_buffer , "NHD" )
116+ def begin_forward (self ):
117+ if self .num_prefill_tokens > 0 :
118+ if self .paged_kv_indices is None :
119+ return
120+
121+ assert self .prefill_wrapper is not None
122+ assert self .paged_kv_indices is not None
123+ assert self .paged_kv_indptr is not None
124+ assert self .paged_kv_last_page_len is not None
125+ self .paged_kv_indices = self .paged_kv_indices .to (self .device )
126+ self .paged_kv_indptr = self .paged_kv_indptr .to (self .device )
127+ self .paged_kv_last_page_len = self .paged_kv_last_page_len .to (
128+ self .device )
129+ self .prefill_wrapper .begin_forward (
130+ self .query_start_loc , self .paged_kv_indptr ,
131+ self .paged_kv_indices , self .paged_kv_last_page_len ,
132+ self .num_qo_heads , self .num_kv_heads , self .head_dim ,
133+ self .page_size )
134+ else :
135+ if not self .use_cuda_graph :
136+ assert self .paged_kv_indices is not None
137+ assert self .paged_kv_indptr is not None
138+ assert self .paged_kv_last_page_len is not None
139+ self .paged_kv_indices = self .paged_kv_indices .to (self .device )
140+ self .paged_kv_indptr = self .paged_kv_indptr .to (self .device )
141+ self .paged_kv_last_page_len = self .paged_kv_last_page_len .to (
142+ self .device )
143+
144+ assert self .decode_wrapper is not None
119145 self .decode_wrapper .begin_forward (
120146 self .paged_kv_indptr ,
121147 self .paged_kv_indices ,
@@ -133,8 +159,9 @@ def asdict_zerocopy(self,
133159 ) -> Dict [str , Any ]:
134160 if skip_fields is None :
135161 skip_fields = set ()
136- # We need to skip the decode_wrapper field since it cannot be
162+ # We need to skip the prefill/ decode_wrapper field since it cannot be
137163 # broadcasted with nccl when TP is enabled.
164+ skip_fields .add ('prefill_wrapper' )
138165 skip_fields .add ('decode_wrapper' )
139166 return super ().asdict_zerocopy (skip_fields )
140167
@@ -168,6 +195,7 @@ def __init__(
168195 alibi_slopes : Optional [List [float ]],
169196 sliding_window : Optional [int ],
170197 kv_cache_dtype : str ,
198+ blocksparse_params : Optional [Dict [str , Any ]] = None ,
171199 ) -> None :
172200 self .num_heads = num_heads
173201 self .head_size = head_size
@@ -217,10 +245,14 @@ def forward(
217245 self .kv_cache_dtype ,
218246 )
219247
248+ query = query .contiguous (
249+ ) # Flashinfer requires query to be contiguous
220250 if prefill_meta := attn_metadata .prefill_metadata :
221- # Prompt run.
222- assert prefill_meta .block_tables is not None
223- if kv_cache is None or prefill_meta .block_tables .numel () == 0 :
251+ # We will use flash attention for prefill
252+ # when kv_cache is not provided.
253+ # This happens when vllm runs the profiling to
254+ # determine the number of blocks.
255+ if kv_cache is None :
224256 output = flash_attn_varlen_func (
225257 q = query ,
226258 k = key ,
@@ -235,13 +267,14 @@ def forward(
235267 alibi_slopes = self .alibi_slopes ,
236268 )
237269 else :
238- raise NotImplementedError (
239- "Prefix caching is not supported with flashinfer yet." )
270+ assert prefill_meta is not None
271+ assert prefill_meta .prefill_wrapper is not None
272+ output = prefill_meta .prefill_wrapper .forward (query ,
273+ kv_cache ,
274+ causal = True )
240275 else :
241276 assert attn_metadata .decode_metadata is not None
242277 assert attn_metadata .decode_metadata .decode_wrapper is not None
243- query = query .contiguous (
244- ) # Flashinfer requires query to be contiguous
245278 output = attn_metadata .decode_metadata .decode_wrapper .forward (
246279 query ,
247280 kv_cache ,
0 commit comments