88
99import  torch 
1010import  torch .nn  as  nn 
11- from  transformers  import  PretrainedConfig 
1211
1312from  vllm .attention .layer  import  MultiHeadAttention 
1413from  vllm .distributed  import  get_tensor_model_parallel_world_size 
2120from  vllm .model_executor .layers .quantization .base_config  import  (
2221    QuantizationConfig )
2322from  vllm .model_executor .model_loader .weight_utils  import  default_weight_loader 
23+ from  vllm .transformers_utils .configs .ovis  import  AIMv2Config 
2424
2525
2626class  AIMv2SwiGLUFFN (nn .Module ):
2727
28-     def  __init__ (self , config : PretrainedConfig ,
29-                  quant_config :  QuantizationConfig ,  prefix : str ):
28+     def  __init__ (self , config : AIMv2Config ,  quant_config :  QuantizationConfig ,
29+                  prefix : str ):
3030        super ().__init__ ()
3131        hidden_features  =  config .intermediate_size 
3232        in_features  =  config .hidden_size 
@@ -57,7 +57,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5757
5858class  AIMv2PatchEmbed (nn .Module ):
5959
60-     def  __init__ (self , config : PretrainedConfig ):
60+     def  __init__ (self , config : AIMv2Config ):
6161        super ().__init__ ()
6262        self .proj  =  nn .Conv2d (
6363            config .num_channels ,
@@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7575
7676class  AIMv2ViTPreprocessor (nn .Module ):
7777
78-     def  __init__ (self , config : PretrainedConfig ):
78+     def  __init__ (self , config : AIMv2Config ):
7979        super ().__init__ ()
8080        num_patches  =  (config .image_size  //  config .patch_size )** 2 
8181
@@ -93,8 +93,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9393
9494class  AIMv2Attention (nn .Module ):
9595
96-     def  __init__ (self , config : PretrainedConfig ,
97-                  quant_config :  QuantizationConfig ,  prefix : str ):
96+     def  __init__ (self , config : AIMv2Config ,  quant_config :  QuantizationConfig ,
97+                  prefix : str ):
9898        super ().__init__ ()
9999        self .config  =  config 
100100        self .embed_dim  =  config .hidden_size 
@@ -141,8 +141,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
141141
142142class  AIMv2Block (nn .Module ):
143143
144-     def  __init__ (self , config : PretrainedConfig ,
145-                  quant_config :  QuantizationConfig ,  prefix : str ):
144+     def  __init__ (self , config : AIMv2Config ,  quant_config :  QuantizationConfig ,
145+                  prefix : str ):
146146        super ().__init__ ()
147147        self .attn  =  AIMv2Attention (config ,
148148                                   quant_config = quant_config ,
@@ -163,7 +163,7 @@ class AIMv2Transformer(nn.Module):
163163
164164    def  __init__ (
165165        self ,
166-         config : PretrainedConfig ,
166+         config : AIMv2Config ,
167167        quant_config : QuantizationConfig ,
168168        * ,
169169        require_post_norm : Optional [bool ] =  None ,
@@ -193,7 +193,7 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
193193class  AIMv2Model (torch .nn .Module ):
194194
195195    def  __init__ (self ,
196-                  config : PretrainedConfig ,
196+                  config : AIMv2Config ,
197197                 quant_config : QuantizationConfig ,
198198                 * ,
199199                 require_post_norm : Optional [bool ] =  None ,
0 commit comments