3030from vllm .sequence import IntermediateTensors
3131
3232# yapf: disable
33+ from .idefics2_vision_model import Idefics2VisionConfig
3334from .idefics2_vision_model import (
3435 Idefics2VisionTransformer as Idefics3VisionTransformer )
3536# yapf: enable
@@ -50,6 +51,50 @@ class AriaImagePixelInputs(TypedDict):
5051 """
5152
5253
54+ class AriaVisionTransformer (Idefics3VisionTransformer ):
55+
56+ def __init__ (
57+ self ,
58+ config : Idefics2VisionConfig ,
59+ quant_config : Optional [QuantizationConfig ] = None ,
60+ prefix : str = "" ,
61+ ) -> None :
62+ super ().__init__ (config , quant_config , prefix )
63+ self .post_layernorm = nn .Identity ()
64+
65+ def load_weights (self , weights : Iterable [Tuple [str ,
66+ torch .Tensor ]]) -> Set [str ]:
67+ stacked_params_mapping = [
68+ # (param_name, shard_name, shard_id)
69+ ("qkv_proj" , "q_proj" , "q" ),
70+ ("qkv_proj" , "k_proj" , "k" ),
71+ ("qkv_proj" , "v_proj" , "v" ),
72+ ]
73+ params_dict = dict (self .named_parameters ())
74+ loaded_params : Set [str ] = set ()
75+ for name , loaded_weight in weights :
76+
77+ # NOTE: post_layernorm is not used in Aria
78+ if "post_layernorm" in name :
79+ continue
80+
81+ for param_name , weight_name , shard_id in stacked_params_mapping :
82+ if weight_name not in name :
83+ continue
84+ name = name .replace (weight_name , param_name )
85+ param = params_dict [name ]
86+ weight_loader = param .weight_loader
87+ weight_loader (param , loaded_weight , shard_id )
88+ break
89+ else :
90+ param = params_dict [name ]
91+ weight_loader = getattr (param , "weight_loader" ,
92+ default_weight_loader )
93+ weight_loader (param , loaded_weight )
94+ loaded_params .add (name )
95+ return loaded_params
96+
97+
5398class AriaProjectorMLP (nn .Module ):
5499
55100 def __init__ (
@@ -228,8 +273,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
228273 router_output = torch .nn .functional .linear (hidden_states ,
229274 self .router_weight )
230275
276+ hidden_states_copy = hidden_states .clone ()
277+ # NOTE: hidden_states will be modified inplace by `FusedMoE`
231278 sparse_expert_output = self .experts (hidden_states , router_output )
232- shared_expert_output = self .shared_experts (hidden_states )
279+ shared_expert_output = self .shared_experts (hidden_states_copy )
233280
234281 return sparse_expert_output + shared_expert_output
235282
@@ -445,7 +492,7 @@ def __init__(
445492 quant_config = vllm_config .quant_config
446493
447494 self .config = config
448- self .vision_tower = Idefics3VisionTransformer (
495+ self .vision_tower = AriaVisionTransformer (
449496 config .vision_config ,
450497 quant_config ,
451498 prefix = f"{ prefix } .vision_tower" ,
0 commit comments