88
99import torch
1010import torch .nn as nn
11- from transformers import PretrainedConfig
1211
1312from vllm .attention .layer import MultiHeadAttention
1413from vllm .distributed import get_tensor_model_parallel_world_size
2120from vllm .model_executor .layers .quantization .base_config import (
2221 QuantizationConfig )
2322from vllm .model_executor .model_loader .weight_utils import default_weight_loader
23+ from vllm .transformers_utils .configs .ovis import AIMv2Config
2424
2525
2626class AIMv2SwiGLUFFN (nn .Module ):
2727
28- def __init__ (self , config : PretrainedConfig ,
29- quant_config : QuantizationConfig , prefix : str ):
28+ def __init__ (self , config : AIMv2Config , quant_config : QuantizationConfig ,
29+ prefix : str ):
3030 super ().__init__ ()
3131 hidden_features = config .intermediate_size
3232 in_features = config .hidden_size
@@ -57,7 +57,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5757
5858class AIMv2PatchEmbed (nn .Module ):
5959
60- def __init__ (self , config : PretrainedConfig ):
60+ def __init__ (self , config : AIMv2Config ):
6161 super ().__init__ ()
6262 self .proj = nn .Conv2d (
6363 config .num_channels ,
@@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7575
7676class AIMv2ViTPreprocessor (nn .Module ):
7777
78- def __init__ (self , config : PretrainedConfig ):
78+ def __init__ (self , config : AIMv2Config ):
7979 super ().__init__ ()
8080 num_patches = (config .image_size // config .patch_size )** 2
8181
@@ -93,8 +93,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9393
9494class AIMv2Attention (nn .Module ):
9595
96- def __init__ (self , config : PretrainedConfig ,
97- quant_config : QuantizationConfig , prefix : str ):
96+ def __init__ (self , config : AIMv2Config , quant_config : QuantizationConfig ,
97+ prefix : str ):
9898 super ().__init__ ()
9999 self .config = config
100100 self .embed_dim = config .hidden_size
@@ -141,8 +141,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
141141
142142class AIMv2Block (nn .Module ):
143143
144- def __init__ (self , config : PretrainedConfig ,
145- quant_config : QuantizationConfig , prefix : str ):
144+ def __init__ (self , config : AIMv2Config , quant_config : QuantizationConfig ,
145+ prefix : str ):
146146 super ().__init__ ()
147147 self .attn = AIMv2Attention (config ,
148148 quant_config = quant_config ,
@@ -163,7 +163,7 @@ class AIMv2Transformer(nn.Module):
163163
164164 def __init__ (
165165 self ,
166- config : PretrainedConfig ,
166+ config : AIMv2Config ,
167167 quant_config : QuantizationConfig ,
168168 * ,
169169 require_post_norm : Optional [bool ] = None ,
@@ -193,7 +193,7 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
193193class AIMv2Model (torch .nn .Module ):
194194
195195 def __init__ (self ,
196- config : PretrainedConfig ,
196+ config : AIMv2Config ,
197197 quant_config : QuantizationConfig ,
198198 * ,
199199 require_post_norm : Optional [bool ] = None ,
0 commit comments