@@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
148148
149149@dataclass
150150class FlexAttentionMetadata :
151+ causal : bool
151152 num_actual_tokens : int # Number of tokens excluding padding.
152153 max_query_len : int
153154 query_start_loc : torch .Tensor
@@ -177,10 +178,9 @@ class FlexAttentionMetadata:
177178 num_blocks = 0
178179 block_mask : Optional [BlockMask ] = None
179180 score_mod : Optional [_score_mod_signature ] = None
180- mask_mod : Optional [_mask_mod_signature ] = None
181181 logical_mask_mod : _mask_mod_signature = causal_mask_mod
182182
183- def get_mask_mod (self ) -> _mask_mod_signature :
183+ def get_causal_mask_mod (self ) -> _mask_mod_signature :
184184 """Creates the mask_mod function for FlexAttention.
185185
186186 This function creates the combined mask mod function that handles:
@@ -233,14 +233,39 @@ def final_mask_mod(
233233
234234 return final_mask_mod
235235
236+ def get_bidirectional_mask_mod (self ) -> _mask_mod_signature :
237+ """Creates the encoder mask_mod function for FlexAttention.
238+
239+ Since the encoder bidirectional attention doesn't run with
240+ KV cache, this function creates a mask based on the
241+ packed query sequences.
242+ """
243+ # Create a lookup mapping from query indices -> request number
244+ request_lookup = _offsets_to_doc_ids_tensor (self .query_start_loc )
245+
246+ def final_mask_mod (
247+ b : torch .Tensor ,
248+ h : torch .Tensor ,
249+ q_idx : torch .Tensor ,
250+ kv_idx : torch .Tensor ,
251+ ) -> torch .Tensor :
252+ return request_lookup [q_idx ] == request_lookup [kv_idx ]
253+
254+ return final_mask_mod
255+
236256 def build_block_mask (self ) -> BlockMask :
237- assert self .mask_mod is not None
257+ if self .causal :
258+ mask_mod = self .get_causal_mask_mod ()
259+ kv_len = self .total_cache_tokens
260+ else :
261+ mask_mod = self .get_bidirectional_mask_mod ()
262+ kv_len = self .num_actual_tokens
238263 return create_block_mask_compiled (
239- self . mask_mod ,
264+ mask_mod ,
240265 None ,
241266 None ,
242267 self .num_actual_tokens ,
243- self . total_cache_tokens ,
268+ kv_len ,
244269 device = self .block_table .device ,
245270 )
246271
@@ -251,7 +276,6 @@ def __post_init__(self):
251276 assert self .prefix_kv_lens is None , "Not implemented yet."
252277 assert self .suffix_kv_lens is None , "Not implemented yet."
253278 self .num_blocks = self .total_cache_tokens // self .block_size
254- self .mask_mod = self .get_mask_mod ()
255279 self .block_mask = self .build_block_mask ()
256280
257281
@@ -306,6 +330,7 @@ def build(self,
306330 self .device , non_blocking = True )
307331
308332 out = FlexAttentionMetadata (
333+ causal = common_attn_metadata .causal ,
309334 num_actual_tokens = num_actual_tokens ,
310335 max_query_len = max_query_len ,
311336 query_start_loc = query_start_loc ,
@@ -350,6 +375,12 @@ def __init__(
350375 self .head_size = head_size
351376 self .scale = float (scale )
352377 self .num_kv_heads = num_kv_heads
378+ self .attn_type = attn_type
379+
380+ if attn_type not in (AttentionType .ENCODER_ONLY ,
381+ AttentionType .DECODER ):
382+ raise NotImplementedError (
383+ f"FlexAttention does not support { attn_type } attention" )
353384
354385 if alibi_slopes is not None :
355386 raise NotImplementedError (
@@ -425,26 +456,38 @@ def forward(
425456
426457 num_actual_tokens = attn_metadata .num_actual_tokens
427458
428- key_cache , value_cache = kv_cache .unbind (0 )
429-
430- torch .ops ._C_cache_ops .reshape_and_cache_flash (
431- key ,
432- value ,
433- key_cache ,
434- value_cache ,
435- attn_metadata .slot_mapping ,
436- self .kv_cache_dtype ,
437- layer ._k_scale ,
438- layer ._v_scale ,
439- )
459+ if not attn_metadata .causal :
460+ assert self .attn_type == AttentionType .ENCODER_ONLY
461+
462+ query , key_tensor , value_tensor = map (
463+ lambda x : self .view_as_4d (x ).permute (0 , 2 , 1 , 3 ),
464+ (query , key , value ),
465+ )
466+
467+ else :
468+ assert self .attn_type == AttentionType .DECODER
469+ key_cache , value_cache = kv_cache .unbind (0 )
470+
471+ torch .ops ._C_cache_ops .reshape_and_cache_flash (
472+ key ,
473+ value ,
474+ key_cache ,
475+ value_cache ,
476+ attn_metadata .slot_mapping ,
477+ self .kv_cache_dtype ,
478+ layer ._k_scale ,
479+ layer ._v_scale ,
480+ )
481+
482+ # View out the block_size dim
483+ key_cache = key_cache .view (- 1 , self .num_kv_heads , self .head_size )
484+ value_cache = value_cache .view (- 1 , self .num_kv_heads ,
485+ self .head_size )
486+ query , key_tensor , value_tensor = map (
487+ lambda x : self .view_as_4d (x ).permute (0 , 2 , 1 , 3 ),
488+ (query , key_cache , value_cache ),
489+ )
440490
441- # View out the block_size dim
442- key_cache = key_cache .view (- 1 , self .num_kv_heads , self .head_size )
443- value_cache = value_cache .view (- 1 , self .num_kv_heads , self .head_size )
444- query , key_cache , value_cache = map (
445- lambda x : self .view_as_4d (x ).permute (0 , 2 , 1 , 3 ),
446- (query , key_cache , value_cache ),
447- )
448491 query = query [:, :, :num_actual_tokens , :]
449492 # Doesn't work for now -> constraint violation
450493 # torch._dynamo.try_mark_dynamic(query, 2)
@@ -465,8 +508,8 @@ def forward(
465508
466509 out = flex_attention_compiled (
467510 query ,
468- key_cache ,
469- value_cache ,
511+ key_tensor ,
512+ value_tensor ,
470513 attn_metadata .score_mod ,
471514 attn_metadata .block_mask ,
472515 self .scale ,
0 commit comments