Skip to content

Commit a3eb066

Browse files
committed
[TRTLLM-6577][feat] Support nano_v2_vlm in pytorch backend
* Clean up codes Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 7ea064f commit a3eb066

File tree

4 files changed

+183
-174
lines changed

4 files changed

+183
-174
lines changed

.github/CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@
9999
/tests/unittest/_torch/modeling/test_modeling_pixtral.py @NVIDIA/trt-llm-torch-models-vlm-devs @NVIDIA/trt-llm-torch-models-devs
100100

101101
### TensorRT-LLM Pytorch - Models - Nemotron
102+
/tensorrt_llm/_torch/models/modeling_nanov2vlm.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
103+
/tensorrt_llm/_torch/models/modeling_radio.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
102104
/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
103105
/tensorrt_llm/_torch/models/modeling_nemotron_h.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
104106
/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,6 @@ blake3
6666
soundfile
6767
triton==3.3.1; platform_machine == "x86_64"
6868
tiktoken
69+
timm
6970
blobfile
7071
openai-harmony==0.0.4

tensorrt_llm/_torch/models/modeling_nanov2vlm.py

Lines changed: 45 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import transformers
88
from PIL import Image
99

10-
from tensorrt_llm._torch.models import modeling_utils
1110
from tensorrt_llm._torch.models.checkpoints import NemotronHHfWeightMapper
1211
from tensorrt_llm.inputs.multimodal import MultimodalParams
1312

@@ -30,6 +29,7 @@ def _is_disagg() -> bool:
3029
return os.getenv("TLLM_MULTIMODAL_DISAGGREGATED", "0") == "1"
3130

3231

32+
# TODO: update the reference config path once Nano v2 VLM is released.
3333
IMAGE_TOKEN_ID = 131072
3434

3535

