@@ -29,6 +29,9 @@ def _is_disagg() -> bool:
2929 return os .getenv ("TLLM_MULTIMODAL_DISAGGREGATED" , "0" ) == "1"
3030
3131
32+ IMAGE_TOKEN_ID = 131072
33+
34+
3235class SquaredReLU (nn .Module ):
3336
3437 def forward (self , x ):
@@ -77,25 +80,22 @@ def __init__(self,
7780 self .llm_hidden_size = config .llm_config .hidden_size
7881
7982 self .mlp1 = nn .Sequential (
80- # nn.LayerNorm(self.vit_hidden_size *
81- # int(1 / self.downsample_ratio)**2,
82- # bias=False),
8383 RMSNorm (self .vit_hidden_size * int (1 / self .downsample_ratio )** 2 ,
8484 eps = 1e-5 ),
8585 nn .Linear (self .vit_hidden_size * int (1 / self .downsample_ratio )** 2 ,
8686 self .vision_projection_hidden_size ,
87- bias = False ),
88- SquaredReLU (),
87+ bias = False ), SquaredReLU (),
8988 nn .Linear (self .vision_projection_hidden_size ,
9089 self .llm_hidden_size ,
9190 bias = False ))
9291 self .mlp1 = self .mlp1 .to (config .torch_dtype )
9392
94- # self.img_context_token_id = None
9593 WITH_HF_CODES = False
9694 if WITH_HF_CODES :
9795 self .vision_model = transformers .AutoModel .from_config (
9896 config .vision_config , trust_remote_code = True )
97+ # set input_condition as Identity module.
98+ self .vision_model .radio_model .make_preprocessor_external ()
9999 self .vision_model .to (config .torch_dtype )
100100
101101 with open ("hf_vision_encoder_arch.txt" , "w" ) as f :
@@ -113,7 +113,6 @@ def __init__(self,
113113
114114 with open ("trtllm_vision_encoder_arch.txt" , "w" ) as f :
115115 f .write (str (self .vision_model ))
116-
117116 else :
118117 # Update the vision model with customized one.
119118 from .modeling_radio import RADIOModel
@@ -218,6 +217,7 @@ def __init__(self,
218217 self .img_context_token = "<image>"
219218 self .img_start_token = "<img>"
220219 self .img_end_token = "</img>"
220+ self .dtype = model_config .torch_dtype
221221
222222 @torch .inference_mode ()
223223 def __call__ (
@@ -258,7 +258,8 @@ def __call__(
258258
259259 # Will package inputs for language model forward in AGGREGATE mode.
260260 multimodal_data = {}
261- multimodal_data ['pixel_values' ] = processed_images ['pixel_values' ]
261+ multimodal_data ['pixel_values' ] = processed_images ['pixel_values' ].to (
262+ self .dtype )
262263 multimodal_data ['num_patches' ] = processed_images ['num_patches' ]
263264 return input_ids [0 ].to (torch .int32 ).tolist (), {
264265 "multimodal_data" : multimodal_data ,
@@ -271,7 +272,7 @@ def __call__(
271272 model_type = "NemotronH_Nano_VL_V2" ,
272273 placeholder_metadata = MultimodalPlaceholderMetadata (
273274 placeholder_map = {
274- "image" : "<image>" ,
275+ "image" : "<image>\n " ,
275276 },
276277 placeholder_placement = MultimodalPlaceholderPlacement .BEFORE_TEXT ,
277278 placeholders_separator = "" ,
@@ -321,38 +322,44 @@ def load_weights(self, weights):
321322 ) and m != "vision_model.radio_model.summary_idxs" :
322323 raise ValueError (f"Missing key: { m } " )
323324 for u in unexpected_keys :
324- if not u .startswith ('vision_model.radio_model.model.blocks.' ):
325+ if not u .startswith (
326+ 'vision_model.radio_model.model.blocks.' ) and u not in [
327+ "vision_model.radio_model.input_conditioner.norm_mean" ,
328+ "vision_model.radio_model.input_conditioner.norm_std" ,
329+ ]:
325330 raise ValueError (f"Unexpected key: { u } " )
326331
327- # Load weights for vision transformer module.
328- model_weights = {
329- k .replace ('vision_model.radio_model.model.' , '' ): v
330- for k , v in weights .items ()
331- if k .startswith ('vision_model.radio_model.model.' )
332- }
333- converted_weights = dict ()
334- for name in model_weights :
335- # Handle with weights and bias for vision transformer's qkv projection.
336- if "attn.qkv." in name :
337- q_name = name .replace ("attn.qkv." , "attn.q_proj." )
338- k_name = name .replace ("attn.qkv." , "attn.k_proj." )
339- v_name = name .replace ("attn.qkv." , "attn.v_proj." )
340- dim_shape = model_weights [name ].shape [0 ] // 3
341- converted_weights [q_name ] = model_weights [name ][:dim_shape ]
342- converted_weights [k_name ] = model_weights [name ][dim_shape :2 *
343- dim_shape ]
344- converted_weights [v_name ] = model_weights [name ][2 * dim_shape :]
345- else :
346- converted_weights [name ] = model_weights [name ]
347- pattern_mapping = {
348- r'(.*?)attn.proj.(.*)' : r'\1attn.o_proj.\2' ,
349- r'(.*?)mlp.fc1.(.*)' : r'\1mlp.up_proj.\2' ,
350- r'(.*?)mlp.fc2.(.*)' : r'\1mlp.down_proj.\2' ,
351- }
352- modeling_utils ._load_weights_impl (
353- self .vision_encoder .vision_model .radio_model .model ,
354- converted_weights ,
355- params_map = pattern_mapping )
332+ if len (unexpected_keys ) > 0 or len (missing_keys ) > 1 :
333+ # Load weights for vision transformer module.
334+ model_weights = {
335+ k .replace ('vision_model.radio_model.model.' , '' ): v
336+ for k , v in weights .items ()
337+ if k .startswith ('vision_model.radio_model.model.' )
338+ }
339+ converted_weights = dict ()
340+ for name in model_weights :
341+ # Handle with weights and bias for vision transformer's qkv projection.
342+ if "attn.qkv." in name :
343+ q_name = name .replace ("attn.qkv." , "attn.q_proj." )
344+ k_name = name .replace ("attn.qkv." , "attn.k_proj." )
345+ v_name = name .replace ("attn.qkv." , "attn.v_proj." )
346+ dim_shape = model_weights [name ].shape [0 ] // 3
347+ converted_weights [q_name ] = model_weights [name ][:dim_shape ]
348+ converted_weights [k_name ] = model_weights [name ][
349+ dim_shape :2 * dim_shape ]
350+ converted_weights [v_name ] = model_weights [name ][2 *
351+ dim_shape :]
352+ else :
353+ converted_weights [name ] = model_weights [name ]
354+ pattern_mapping = {
355+ r'(.*?)attn.proj.(.*)' : r'\1attn.o_proj.\2' ,
356+ r'(.*?)mlp.fc1.(.*)' : r'\1mlp.up_proj.\2' ,
357+ r'(.*?)mlp.fc2.(.*)' : r'\1mlp.down_proj.\2' ,
358+ }
359+ modeling_utils ._load_weights_impl (
360+ self .vision_encoder .vision_model .radio_model .model ,
361+ converted_weights ,
362+ params_map = pattern_mapping )
356363
357364 # Load language model weights.
358365 filtered_weights = {
@@ -405,11 +412,8 @@ def forward(
405412 self .llm .model .embed_tokens ,
406413 input_ids ,
407414 mm_embedding ,
408- mm_token_ids = torch .tensor ([
409- 131072
410- ], dtype = torch .int32 ), # 131072 is the token id for the image token
415+ mm_token_ids = torch .tensor ([IMAGE_TOKEN_ID ], dtype = torch .int32 ),
411416 )
412-
413417 output_prob = self .llm .forward (
414418 attn_metadata = attn_metadata ,
415419 input_ids = input_ids ,
0 commit comments