77import transformers
88from PIL import Image
99
10+ from tensorrt_llm ._torch .models import modeling_utils
1011from tensorrt_llm ._torch .models .checkpoints import NemotronHHfWeightMapper
1112from tensorrt_llm .inputs .multimodal import MultimodalParams
1213
@@ -34,10 +35,27 @@ def forward(self, x):
3435 return torch .pow (torch .nn .functional .relu (x ), 2 )
3536
3637
38+ class RMSNorm (nn .Module ):
39+
40+ def __init__ (self , hidden_size , eps = 1e-5 ):
41+ super ().__init__ ()
42+ self .weight = nn .Parameter (torch .ones (hidden_size ))
43+ self .eps = eps
44+
45+ def forward (self , hidden_states ):
46+ input_dtype = hidden_states .dtype
47+ hidden_states = hidden_states .to (torch .float32 )
48+ variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
49+ hidden_states = hidden_states * torch .rsqrt (variance + self .eps )
50+ return (self .weight .to (torch .float32 ) * hidden_states ).to (input_dtype )
51+
52+
3753class NanoV2VLVisionEncoder (transformers .PreTrainedModel ,
3854 transformers .generation .GenerationMixin ):
3955
40- def __init__ (self , config : transformers .PretrainedConfig ):
56+ def __init__ (self ,
57+ model_config : ModelConfig [transformers .PretrainedConfig ]):
58+ config = model_config .pretrained_config
4159 super ().__init__ (config )
4260 self .image_size = config .force_image_size
4361 self .patch_size = config .patch_size
@@ -59,12 +77,15 @@ def __init__(self, config: transformers.PretrainedConfig):
5977 self .llm_hidden_size = config .llm_config .hidden_size
6078
6179 self .mlp1 = nn .Sequential (
62- nn .LayerNorm (self .vit_hidden_size *
63- int (1 / self .downsample_ratio )** 2 ,
64- bias = False ),
80+ # nn.LayerNorm(self.vit_hidden_size *
81+ # int(1 / self.downsample_ratio)**2,
82+ # bias=False),
83+ RMSNorm (self .vit_hidden_size * int (1 / self .downsample_ratio )** 2 ,
84+ eps = 1e-5 ),
6585 nn .Linear (self .vit_hidden_size * int (1 / self .downsample_ratio )** 2 ,
6686 self .vision_projection_hidden_size ,
67- bias = False ), SquaredReLU (),
87+ bias = False ),
88+ SquaredReLU (),
6889 nn .Linear (self .vision_projection_hidden_size ,
6990 self .llm_hidden_size ,
7091 bias = False ))
@@ -80,13 +101,27 @@ def __init__(self, config: transformers.PretrainedConfig):
80101 with open ("hf_vision_encoder_arch.txt" , "w" ) as f :
81102 f .write (str (self .vision_model ))
82103 else :
83- # Update the vision model with customized one.
84- from .modeling_radio import RADIOModel
85- self .vision_model = RADIOModel (config .vision_config )
86- self .vision_model .to (config .torch_dtype )
104+ WITH_TRTLLM_CODES = True
105+ if WITH_TRTLLM_CODES :
106+ from .modeling_radio import RADIOVisionModel
87107
88- with open ("user_vision_encoder_arch.txt" , "w" ) as f :
89- f .write (str (self .vision_model ))
108+ vision_model_config = copy .deepcopy (model_config )
109+ vision_model_config .pretrained_config = vision_model_config .pretrained_config .vision_config
110+
111+ self .vision_model = RADIOVisionModel (vision_model_config )
112+ self .vision_model .to (config .torch_dtype )
113+
114+ with open ("trtllm_vision_encoder_arch.txt" , "w" ) as f :
115+ f .write (str (self .vision_model ))
116+
117+ else :
118+ # Update the vision model with customized one.
119+ from .modeling_radio import RADIOModel
120+ self .vision_model = RADIOModel (config .vision_config )
121+ self .vision_model .to (config .torch_dtype )
122+
123+ with open ("user_vision_encoder_arch.txt" , "w" ) as f :
124+ f .write (str (self .vision_model ))
90125
91126 def pixel_shuffle (self , x , scale_factor = 0.5 ):
92127 n , w , h , c = x .size ()
@@ -258,7 +293,7 @@ def __init__(self, model_config: ModelConfig):
258293 return
259294
260295 if not _is_disagg ():
261- self .vision_encoder = NanoV2VLVisionEncoder (config ).eval ()
296+ self .vision_encoder = NanoV2VLVisionEncoder (model_config ).eval ()
262297 self .vision_encoder .to (config .torch_dtype )
263298
264299 llm_model_config = copy .deepcopy (model_config )
@@ -272,19 +307,53 @@ def __init__(self, model_config: ModelConfig):
272307 self .is_loaded = True
273308
274309 def load_weights (self , weights ):
275- # Load vision encoder weights.
310+ # Load vision encoder weights for pytorch modules .
276311 filter_weights = {
277312 k : v
278313 for k , v in weights .items ()
279314 if k .startswith ('vision' ) or k .startswith ('mlp1' )
280315 }
281316 missing_keys , unexpected_keys = self .vision_encoder .load_state_dict (
282317 filter_weights , strict = False )
283- if len (unexpected_keys ) > 0 :
284- raise ValueError (f"Unexpected keys: { unexpected_keys } " )
285- if len (missing_keys ) > 1 and missing_keys [
286- 0 ] != 'vision_model.radio_model.summary_idxs' :
287- raise ValueError (f"Missing keys: { missing_keys } " )
318+ for m in missing_keys :
319+ if not m .startswith (
320+ 'vision_model.radio_model.model.blocks.'
321+ ) and m != "vision_model.radio_model.summary_idxs" :
322+ raise ValueError (f"Missing key: { m } " )
323+ for u in unexpected_keys :
324+ if not u .startswith ('vision_model.radio_model.model.blocks.' ):
325+ raise ValueError (f"Unexpected key: { u } " )
326+
327+ # Load weights for vision transformer module.
328+ model_weights = {
329+ k .replace ('vision_model.radio_model.model.' , '' ): v
330+ for k , v in weights .items ()
331+ if k .startswith ('vision_model.radio_model.model.' )
332+ }
333+ converted_weights = dict ()
334+ for name in model_weights :
335+ # Handle with weights and bias for vision transformer's qkv projection.
336+ if "attn.qkv." in name :
337+ q_name = name .replace ("attn.qkv." , "attn.q_proj." )
338+ k_name = name .replace ("attn.qkv." , "attn.k_proj." )
339+ v_name = name .replace ("attn.qkv." , "attn.v_proj." )
340+ dim_shape = model_weights [name ].shape [0 ] // 3
341+ converted_weights [q_name ] = model_weights [name ][:dim_shape ]
342+ converted_weights [k_name ] = model_weights [name ][dim_shape :2 *
343+ dim_shape ]
344+ converted_weights [v_name ] = model_weights [name ][2 * dim_shape :]
345+ else :
346+ converted_weights [name ] = model_weights [name ]
347+ pattern_mapping = {
348+ r'(.*?)attn.proj.(.*)' : r'\1attn.o_proj.\2' ,
349+ r'(.*?)mlp.fc1.(.*)' : r'\1mlp.up_proj.\2' ,
350+ r'(.*?)mlp.fc2.(.*)' : r'\1mlp.down_proj.\2' ,
351+ }
352+ modeling_utils ._load_weights_impl (
353+ self .vision_encoder .vision_model .radio_model .model ,
354+ converted_weights ,
355+ params_map = pattern_mapping )
356+
288357 # Load language model weights.
289358 filtered_weights = {
290359 k .replace ('language_model.' , '' ): v
0 commit comments