3030from vllm .sequence import IntermediateTensors
3131
3232# yapf: disable
33+ from .idefics2_vision_model import Idefics2VisionConfig
3334from .idefics2_vision_model import (
3435 Idefics2VisionTransformer as Idefics3VisionTransformer )
3536# yapf: enable
@@ -50,6 +51,53 @@ class AriaImagePixelInputs(TypedDict):
5051 """
5152
5253
54+ class AriaVisionTransformer (Idefics3VisionTransformer ):
55+
56+ def __init__ (
57+ self ,
58+ config : Idefics2VisionConfig ,
59+ quant_config : Optional [QuantizationConfig ] = None ,
60+ prefix : str = "" ,
61+ ) -> None :
62+ super ().__init__ (config , quant_config , prefix )
63+ # Unlike Idefics3VisionTransformer which uses LayerNorm after the
64+ # final layer, Aria omits this normalization, so we replace it with an
65+ # Identity layer
66+ self .post_layernorm = nn .Identity ()
67+
68+ def load_weights (self , weights : Iterable [Tuple [str ,
69+ torch .Tensor ]]) -> Set [str ]:
70+ stacked_params_mapping = [
71+ # (param_name, shard_name, shard_id)
72+ ("qkv_proj" , "q_proj" , "q" ),
73+ ("qkv_proj" , "k_proj" , "k" ),
74+ ("qkv_proj" , "v_proj" , "v" ),
75+ ]
76+ params_dict = dict (self .named_parameters ())
77+ loaded_params : Set [str ] = set ()
78+ for name , loaded_weight in weights :
79+
80+ # NOTE: post_layernorm is not used in Aria
81+ if "post_layernorm" in name :
82+ continue
83+
84+ for param_name , weight_name , shard_id in stacked_params_mapping :
85+ if weight_name not in name :
86+ continue
87+ name = name .replace (weight_name , param_name )
88+ param = params_dict [name ]
89+ weight_loader = param .weight_loader
90+ weight_loader (param , loaded_weight , shard_id )
91+ break
92+ else :
93+ param = params_dict [name ]
94+ weight_loader = getattr (param , "weight_loader" ,
95+ default_weight_loader )
96+ weight_loader (param , loaded_weight )
97+ loaded_params .add (name )
98+ return loaded_params
99+
100+
53101class AriaProjectorMLP (nn .Module ):
54102
55103 def __init__ (
@@ -228,8 +276,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
228276 router_output = torch .nn .functional .linear (hidden_states ,
229277 self .router_weight )
230278
279+ hidden_states_copy = hidden_states .clone ()
280+ # NOTE: hidden_states will be modified inplace by `FusedMoE`
231281 sparse_expert_output = self .experts (hidden_states , router_output )
232- shared_expert_output = self .shared_experts (hidden_states )
282+ shared_expert_output = self .shared_experts (hidden_states_copy )
233283
234284 return sparse_expert_output + shared_expert_output
235285
@@ -445,7 +495,7 @@ def __init__(
445495 quant_config = vllm_config .quant_config
446496
447497 self .config = config
448- self .vision_tower = Idefics3VisionTransformer (
498+ self .vision_tower = AriaVisionTransformer (
449499 config .vision_config ,
450500 quant_config ,
451501 prefix = f"{ prefix } .vision_tower" ,
0 commit comments