Skip to content

Commit 034393e

Browse files
committed
[TRTLLM-6577][feat] Support nano_v2_vlm in pytorch backend
* aligned output between HF codes and trtllm codes. Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 100536e commit 034393e

File tree

2 files changed

+74
-64
lines changed

2 files changed

+74
-64
lines changed

tensorrt_llm/_torch/models/modeling_nanov2vlm.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def _is_disagg() -> bool:
2929
return os.getenv("TLLM_MULTIMODAL_DISAGGREGATED", "0") == "1"
3030

3131

32+
IMAGE_TOKEN_ID = 131072
33+
34+
3235
class SquaredReLU(nn.Module):
3336

3437
def forward(self, x):
@@ -77,25 +80,22 @@ def __init__(self,
7780
self.llm_hidden_size = config.llm_config.hidden_size
7881

7982
self.mlp1 = nn.Sequential(
80-
# nn.LayerNorm(self.vit_hidden_size *
81-
# int(1 / self.downsample_ratio)**2,
82-
# bias=False),
8383
RMSNorm(self.vit_hidden_size * int(1 / self.downsample_ratio)**2,
8484
eps=1e-5),
8585
nn.Linear(self.vit_hidden_size * int(1 / self.downsample_ratio)**2,
8686
self.vision_projection_hidden_size,
87-
bias=False),
88-
SquaredReLU(),
87+
bias=False), SquaredReLU(),
8988
nn.Linear(self.vision_projection_hidden_size,
9089
self.llm_hidden_size,
9190
bias=False))
9291
self.mlp1 = self.mlp1.to(config.torch_dtype)
9392

94-
# self.img_context_token_id = None
9593
WITH_HF_CODES = False
9694
if WITH_HF_CODES:
9795
self.vision_model = transformers.AutoModel.from_config(
9896
config.vision_config, trust_remote_code=True)
97+
# set input_condition as Identity module.
98+
self.vision_model.radio_model.make_preprocessor_external()
9999
self.vision_model.to(config.torch_dtype)
100100