@@ -63,7 +63,6 @@ def __init__(self,
6363
super().__init__(config)
6464
self.image_size = config.force_image_size
6565
self.patch_size = config.patch_size
66-
# self.template = config.template
6766
self.num_image_token = int((self.image_size // self.patch_size)**2 *
6867
(config.downsample_ratio**2))
6968
self.downsample_ratio = config.downsample_ratio
@@ -85,18 +84,25 @@ def __init__(self,
8584
self.mlp1 = self.mlp1.to(config.torch_dtype)
8685

8786
# Construct the vision encoder.
88-
self.with_hf_codes = os.getenv("WITH_HF_CODES", "0") == "1"
89-
if self.with_hf_codes:
90-
self.vision_model = transformers.AutoModel.from_config(
91-
config.vision_config, trust_remote_code=True)
92-
# set input_condition as Identity module.
93-
self.vision_model.radio_model.make_preprocessor_external()
94-
self.vision_model.to(config.torch_dtype)
95-
else:
96-
vision_model_config = copy.deepcopy(model_config)
97-
vision_model_config.pretrained_config = vision_model_config.pretrained_config.vision_config
98-
self.vision_model = RADIOVisionModel(vision_model_config)
99-
self.vision_model.to(config.torch_dtype)
87+
vision_model_config = copy.deepcopy(model_config)
88+
vision_model_config.pretrained_config = vision_model_config.pretrained_config.vision_config
89+
self.vision_model = RADIOVisionModel(vision_model_config)
90+
self.vision_model.to(config.torch_dtype)
91+
92+
def load_weights(self, weights):
93+
# Load mlp1 weights.
94+
mlp1_weights = {
95+
k.replace('mlp1.', ''): v
96+
for k, v in weights.items() if k.startswith('mlp1.')
97+
}
98+
self.mlp1.load_state_dict(mlp1_weights, strict=True)
99+
100+
# Load vision encoder weights.
101+
vision_encoder_weights = {
102+
k.replace('vision_model.', ''): v
103+
for k, v in weights.items() if k.startswith('vision_model.')
104+
}
105+
self.vision_model.load_weights(vision_encoder_weights)
100106

101107
@torch.compile
102108
def pixel_shuffle(self, x, scale_factor=0.5):
@@ -117,10 +123,7 @@ def pixel_shuffle(self, x, scale_factor=0.5):
117123
return x
118124

119125
def extract_feature(self, pixel_values):
120-
if self.with_hf_codes:
121-
vit_embeds = self.vision_model(pixel_values).features
122-
else:
123-
vit_embeds = self.vision_model(pixel_values)
126+
vit_embeds = self.vision_model(pixel_values)
124127
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
125128
# Down-sampling and projection.
126129
h = w = int(vit_embeds.shape[1]**0.5)
@@ -134,40 +137,28 @@ def extract_feature(self, pixel_values):
134137

135138
def forward(self, multimodal_params: List[MultimodalParams]):
136139
mm_embedding = []
137-
138-
BATCH_INFERENCE = True
139-
if BATCH_INFERENCE:
140-
# Batch data.
141-
batched_pixel_values = torch.cat([
142-
multimodal_param.multimodal_data["pixel_values"]
143-
for multimodal_param in multimodal_params
144-
],
145-
dim=0)
146-
# -> [num_patches, channel, height, width]
147-
batched_num_patches = torch.cat([
148-
multimodal_param.multimodal_data["num_patches"]
149-
for multimodal_param in multimodal_params
150-
],
151-
dim=0).tolist()
152-
# -> list of[num_patches1, num_patches2, ...]
153-
batched_image_embeds = self.extract_feature(batched_pixel_values)
154-
# -> [num_patches, num_image_token, hidden_size]
155-
mm_embedding = torch.split(batched_image_embeds,
156-
batched_num_patches,
157-
dim=0)
158-
mm_embedding = [
159-
m.reshape(-1, self.llm_hidden_size) for m in mm_embedding
160-
]
161-
# -> list of [num_patches*num_image_token, hidden_size]
162-
else:
163-
# Inference per sample.
164-
for multimodal_param in multimodal_params:
165-
pixel_values = multimodal_param.multimodal_data["pixel_values"]
166-
image_embeds = self.extract_feature(pixel_values)
167-
# -> [num_patches, num_image_token, hidden_size]
168-
image_embeds = image_embeds.reshape(-1, self.llm_hidden_size)
169-
# -> [num_patches*num_image_token, hidden_size]
170-
mm_embedding.append(image_embeds)
140+
# Batch data.
141+
batched_pixel_values = torch.cat([
142+
multimodal_param.multimodal_data["pixel_values"]
143+
for multimodal_param in multimodal_params
144+
],
145+
dim=0)
146+
# -> [num_patches, channel, height, width]
147+
batched_num_patches = torch.cat([
148+
multimodal_param.multimodal_data["num_patches"]
149+
for multimodal_param in multimodal_params
150+
],
151+
dim=0).tolist()
152+
# -> list of[num_patches1, num_patches2, ...]
153+
batched_image_embeds = self.extract_feature(batched_pixel_values)
154+
# -> [num_patches, num_image_token, hidden_size]
155+
mm_embedding = torch.split(batched_image_embeds,
156+
batched_num_patches,
157+
dim=0)
158+
mm_embedding = [
159+
m.reshape(-1, self.llm_hidden_size) for m in mm_embedding
160+
]
161+
# -> list of [num_patches*num_image_token, hidden_size]
171162
return mm_embedding
172163

173164

@@ -361,63 +352,8 @@ def __init__(self, model_config: ModelConfig):
361352
self.is_loaded = True
362353

363354
def load_weights(self, weights):
364-
# TODO: move vision encoder weights loading to vision encoder class.
365-
366-
# Load vision encoder weights for pytorch modules.
367-
filter_weights = {
368-
k: v
369-
for k, v in weights.items()
370-
if k.startswith('vision') or k.startswith('mlp1')
371-
}
372-
missing_keys, unexpected_keys = self.vision_encoder.load_state_dict(
373-
filter_weights, strict=False)
374-
try:
375-
missing_keys.remove("vision_model.radio_model.summary_idxs")
376-
except ValueError:
377-
pass
378-
379-
unexpected_keys.remove(
380-
"vision_model.radio_model.input_conditioner.norm_mean")
381-
unexpected_keys.remove(
382-
"vision_model.radio_model.input_conditioner.norm_std")
383-
for m in missing_keys:
384-
if not m.startswith('vision_model.radio_model.model.blocks.'):
385-
raise ValueError(f"Missing key: {m}")
386-
for u in unexpected_keys:
387-
if not u.startswith('vision_model.radio_model.model.blocks.'):
388-
raise ValueError(f"Unexpected key: {u}")
389-
390-
if len(unexpected_keys) > 0 or len(missing_keys) > 0:
391-
# Load weights for vision transformer module.
392-
model_weights = {
393-
k.replace('vision_model.radio_model.model.', ''): v
394-
for k, v in weights.items()
395-
if k.startswith('vision_model.radio_model.model.')
396-
}
397-
converted_weights = dict()
398-
for name in model_weights:
399-
# Handle with weights and bias for vision transformer's qkv projection.
400-
if "attn.qkv." in name:
401-
q_name = name.replace("attn.qkv.", "attn.q_proj.")
402-
k_name = name.replace("attn.qkv.", "attn.k_proj.")
403-
v_name = name.replace("attn.qkv.", "attn.v_proj.")
404-
dim_shape = model_weights[name].shape[0] // 3
405-
converted_weights[q_name] = model_weights[name][:dim_shape]
406-
converted_weights[k_name] = model_weights[name][
407-
dim_shape:2 * dim_shape]
408-
converted_weights[v_name] = model_weights[name][2 *
409-
dim_shape:]
410-
else:
411-
converted_weights[name] = model_weights[name]
412-
pattern_mapping = {
413-
r'(.*?)attn.proj.(.*)': r'\1attn.o_proj.\2',
414-
r'(.*?)mlp.fc1.(.*)': r'\1mlp.up_proj.\2',
415-
r'(.*?)mlp.fc2.(.*)': r'\1mlp.down_proj.\2',
416-
}
417-
modeling_utils._load_weights_impl(
418-
self.vision_encoder.vision_model.radio_model.model,
419-
converted_weights,
420-
params_map=pattern_mapping)
355+
# Load vision encoder weights.
356+
self.vision_encoder.load_weights(weights)
421357

422358
# Load language model weights.
423359
filtered_weights = {

0 commit comments

Comments
 (0)