2121from ..model_config import ModelConfig
2222from .modeling_auto import AutoModelForCausalLM
2323from .modeling_multimodal_utils import fuse_input_embeds
24+ from .modeling_radio import RADIOVisionModel
2425from .modeling_utils import register_auto_model
2526
2627
@@ -66,19 +67,12 @@ def __init__(self,
6667 self .num_image_token = int ((self .image_size // self .patch_size )** 2 *
6768 (config .downsample_ratio ** 2 ))
6869 self .downsample_ratio = config .downsample_ratio
69- self .ps_version = config .ps_version
70- # self.image_tag_type = config.image_tag_type
71-
72- logger .info (f'num_image_token: { self .num_image_token } ' )
73- logger .info (f'ps_version: { self .ps_version } ' )
74-
75- # self.drop_vision_class_token = True
70+ self .ps_version = config .ps_version # Pixel shuffle version.
7671
7772 # Construct the vision projection.
7873 self .vit_hidden_size = config .vit_hidden_size
7974 self .vision_projection_hidden_size = config .projector_hidden_size
8075 self .llm_hidden_size = config .llm_config .hidden_size
81-
8276 self .mlp1 = nn .Sequential (
8377 RMSNorm (self .vit_hidden_size * int (1 / self .downsample_ratio )** 2 ,
8478 eps = 1e-5 ),
@@ -90,37 +84,19 @@ def __init__(self,
9084 bias = False ))
9185 self .mlp1 = self .mlp1 .to (config .torch_dtype )
9286
93- WITH_HF_CODES = False
94- if WITH_HF_CODES :
87+ # Construct the vision encoder.
88+ self .with_hf_codes = os .getenv ("WITH_HF_CODES" , "0" ) == "1"
89+ if self .with_hf_codes :
9590 self .vision_model = transformers .AutoModel .from_config (
9691 config .vision_config , trust_remote_code = True )
9792 # set input_condition as Identity module.
9893 self .vision_model .radio_model .make_preprocessor_external ()
9994 self .vision_model .to (config .torch_dtype )
100-
101- with open ("hf_vision_encoder_arch.txt" , "w" ) as f :
102- f .write (str (self .vision_model ))
10395 else :
104- WITH_TRTLLM_CODES = True
105- if WITH_TRTLLM_CODES :
106- from .modeling_radio import RADIOVisionModel
107-
108- vision_model_config = copy .deepcopy (model_config )
109- vision_model_config .pretrained_config = vision_model_config .pretrained_config .vision_config
110-
111- self .vision_model = RADIOVisionModel (vision_model_config )
112- self .vision_model .to (config .torch_dtype )
113-
114- with open ("trtllm_vision_encoder_arch.txt" , "w" ) as f :
115- f .write (str (self .vision_model ))
116- else :
117- # Update the vision model with customized one.
118- from .modeling_radio import RADIOModel
119- self .vision_model = RADIOModel (config .vision_config )
120- self .vision_model .to (config .torch_dtype )
121-
122- with open ("user_vision_encoder_arch.txt" , "w" ) as f :
123- f .write (str (self .vision_model ))
96+ vision_model_config = copy .deepcopy (model_config )
97+ vision_model_config .pretrained_config = vision_model_config .pretrained_config .vision_config
98+ self .vision_model = RADIOVisionModel (vision_model_config )
99+ self .vision_model .to (config .torch_dtype )
124100
125101 @torch .compile
126102 def pixel_shuffle (self , x , scale_factor = 0.5 ):
@@ -141,8 +117,12 @@ def pixel_shuffle(self, x, scale_factor=0.5):
141117 return x
142118
143119 def extract_feature (self , pixel_values ):
144- vit_embeds = self .vision_model (pixel_values ).features
120+ if self .with_hf_codes :
121+ vit_embeds = self .vision_model (pixel_values ).features
122+ else :
123+ vit_embeds = self .vision_model (pixel_values )
145124 vit_embeds = vit_embeds .to (dtype = torch .bfloat16 )
125+ # Down-sampling and projection.
146126 h = w = int (vit_embeds .shape [1 ]** 0.5 )
147127 vit_embeds = vit_embeds .reshape (vit_embeds .shape [0 ], h , w , - 1 )
148128 vit_embeds = self .pixel_shuffle (vit_embeds ,
@@ -317,7 +297,11 @@ def load_weights(self, weights):
317297 }
318298 missing_keys , unexpected_keys = self .vision_encoder .load_state_dict (
319299 filter_weights , strict = False )
320- missing_keys .remove ("vision_model.radio_model.summary_idxs" )
300+ try :
301+ missing_keys .remove ("vision_model.radio_model.summary_idxs" )
302+ except ValueError :
303+ pass
304+
321305 unexpected_keys .remove (
322306 "vision_model.radio_model.input_conditioner.norm_mean" )
323307 unexpected_keys .remove (
0 commit comments