22
33# A modified implementation of the AIMv2 Transformer
44# inserted here also the image tokenizer used by Ovis2
5+ from collections .abc import Iterable
56from typing import Optional
67
78import torch
89import torch .nn as nn
9- from torch .nn import functional as F
1010
11+ from vllm .attention .layer import MultiHeadAttention
12+ from vllm .distributed import get_tensor_model_parallel_world_size
13+ from vllm .distributed .utils import divide
14+ from vllm .model_executor .layers .activation import SiluAndMul
1115from vllm .model_executor .layers .layernorm import RMSNorm
12- from vllm .model_executor .layers .linear import ReplicatedLinear
16+ from vllm .model_executor .layers .linear import (MergedColumnParallelLinear ,
17+ QKVParallelLinear ,
18+ RowParallelLinear )
1319from vllm .model_executor .layers .quantization .base_config import (
1420 QuantizationConfig )
21+ from vllm .model_executor .model_loader .weight_utils import default_weight_loader
1522from vllm .transformers_utils .configs .ovis import AIMv2Config
1623
1724
@@ -24,29 +31,27 @@ def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
2431 in_features = config .hidden_size
2532 bias = config .use_bias
2633
27- # TODO(Isotr0py): investigate if we can add TP to visual tokenizer
28- self .fc1 = ReplicatedLinear (in_features ,
29- hidden_features ,
30- bias = bias ,
31- quant_config = quant_config ,
32- prefix = f"{ prefix } .fc1" )
33- self .fc2 = ReplicatedLinear (hidden_features ,
34- in_features ,
35- bias = bias ,
36- quant_config = quant_config ,
37- prefix = f"{ prefix } .fc2" )
38- self .fc3 = ReplicatedLinear (in_features ,
39- hidden_features ,
40- bias = bias ,
41- quant_config = quant_config ,
42- prefix = f"{ prefix } .fc3" )
34+ self .fc13 = MergedColumnParallelLinear (
35+ in_features ,
36+ [hidden_features ] * 2 ,
37+ bias = bias ,
38+ quant_config = quant_config ,
39+ prefix = f"{ prefix } .fc13" ,
40+ )
41+ self .fc2 = RowParallelLinear (
42+ input_size = hidden_features ,
43+ output_size = in_features ,
44+ bias = bias ,
45+ quant_config = quant_config ,
46+ prefix = f"{ prefix } .fc2" ,
47+ )
48+ self .act_fn = SiluAndMul ()
4349
4450 def forward (self , x : torch .Tensor ) -> torch .Tensor :
45- x_parallel , _ = self .fc1 (x )
46- gate , _ = self .fc3 (x )
47- x_parallel = F .silu (x_parallel ) * gate
48- out , _ = self .fc2 (x_parallel )
49- return out
51+ x , _ = self .fc13 (x )
52+ x = self .act_fn (x )
53+ x , _ = self .fc2 (x )
54+ return x
5055
5156
5257class AIMv2PatchEmbed (nn .Module ):
@@ -90,39 +95,45 @@ class AIMv2Attention(nn.Module):
9095 def __init__ (self , config : AIMv2Config , quant_config : QuantizationConfig ,
9196 prefix : str ):
9297 super ().__init__ ()
93- dim = config .hidden_size
94-
95- # TODO(Isotr0py): investigate if we can add TP to visual tokenizer
98+ self .config = config
99+ self .embed_dim = config .hidden_size
96100 self .num_heads = config .num_attention_heads
97- self .qkv = ReplicatedLinear (dim , dim * 3 , bias = config .qkv_bias )
98- # self.qkv = QKVParallelLinear(
99- # hidden_size=dim,
100- # head_size=dim // config.num_attention_heads,
101- # total_num_heads=config.num_attention_heads,
102- # bias=config.qkv_bias,
103- # quant_config=quant_config,
104- # prefix=f"{prefix}.qkv")
105- self .proj = ReplicatedLinear (dim , dim , bias = config .use_bias )
106- # self.proj = RowParallelLinear(input_size=dim,
107- # output_size=dim,
108- # bias = config.use_bias,
109- # quant_config=quant_config,
110- # prefix=f"{prefix}.proj")
111-
112- def forward ( # todo might implement multiple attn implementations
113- self ,
114- x : torch .Tensor ,
115- mask : Optional [torch .Tensor ] = None ) -> torch .Tensor :
116- B , N , C = x .shape
117- qkv , _ = self .qkv (x )
101+ self .head_dim = self .embed_dim // self .num_heads
102+ if self .head_dim * self .num_heads != self .embed_dim :
103+ raise ValueError (
104+ "embed_dim must be divisible by num_heads "
105+ f"(got `embed_dim`: { self .embed_dim } and `num_heads`:"
106+ f" { self .num_heads } )." )
107+ self .scale = self .head_dim ** - 0.5
108+
109+ self .qkv = QKVParallelLinear (
110+ hidden_size = self .embed_dim ,
111+ head_size = self .head_dim ,
112+ total_num_heads = self .num_heads ,
113+ bias = config .qkv_bias ,
114+ quant_config = quant_config ,
115+ prefix = f"{ prefix } .qkv" ,
116+ )
117+
118+ self .proj = RowParallelLinear (
119+ input_size = self .embed_dim ,
120+ output_size = self .embed_dim ,
121+ bias = config .use_bias ,
122+ quant_config = quant_config ,
123+ prefix = f"{ prefix } .proj" ,
124+ )
125+
126+ self .tp_size = get_tensor_model_parallel_world_size ()
127+ self .num_heads_per_partition = divide (self .num_heads , self .tp_size )
118128
119- qkv = qkv . reshape ( B , N , 3 , self .num_heads ,
120- C // self . num_heads ). permute ( 2 , 0 , 3 , 1 , 4 )
129+ self . attn = MultiHeadAttention ( self .num_heads_per_partition ,
130+ self . head_dim , self . scale )
121131
122- q , k , v = qkv .unbind (0 )
132+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
133+ qkv , _ = self .qkv (x )
134+ q , k , v = qkv .chunk (3 , dim = - 1 )
123135
124- x = F .scaled_dot_product_attention (q , k , v , attn_mask = mask )
125- x = x .transpose (1 , 2 ).contiguous ().reshape (B , N , C )
136+ x = self .attn (q , k , v )
126137 x , _ = self .proj (x )
127138 return x
128139
@@ -141,37 +152,40 @@ def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
141152 prefix = f"{ prefix } .mlp" )
142153 self .norm_2 = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
143154
144- def forward (self ,
145- x : torch .Tensor ,
146- mask : Optional [torch .Tensor ] = None ) -> torch .Tensor :
147- x = x + self .attn (self .norm_1 .forward_native (x ), mask )
155+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
156+ x = x + self .attn (self .norm_1 .forward_native (x ))
148157 x = x + self .mlp (self .norm_2 .forward_native (x ))
149158 return x
150159
151160
152161class AIMv2Transformer (nn .Module ):
153162
154- def __init__ (self , config : AIMv2Config , quant_config : QuantizationConfig ,
155- prefix : str ):
163+ def __init__ (
164+ self ,
165+ config : AIMv2Config ,
166+ quant_config : QuantizationConfig ,
167+ * ,
168+ require_post_norm : Optional [bool ] = None ,
169+ prefix : str = "" ,
170+ ):
156171 super ().__init__ ()
157172
158173 self .blocks = nn .ModuleList ([
159174 AIMv2Block (config , quant_config , prefix = f"{ prefix } .blocks.{ i } " )
160175 for i in range (config .num_hidden_layers )
161176 ])
162- self .post_trunk_norm = RMSNorm (config .hidden_size ,
163- eps = config .rms_norm_eps )
177+ if require_post_norm :
178+ self .post_trunk_norm = RMSNorm (config .hidden_size ,
179+ eps = config .rms_norm_eps )
180+ else :
181+ self .post_trunk_norm = None
164182
165- def forward (
166- self ,
167- tokens : torch .Tensor ,
168- mask : Optional [torch .Tensor ] = None ,
169- ) -> torch .Tensor :
183+ def forward (self , tokens : torch .Tensor ) -> torch .Tensor :
170184 # they take the -1 as the ref embeddings, like a clip skip
171185 for block in self .blocks :
172- tokens = block (tokens , mask )
173- # NO NORM IN THE OG IMPLEMENTATION
174- # tokens = self.post_trunk_norm(tokens)
186+ tokens = block (tokens )
187+ if self . post_trunk_norm is not None :
188+ tokens = self .post_trunk_norm (tokens )
175189 return tokens
176190
177191
@@ -180,20 +194,52 @@ class AIMv2Model(torch.nn.Module):
180194 def __init__ (self ,
181195 config : AIMv2Config ,
182196 quant_config : QuantizationConfig ,
197+ * ,
198+ require_post_norm : Optional [bool ] = None ,
183199 prefix : str = "" ):
184200 super ().__init__ ()
185201 self .preprocessor = AIMv2ViTPreprocessor (config )
186202 self .trunk = AIMv2Transformer (config ,
187203 quant_config = quant_config ,
204+ require_post_norm = require_post_norm ,
188205 prefix = f"{ prefix } .trunk" )
189206
190- def forward (
191- self ,
192- pixel_values : torch .Tensor ,
193- mask : Optional [torch .Tensor ] = None ,
194- ) -> torch .Tensor :
207+ def forward (self , pixel_values : torch .Tensor ) -> torch .Tensor :
195208
196209 x = self .preprocessor (pixel_values )
197- x = self .trunk (x , mask )
210+ x = self .trunk (x )
198211
199212 return x
213+
214+ def load_weights (self , weights : Iterable [tuple [str ,
215+ torch .Tensor ]]) -> set [str ]:
216+ stacked_params_mapping = [
217+ # (param_name, shard_name, shard_id)
218+ (".fc13" , ".fc1" , 0 ),
219+ (".fc13" , ".fc3" , 1 ),
220+ ]
221+ params_dict = dict (self .named_parameters ())
222+ loaded_params : set [str ] = set ()
223+
224+ for name , loaded_weight in weights :
225+ # post_layernorm is optional in SiglipVisionModel
226+ if (name .startswith ("trunk.post_trunk_norm" )
227+ and self .trunk .post_trunk_norm is None ):
228+ continue
229+
230+ for (param_name , weight_name , shard_id ) in stacked_params_mapping :
231+ if weight_name not in name :
232+ continue
233+ name = name .replace (weight_name , param_name )
234+
235+ param = params_dict [name ]
236+ weight_loader = param .weight_loader
237+ weight_loader (param , loaded_weight , shard_id )
238+ break
239+ else :
240+ param = params_dict [name ]
241+ weight_loader = getattr (param , "weight_loader" ,
242+ default_weight_loader )
243+ weight_loader (param , loaded_weight )
244+ loaded_params .add (name )
245+ return loaded_params
0 commit comments