99
1010import  vllm .envs  as  envs 
1111from  vllm .attention  import  AttentionType 
12+ from  vllm .attention .backends .abstract  import  AttentionBackend 
1213from  vllm .attention .selector  import  backend_name_to_enum , get_attn_backend 
1314from  vllm .attention .utils .kv_sharing_utils  import  validate_kv_sharing_target 
1415from  vllm .config  import  CacheConfig , get_current_vllm_config 
@@ -80,6 +81,7 @@ def __init__(
8081        prefix : str  =  "" ,
8182        attn_type : str  =  AttentionType .DECODER ,
8283        kv_sharing_target_layer_name : Optional [str ] =  None ,
84+         attn_backend : Optional [type [AttentionBackend ]] =  None ,
8385        ** extra_impl_args ,
8486    ) ->  None :
8587        """ 
@@ -137,15 +139,6 @@ def __init__(
137139        self .num_kv_heads  =  num_kv_heads 
138140        self .sliding_window  =  sliding_window 
139141
140-         # For v1 we have backend agnostic iRoPE (local chunked attention) 
141-         # we have to store the flag on the layer so gpu model runner can 
142-         # set KVSpec appropriately (and pop it so it doesnt get passed to 
143-         # the backends) 
144-         if  envs .VLLM_USE_V1 :
145-             self .use_irope  =  extra_impl_args .pop ("use_irope" , False )
146-         else :
147-             self .use_irope  =  extra_impl_args .get ("use_irope" , False )
148- 
149142        quant_method  =  quant_config .get_quant_method (
150143            self , prefix = prefix ) if  quant_config  else  None 
151144        if  quant_method  is  not None  and  not  isinstance (
@@ -166,18 +159,22 @@ def __init__(
166159        # During model initialization, the default dtype is set as the model 
167160        # weight and activation dtype. 
168161        dtype  =  torch .get_default_dtype ()
169-         attn_backend  =  get_attn_backend (head_size ,
170-                                         dtype ,
171-                                         kv_cache_dtype ,
172-                                         block_size ,
173-                                         is_attention_free ,
174-                                         use_mla = use_mla )
175-         impl_cls  =  attn_backend .get_impl_cls ()
162+         if  attn_backend  is  None :
163+             self .attn_backend  =  get_attn_backend (head_size ,
164+                                                  dtype ,
165+                                                  kv_cache_dtype ,
166+                                                  block_size ,
167+                                                  is_attention_free ,
168+                                                  use_mla = use_mla )
169+         else :
170+             self .attn_backend  =  attn_backend 
171+ 
172+         impl_cls  =  self .attn_backend .get_impl_cls ()
176173        self .impl  =  impl_cls (num_heads , head_size , scale , num_kv_heads ,
177174                             alibi_slopes , sliding_window , kv_cache_dtype ,
178175                             logits_soft_cap , attn_type ,
179176                             kv_sharing_target_layer_name , ** extra_impl_args )
180-         self .backend  =  backend_name_to_enum (attn_backend .get_name ())
177+         self .backend  =  backend_name_to_enum (self . attn_backend .get_name ())
181178        self .dtype  =  dtype 
182179
183180        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how 
@@ -187,7 +184,7 @@ def __init__(
187184        self .use_direct_call  =  not  current_platform .is_cuda_alike (
188185        ) and  not  current_platform .is_cpu ()
189186
190-         self .use_output  =  attn_backend .accept_output_buffer 
187+         self .use_output  =  self . attn_backend .accept_output_buffer 
191188        compilation_config  =  get_current_vllm_config ().compilation_config 
192189        if  prefix  in  compilation_config .static_forward_context :
193190            raise  ValueError (f"Duplicate layer name: { prefix }  )
@@ -309,6 +306,9 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
309306        if  hasattr (self .impl , "process_weights_after_loading" ):
310307            self .impl .process_weights_after_loading (act_dtype )
311308
309+     def  get_attn_backend (self ) ->  type [AttentionBackend ]:
310+         return  self .attn_backend 
311+ 
312312
313313class  MultiHeadAttention (nn .Module ):
314314    """Multi-headed attention without any cache, used for ViT.""" 
0 commit comments