@@ -274,6 +274,8 @@ def __init__(
274274        quant_config : Optional [QuantizationConfig ] =  None ,
275275        prefix : str  =  "" ,
276276        use_data_parallel : bool  =  False ,
277+         attn_backend : _Backend  =  _Backend .TORCH_SDPA ,
278+         use_upstream_fa : bool  =  False ,
277279    ) ->  None :
278280        super ().__init__ ()
279281        # Per attention head and per partition values. 
@@ -300,25 +302,8 @@ def __init__(
300302                                      quant_config = quant_config ,
301303                                      prefix = f"{ prefix }  ,
302304                                      disable_tp = use_data_parallel )
303- 
304-         # Detect attention implementation. 
305-         self .attn_backend  =  get_vit_attn_backend (
306-             head_size = self .hidden_size_per_attention_head ,
307-             dtype = torch .get_default_dtype ())
308-         self .use_upstream_fa  =  False 
309-         if  self .attn_backend  !=  _Backend .FLASH_ATTN  and  \
310-             check_upstream_fa_availability (
311-                 torch .get_default_dtype ()):
312-             self .attn_backend  =  _Backend .FLASH_ATTN 
313-             self .use_upstream_fa  =  True 
314- 
315-         if  self .attn_backend  not  in 
316-                 _Backend .FLASH_ATTN , _Backend .TORCH_SDPA , _Backend .XFORMERS ,
317-                 _Backend .ROCM_AITER_FA 
318-         }:
319-             raise  RuntimeError (
320-                 f"Qwen2.5-VL does not support { self .attn_backend }  
321-             )
305+         self .attn_backend  =  attn_backend 
306+         self .use_upstream_fa  =  use_upstream_fa 
322307        self .is_flash_attn_backend  =  self .attn_backend  in  {
323308            _Backend .FLASH_ATTN , _Backend .ROCM_AITER_FA 
324309        }
@@ -443,6 +428,8 @@ def __init__(
443428        quant_config : Optional [QuantizationConfig ] =  None ,
444429        prefix : str  =  "" ,
445430        use_data_parallel : bool  =  False ,
431+         attn_backend : _Backend  =  _Backend .TORCH_SDPA ,
432+         use_upstream_fa : bool  =  False ,
446433    ) ->  None :
447434        super ().__init__ ()
448435        if  norm_layer  is  None :
@@ -455,7 +442,9 @@ def __init__(
455442            projection_size = dim ,
456443            quant_config = quant_config ,
457444            prefix = f"{ prefix }  ,
458-             use_data_parallel = use_data_parallel )
445+             use_data_parallel = use_data_parallel ,
446+             attn_backend = attn_backend ,
447+             use_upstream_fa = use_upstream_fa )
459448        self .mlp  =  Qwen2_5_VisionMLP (dim ,
460449                                     mlp_hidden_dim ,
461450                                     act_fn = act_fn ,
@@ -627,17 +616,35 @@ def __init__(
627616        head_dim  =  self .hidden_size  //  self .num_heads 
628617        self .rotary_pos_emb  =  Qwen2_5_VisionRotaryEmbedding (head_dim  //  2 )
629618
619+         use_upstream_fa  =  False 
620+         self .attn_backend  =  get_vit_attn_backend (
621+             head_size = head_dim , dtype = torch .get_default_dtype ())
622+         if  self .attn_backend  !=  _Backend .FLASH_ATTN  and  \
623+             check_upstream_fa_availability (
624+                 torch .get_default_dtype ()):
625+             self .attn_backend  =  _Backend .FLASH_ATTN 
626+             use_upstream_fa  =  True 
627+ 
628+         if  self .attn_backend  not  in 
629+                 _Backend .FLASH_ATTN , _Backend .TORCH_SDPA , _Backend .XFORMERS ,
630+                 _Backend .ROCM_AITER_FA 
631+         }:
632+             raise  RuntimeError (
633+                 f"Qwen2.5-VL does not support { self .attn_backend }  
634+             )
635+ 
630636        self .blocks  =  nn .ModuleList ([
631-             Qwen2_5_VisionBlock (dim = self .hidden_size ,
632-                                 num_heads = self .num_heads ,
633-                                 mlp_hidden_dim = vision_config .intermediate_size ,
634-                                 act_fn = get_act_and_mul_fn (
635-                                     vision_config .hidden_act ),
636-                                 norm_layer = norm_layer ,
637-                                 quant_config = quant_config ,
638-                                 prefix = f"{ prefix } { layer_idx }  ,
639-                                 use_data_parallel = use_data_parallel )
640-             for  layer_idx  in  range (depth )
637+             Qwen2_5_VisionBlock (
638+                 dim = self .hidden_size ,
639+                 num_heads = self .num_heads ,
640+                 mlp_hidden_dim = vision_config .intermediate_size ,
641+                 act_fn = get_act_and_mul_fn (vision_config .hidden_act ),
642+                 norm_layer = norm_layer ,
643+                 quant_config = quant_config ,
644+                 prefix = f"{ prefix } { layer_idx }  ,
645+                 use_data_parallel = use_data_parallel ,
646+                 attn_backend = self .attn_backend ,
647+                 use_upstream_fa = use_upstream_fa ) for  layer_idx  in  range (depth )
641648        ])
642649        self .merger  =  Qwen2_5_VisionPatchMerger (
643650            d_model = vision_config .out_hidden_size ,
@@ -648,12 +655,6 @@ def __init__(
648655            prefix = f"{ prefix }  ,
649656            use_data_parallel = use_data_parallel ,
650657        )
651-         self .attn_backend  =  get_vit_attn_backend (
652-             head_size = head_dim , dtype = torch .get_default_dtype ())
653-         if  self .attn_backend  !=  _Backend .FLASH_ATTN  and  \
654-             check_upstream_fa_availability (
655-                 torch .get_default_dtype ()):
656-             self .attn_backend  =  _Backend .FLASH_ATTN 
657658
658659    @property  
659660    def  dtype (self ) ->  torch .dtype :
0 commit comments