77import transformers
88from PIL import Image
99
10- from tensorrt_llm ._torch .models import modeling_utils
1110from tensorrt_llm ._torch .models .checkpoints import NemotronHHfWeightMapper
1211from tensorrt_llm .inputs .multimodal import MultimodalParams
1312
@@ -30,6 +29,7 @@ def _is_disagg() -> bool:
3029 return os .getenv ("TLLM_MULTIMODAL_DISAGGREGATED" , "0" ) == "1"
3130
3231
32+ # TODO: update the reference config path once Nano v2 VLM is released.
3333IMAGE_TOKEN_ID = 131072
3434
3535
@@ -63,7 +63,6 @@ def __init__(self,
6363 super ().__init__ (config )
6464 self .image_size = config .force_image_size
6565 self .patch_size = config .patch_size
66- # self.template = config.template
6766 self .num_image_token = int ((self .image_size // self .patch_size )** 2 *
6867 (config .downsample_ratio ** 2 ))
6968 self .downsample_ratio = config .downsample_ratio
@@ -85,18 +84,25 @@ def __init__(self,
8584 self .mlp1 = self .mlp1 .to (config .torch_dtype )
8685
8786 # Construct the vision encoder.
88- self .with_hf_codes = os .getenv ("WITH_HF_CODES" , "0" ) == "1"
89- if self .with_hf_codes :
90- self .vision_model = transformers .AutoModel .from_config (
91- config .vision_config , trust_remote_code = True )
92- # set input_condition as Identity module.
93- self .vision_model .radio_model .make_preprocessor_external ()
94- self .vision_model .to (config .torch_dtype )
95- else :
96- vision_model_config = copy .deepcopy (model_config )
97- vision_model_config .pretrained_config = vision_model_config .pretrained_config .vision_config
98- self .vision_model = RADIOVisionModel (vision_model_config )
99- self .vision_model .to (config .torch_dtype )
87+ vision_model_config = copy .deepcopy (model_config )
88+ vision_model_config .pretrained_config = vision_model_config .pretrained_config .vision_config
89+ self .vision_model = RADIOVisionModel (vision_model_config )
90+ self .vision_model .to (config .torch_dtype )
91+
92+ def load_weights (self , weights ):
93+ # Load mlp1 weights.
94+ mlp1_weights = {
95+ k .replace ('mlp1.' , '' ): v
96+ for k , v in weights .items () if k .startswith ('mlp1.' )
97+ }
98+ self .mlp1 .load_state_dict (mlp1_weights , strict = True )
99+
100+ # Load vision encoder weights.
101+ vision_encoder_weights = {
102+ k .replace ('vision_model.' , '' ): v
103+ for k , v in weights .items () if k .startswith ('vision_model.' )
104+ }
105+ self .vision_model .load_weights (vision_encoder_weights )
100106
101107 @torch .compile
102108 def pixel_shuffle (self , x , scale_factor = 0.5 ):
@@ -117,10 +123,7 @@ def pixel_shuffle(self, x, scale_factor=0.5):
117123 return x
118124
119125 def extract_feature (self , pixel_values ):
120- if self .with_hf_codes :
121- vit_embeds = self .vision_model (pixel_values ).features
122- else :
123- vit_embeds = self .vision_model (pixel_values )
126+ vit_embeds = self .vision_model (pixel_values )
124127 vit_embeds = vit_embeds .to (dtype = torch .bfloat16 )
125128 # Down-sampling and projection.
126129 h = w = int (vit_embeds .shape [1 ]** 0.5 )
@@ -134,40 +137,28 @@ def extract_feature(self, pixel_values):
134137
135138 def forward (self , multimodal_params : List [MultimodalParams ]):
136139 mm_embedding = []
137-
138- BATCH_INFERENCE = True
139- if BATCH_INFERENCE :
140- # Batch data.
141- batched_pixel_values = torch .cat ([
142- multimodal_param .multimodal_data ["pixel_values" ]
143- for multimodal_param in multimodal_params
144- ],
145- dim = 0 )
146- # -> [num_patches, channel, height, width]
147- batched_num_patches = torch .cat ([
148- multimodal_param .multimodal_data ["num_patches" ]
149- for multimodal_param in multimodal_params
150- ],
151- dim = 0 ).tolist ()
152- # -> list of[num_patches1, num_patches2, ...]
153- batched_image_embeds = self .extract_feature (batched_pixel_values )
154- # -> [num_patches, num_image_token, hidden_size]
155- mm_embedding = torch .split (batched_image_embeds ,
156- batched_num_patches ,
157- dim = 0 )
158- mm_embedding = [
159- m .reshape (- 1 , self .llm_hidden_size ) for m in mm_embedding
160- ]
161- # -> list of [num_patches*num_image_token, hidden_size]
162- else :
163- # Inference per sample.
164- for multimodal_param in multimodal_params :
165- pixel_values = multimodal_param .multimodal_data ["pixel_values" ]
166- image_embeds = self .extract_feature (pixel_values )
167- # -> [num_patches, num_image_token, hidden_size]
168- image_embeds = image_embeds .reshape (- 1 , self .llm_hidden_size )
169- # -> [num_patches*num_image_token, hidden_size]
170- mm_embedding .append (image_embeds )
140+ # Batch data.
141+ batched_pixel_values = torch .cat ([
142+ multimodal_param .multimodal_data ["pixel_values" ]
143+ for multimodal_param in multimodal_params
144+ ],
145+ dim = 0 )
146+ # -> [num_patches, channel, height, width]
147+ batched_num_patches = torch .cat ([
148+ multimodal_param .multimodal_data ["num_patches" ]
149+ for multimodal_param in multimodal_params
150+ ],
151+ dim = 0 ).tolist ()
152+ # -> list of[num_patches1, num_patches2, ...]
153+ batched_image_embeds = self .extract_feature (batched_pixel_values )
154+ # -> [num_patches, num_image_token, hidden_size]
155+ mm_embedding = torch .split (batched_image_embeds ,
156+ batched_num_patches ,
157+ dim = 0 )
158+ mm_embedding = [
159+ m .reshape (- 1 , self .llm_hidden_size ) for m in mm_embedding
160+ ]
161+ # -> list of [num_patches*num_image_token, hidden_size]
171162 return mm_embedding
172163
173164
@@ -361,63 +352,8 @@ def __init__(self, model_config: ModelConfig):
361352 self .is_loaded = True
362353
363354 def load_weights (self , weights ):
364- # TODO: move vision encoder weights loading to vision encoder class.
365-
366- # Load vision encoder weights for pytorch modules.
367- filter_weights = {
368- k : v
369- for k , v in weights .items ()
370- if k .startswith ('vision' ) or k .startswith ('mlp1' )
371- }
372- missing_keys , unexpected_keys = self .vision_encoder .load_state_dict (
373- filter_weights , strict = False )
374- try :
375- missing_keys .remove ("vision_model.radio_model.summary_idxs" )
376- except ValueError :
377- pass
378-
379- unexpected_keys .remove (
380- "vision_model.radio_model.input_conditioner.norm_mean" )
381- unexpected_keys .remove (
382- "vision_model.radio_model.input_conditioner.norm_std" )
383- for m in missing_keys :
384- if not m .startswith ('vision_model.radio_model.model.blocks.' ):
385- raise ValueError (f"Missing key: { m } " )
386- for u in unexpected_keys :
387- if not u .startswith ('vision_model.radio_model.model.blocks.' ):
388- raise ValueError (f"Unexpected key: { u } " )
389-
390- if len (unexpected_keys ) > 0 or len (missing_keys ) > 0 :
391- # Load weights for vision transformer module.
392- model_weights = {
393- k .replace ('vision_model.radio_model.model.' , '' ): v
394- for k , v in weights .items ()
395- if k .startswith ('vision_model.radio_model.model.' )
396- }
397- converted_weights = dict ()
398- for name in model_weights :
399- # Handle with weights and bias for vision transformer's qkv projection.
400- if "attn.qkv." in name :
401- q_name = name .replace ("attn.qkv." , "attn.q_proj." )
402- k_name = name .replace ("attn.qkv." , "attn.k_proj." )
403- v_name = name .replace ("attn.qkv." , "attn.v_proj." )
404- dim_shape = model_weights [name ].shape [0 ] // 3
405- converted_weights [q_name ] = model_weights [name ][:dim_shape ]
406- converted_weights [k_name ] = model_weights [name ][
407- dim_shape :2 * dim_shape ]
408- converted_weights [v_name ] = model_weights [name ][2 *
409- dim_shape :]
410- else :
411- converted_weights [name ] = model_weights [name ]
412- pattern_mapping = {
413- r'(.*?)attn.proj.(.*)' : r'\1attn.o_proj.\2' ,
414- r'(.*?)mlp.fc1.(.*)' : r'\1mlp.up_proj.\2' ,
415- r'(.*?)mlp.fc2.(.*)' : r'\1mlp.down_proj.\2' ,
416- }
417- modeling_utils ._load_weights_impl (
418- self .vision_encoder .vision_model .radio_model .model ,
419- converted_weights ,
420- params_map = pattern_mapping )
355+ # Load vision encoder weights.
356+ self .vision_encoder .load_weights (weights )
421357
422358 # Load language model weights.
423359 filtered_weights = {
0 commit comments