Skip to content

Commit 100536e

Browse files
committed
[TRTLLM-6577][feat] Support nano_v2_vlm in pytorch backend
* Use trtllm primitives Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 1926ddc commit 100536e

File tree

3 files changed

+581
-104
lines changed

3 files changed

+581
-104
lines changed

cpp/kernels/fmha_v2/setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1981,8 +1981,8 @@ def selected_mask_types(kspec):
19811981
custom_mask = '0'
19821982
# encoder models (head_size = 32 / 64 / 128) need packed_qkv input layout + padding mask.
19831983
elif kspec.input_layout == InputLayout.PACKED_QKV:
1984-
# NOTE: 72 is added for vision transformer
1985-
if kspec.head_size not in [32, 64, 72, 128]:
1984+
# NOTE: 72/80 are added for vision transformer
1985+
if kspec.head_size not in [32, 64, 72, 80, 128]:
19861986
padding_mask = '0'
19871987
# only cross attention (head_size = 32/64/128) needs contiguous_q_kv input layout + padding mask / custom_mask.
19881988
elif kspec.input_layout == InputLayout.CONTIGUOUS_Q_KV:

tensorrt_llm/_torch/models/modeling_nanov2vlm.py

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import transformers
88
from PIL import Image
99

10+
from tensorrt_llm._torch.models import modeling_utils
1011
from tensorrt_llm._torch.models.checkpoints import NemotronHHfWeightMapper
1112
from tensorrt_llm.inputs.multimodal import MultimodalParams
1213

@@ -34,10 +35,27 @@ def forward(self, x):
3435
return torch.pow(torch.nn.functional.relu(x), 2)
3536

3637

38+
class RMSNorm(nn.Module):
39+
40+
def __init__(self, hidden_size, eps=1e-5):
41+
super().__init__()
42+
self.weight = nn.Parameter(torch.ones(hidden_size))
43+
self.eps = eps
44+
45+
def forward(self, hidden_states):
46+
input_dtype = hidden_states.dtype
47+
hidden_states = hidden_states.to(torch.float32)
48+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
49+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
50+
return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
51+
52+
3753
class NanoV2VLVisionEncoder(transformers.PreTrainedModel,
3854
transformers.generation.GenerationMixin):
3955

40-
def __init__(self, config: transformers.PretrainedConfig):
56+
def __init__(self,
57+
model_config: ModelConfig[transformers.PretrainedConfig]):
58+
config = model_config.pretrained_config
4159
super().__init__(config)
4260
self.image_size = config.force_image_size
4361
self.patch_size = config.patch_size
@@ -59,12 +77,15 @@ def __init__(self, config: transformers.PretrainedConfig):
5977
self.llm_hidden_size = config.llm_config.hidden_size
6078

6179
self.mlp1 = nn.Sequential(
62-
nn.LayerNorm(self.vit_hidden_size *
63-
int(1 / self.downsample_ratio)**2,
64-
bias=False),
80+
# nn.LayerNorm(self.vit_hidden_size *
81+
# int(1 / self.downsample_ratio)**2,
82+
# bias=False),
83+
RMSNorm(self.vit_hidden_size * int(1 / self.downsample_ratio)**2,
84+
eps=1e-5),
6585
nn.Linear(self.vit_hidden_size * int(1 / self.downsample_ratio)**2,
6686
self.vision_projection_hidden_size,
67-
bias=False), SquaredReLU(),
87+
bias=False),
88+
SquaredReLU(),
6889
nn.Linear(self.vision_projection_hidden_size,
6990
self.llm_hidden_size,
7091
bias=False))
@@ -80,13 +101,27 @@ def __init__(self, config: transformers.PretrainedConfig):
80101
with open("hf_vision_encoder_arch.txt", "w") as f:
81102
f.write(str(self.vision_model))
82103
else:
83-
# Update the vision model with customized one.
84-
from .modeling_radio import RADIOModel
85-
self.vision_model = RADIOModel(config.vision_config)
86-
self.vision_model.to(config.torch_dtype)
104+
WITH_TRTLLM_CODES = True
105+
if WITH_TRTLLM_CODES:
106+
from .modeling_radio import RADIOVisionModel
87107