101101
with open("hf_vision_encoder_arch.txt", "w") as f:
@@ -113,7 +113,6 @@ def __init__(self,
113113

114114
with open("trtllm_vision_encoder_arch.txt", "w") as f:
115115
f.write(str(self.vision_model))
116-
117116
else:
118117
# Update the vision model with customized one.
119118
from .modeling_radio import RADIOModel
@@ -218,6 +217,7 @@ def __init__(self,
218217
self.img_context_token = "<image>"
219218
self.img_start_token = "<img>"
220219
self.img_end_token = "</img>"
220+
self.dtype = model_config.torch_dtype
221221

222222
@torch.inference_mode()
223223
def __call__(
@@ -258,7 +258,8 @@ def __call__(
258258

259259
# Will package inputs for language model forward in AGGREGATE mode.
260260
multimodal_data = {}
261-
multimodal_data['pixel_values'] = processed_images['pixel_values']
261+
multimodal_data['pixel_values'] = processed_images['pixel_values'].to(
262+
self.dtype)
262263
multimodal_data['num_patches'] = processed_images['num_patches']
263264
return input_ids[0].to(torch.int32).tolist(), {
264265
"multimodal_data": multimodal_data,
@@ -271,7 +272,7 @@ def __call__(
271272
model_type="NemotronH_Nano_VL_V2",
272273
placeholder_metadata=MultimodalPlaceholderMetadata(
273274
placeholder_map={
274-
"image": "<image>",
275+
"image": "<image>\n",
275276
},
276277
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
277278
placeholders_separator="",
@@ -321,38 +322,44 @@ def load_weights(self, weights):
321322
) and m != "vision_model.radio_model.summary_idxs":
322323
raise ValueError(f"Missing key: {m}")
323324
for u in unexpected_keys:
324-
if not u.startswith('vision_model.radio_model.model.blocks.'):
325+
if not u.startswith(
326+
'vision_model.radio_model.model.blocks.') and u not in [
327+
"vision_model.radio_model.input_conditioner.norm_mean",
328+
"vision_model.radio_model.input_conditioner.norm_std",
329+
]:
325330
raise ValueError(f"Unexpected key: {u}")
326331

327-
# Load weights for vision transformer module.
328-
model_weights = {
329-
k.replace('vision_model.radio_model.model.', ''): v
330-
for k, v in weights.items()
331-
if k.startswith('vision_model.radio_model.model.')
332-
}
333-
converted_weights = dict()
334-
for name in model_weights:
335-
# Handle with weights and bias for vision transformer's qkv projection.
336-
if "attn.qkv." in name:
337-
q_name = name.replace("attn.qkv.", "attn.q_proj.")
338-
k_name = name.replace("attn.qkv.", "attn.k_proj.")
339-
v_name = name.replace("attn.qkv.", "attn.v_proj.")
340-
dim_shape = model_weights[name].shape[0] // 3
341-
converted_weights[q_name] = model_weights[name][:dim_shape]
342-
converted_weights[k_name] = model_weights[name][dim_shape:2 *
343-
dim_shape]
344-
converted_weights[v_name] = model_weights[name][2 * dim_shape:]
345-
else:
346-
converted_weights[name] = model_weights[name]
347-
pattern_mapping = {
348-
r'(.*?)attn.proj.(.*)': r'\1attn.o_proj.\2',
349-
r'(.*?)mlp.fc1.(.*)': r'\1mlp.up_proj.\2',
350-
r'(.*?)mlp.fc2.(.*)': r'\1mlp.down_proj.\2',
351-
}
352-
modeling_utils._load_weights_impl(
353-
self.vision_encoder.vision_model.radio_model.model,
354-
converted_weights,
355-
params_map=pattern_mapping)
332+
if len(unexpected_keys) > 0 or len(missing_keys) > 1:
333+
# Load weights for vision transformer module.
334+
model_weights = {
335+
k.replace('vision_model.radio_model.model.', ''): v
336+
for k, v in weights.items()
337+
if k.startswith('vision_model.radio_model.model.')
338+
}
339+
converted_weights = dict()
340+
for name in model_weights:
341+
# Handle with weights and bias for vision transformer's qkv projection.
342+
if "attn.qkv." in name:
343+
q_name = name.replace("attn.qkv.", "attn.q_proj.")
344+
k_name = name.replace("attn.qkv.", "attn.k_proj.")
345+
v_name = name.replace("attn.qkv.", "attn.v_proj.")
346+
dim_shape = model_weights[name].shape[0] // 3
347+
converted_weights[q_name] = model_weights[name][:dim_shape]
348+
converted_weights[k_name] = model_weights[name][
349+
dim_shape:2 * dim_shape]
350+
converted_weights[v_name] = model_weights[name][2 *
351+
dim_shape:]
352+
else:
353+
converted_weights[name] = model_weights[name]
354+
pattern_mapping = {
355+
r'(.*?)attn.proj.(.*)': r'\1attn.o_proj.\2',
356+
r'(.*?)mlp.fc1.(.*)': r'\1mlp.up_proj.\2',
357+
r'(.*?)mlp.fc2.(.*)': r'\1mlp.down_proj.\2',
358+
}
359+
modeling_utils._load_weights_impl(
360+
self.vision_encoder.vision_model.radio_model.model,
361+
converted_weights,
362+
params_map=pattern_mapping)
356363

357364
# Load language model weights.
358365
filtered_weights = {
@@ -405,11 +412,8 @@ def forward(
405412
self.llm.model.embed_tokens,
406413
input_ids,
407414
mm_embedding,
408-
mm_token_ids=torch.tensor([
409-
131072
410-
], dtype=torch.int32), # 131072 is the token id for the image token
415+
mm_token_ids=torch.tensor([IMAGE_TOKEN_ID], dtype=torch.int32),
411416
)
412-
413417
output_prob = self.llm.forward(
414418
attn_metadata=attn_metadata,
415419
input_ids=input_ids,

tensorrt_llm/_torch/models/modeling_radio.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -85,27 +85,29 @@ def __init__(
8585
super().__init__(**kwargs)
8686

8787

88-
class InputConditioner(nn.Module):
88+
# class InputConditioner(nn.Module):
8989

90-
def __init__(
91-
self,
92-
input_scale: float,
93-
norm_mean: norm_t,
94-
norm_std: norm_t,
95-
dtype: torch.dtype = None,
96-
):
97-
super().__init__()
90+
# def __init__(
91+
# self,
92+
# input_scale: float,
93+
# norm_mean: norm_t,
94+
# norm_std: norm_t,
95+
# dtype: torch.dtype = None,
96+
# ):
97+
# super().__init__()
9898

99-
self.dtype = dtype
99+
# self.dtype = dtype
100100

101-
self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
102-
self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
101+
# self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
102+
# self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
103103

104-
def forward(self, x: torch.Tensor):
105-
y = (x - self.norm_mean) / self.norm_std
106-
if self.dtype is not None:
107-
y = y.to(self.dtype)
108-
return y
104+
# def forward(self, x: torch.Tensor):
105+
# y = (x - self.norm_mean) / self.norm_std
106+
# if self.dtype is not None:
107+
# y = y.to(self.dtype)
108+
# return y
109+
110+
InputConditioner = nn.Identity
109111

110112

111113
class ClsToken(nn.Module):
@@ -727,8 +729,9 @@ def __init__(
727729
act_layer = get_act_layer(act_layer) or nn.GELU
728730

729731
self.model_config = model_config
730-
self.config = model_config.pretrained_config
731-
self.config.num_key_value_heads = num_heads
732+
if self.model_config is not None:
733+
self.config = model_config.pretrained_config
734+
self.config.num_key_value_heads = num_heads
732735

733736
self.num_classes = num_classes
734737
self.global_pool = global_pool
@@ -810,8 +813,11 @@ def __init__(
810813
self.patch_size = patch_size
811814
self.num_cls_tokens = num_cls_tokens
812815
self.num_registers = self.patch_generator.num_registers
813-
self.metadata_cls = attention_utils.get_attention_backend(
814-
model_config.attn_backend).Metadata
816+
if self.model_config is not None:
817+
self.metadata_cls = attention_utils.get_attention_backend(
818+
model_config.attn_backend).Metadata
819+
else:
820+
self.metadata_cls = None
815821

816822
def prepare_attn_metadata(self, batch_size: int, seq_lengths: List[int]):
817823
"""

0 commit comments

Comments
 (0)