88-
with open("user_vision_encoder_arch.txt", "w") as f:
89-
f.write(str(self.vision_model))
108+
vision_model_config = copy.deepcopy(model_config)
109+
vision_model_config.pretrained_config = vision_model_config.pretrained_config.vision_config
110+
111+
self.vision_model = RADIOVisionModel(vision_model_config)
112+
self.vision_model.to(config.torch_dtype)
113+
114+
with open("trtllm_vision_encoder_arch.txt", "w") as f:
115+
f.write(str(self.vision_model))
116+
117+
else:
118+
# Update the vision model with customized one.
119+
from .modeling_radio import RADIOModel
120+
self.vision_model = RADIOModel(config.vision_config)
121+
self.vision_model.to(config.torch_dtype)
122+
123+
with open("user_vision_encoder_arch.txt", "w") as f:
124+
f.write(str(self.vision_model))
90125

91126
def pixel_shuffle(self, x, scale_factor=0.5):
92127
n, w, h, c = x.size()
@@ -258,7 +293,7 @@ def __init__(self, model_config: ModelConfig):
258293
return
259294

260295
if not _is_disagg():
261-
self.vision_encoder = NanoV2VLVisionEncoder(config).eval()
296+
self.vision_encoder = NanoV2VLVisionEncoder(model_config).eval()
262297
self.vision_encoder.to(config.torch_dtype)
263298

264299
llm_model_config = copy.deepcopy(model_config)
@@ -272,19 +307,53 @@ def __init__(self, model_config: ModelConfig):
272307
self.is_loaded = True
273308

274309
def load_weights(self, weights):
275-
# Load vision encoder weights.
310+
# Load vision encoder weights for pytorch modules.
276311
filter_weights = {
277312
k: v
278313
for k, v in weights.items()
279314
if k.startswith('vision') or k.startswith('mlp1')
280315
}
281316
missing_keys, unexpected_keys = self.vision_encoder.load_state_dict(
282317
filter_weights, strict=False)
283-
if len(unexpected_keys) > 0:
284-
raise ValueError(f"Unexpected keys: {unexpected_keys}")
285-
if len(missing_keys) > 1 and missing_keys[
286-
0] != 'vision_model.radio_model.summary_idxs':
287-
raise ValueError(f"Missing keys: {missing_keys}")
318+
for m in missing_keys:
319+
if not m.startswith(
320+
'vision_model.radio_model.model.blocks.'
321+
) and m != "vision_model.radio_model.summary_idxs":
322+
raise ValueError(f"Missing key: {m}")
323+
for u in unexpected_keys:
324+
if not u.startswith('vision_model.radio_model.model.blocks.'):
325+
raise ValueError(f"Unexpected key: {u}")
326+
327+
# Load weights for vision transformer module.
328+
model_weights = {
329+
k.replace('vision_model.radio_model.model.', ''): v
330+
for k, v in weights.items()
331+
if k.startswith('vision_model.radio_model.model.')
332+
}
333+
converted_weights = dict()
334+
for name in model_weights:
335+
# Handle with weights and bias for vision transformer's qkv projection.
336+
if "attn.qkv." in name:
337+
q_name = name.replace("attn.qkv.", "attn.q_proj.")
338+
k_name = name.replace("attn.qkv.", "attn.k_proj.")
339+
v_name = name.replace("attn.qkv.", "attn.v_proj.")
340+
dim_shape = model_weights[name].shape[0] // 3
341+
converted_weights[q_name] = model_weights[name][:dim_shape]
342+
converted_weights[k_name] = model_weights[name][dim_shape:2 *
343+
dim_shape]
344+
converted_weights[v_name] = model_weights[name][2 * dim_shape:]
345+
else:
346+
converted_weights[name] = model_weights[name]
347+
pattern_mapping = {
348+
r'(.*?)attn.proj.(.*)': r'\1attn.o_proj.\2',
349+
r'(.*?)mlp.fc1.(.*)': r'\1mlp.up_proj.\2',
350+
r'(.*?)mlp.fc2.(.*)': r'\1mlp.down_proj.\2',
351+
}
352+
modeling_utils._load_weights_impl(
353+
self.vision_encoder.vision_model.radio_model.model,
354+
converted_weights,
355+
params_map=pattern_mapping)
356+
288357
# Load language model weights.
289358
filtered_weights = {
290359
k.replace('language_model.', ''): v

0 commit comments

Comments
 (0)