Skip to content

Commit 1f10d0c

Browse files
committed
[TRTLLM-6577][feat] Support nano_v2_vlm in pytorch backend
* clean up codes. * add test_e2e for nano_v2 vlm. Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent d1058ef commit 1f10d0c

File tree

3 files changed

+83
-73
lines changed

3 files changed

+83
-73
lines changed

tensorrt_llm/_torch/models/modeling_nanov2vlm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,11 @@ def __call__(
279279
Image.fromarray((image.permute(1, 2, 0) * 255).to(
280280
torch.uint8).cpu().numpy()) for image in images
281281
]
282+
else:
283+
input_ids = self.tokenizer.encode(text_prompt,
284+
add_special_tokens=False,
285+
return_tensors="pt")
286+
return input_ids[0].to(torch.int32).tolist(), {}
282287

283288
# Processing for multimodal data.
284289
processed_images = self.processor(images=images,

tensorrt_llm/_torch/models/modeling_radio.py

Lines changed: 56 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@
2222

2323
input_dim_t = Union[int, Tuple[int, int]]
2424

25-
# Need for model weight loading.
26-
NUM_ATTENTION_HEADS = 16
25+
# Model parameters which is not in config.json.
26+
# TODO: read from config.json when it is released.
27+
NUM_ATTENTION_HEADS_FOR_VIT = 16
28+
IMAGE_SIZE_FOR_VIT = 224
29+
PATCH_SIZE_FOR_VIT = 16
30+
EMBED_DIM_FOR_VIT = 1280
31+
DEPTH_FOR_VIT = 32
2732

2833

2934
class Resolution(NamedTuple):
@@ -34,7 +39,7 @@ class Resolution(NamedTuple):
3439
class RADIOConfig(PretrainedConfig):
3540
"""Pretrained Hugging Face configuration for RADIO models.
3641
37-
Copy from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/hf_model.py.
42+
Modified from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/hf_model.py.
3843
"""
3944

4045
def __init__(
@@ -55,8 +60,7 @@ def __init__(
5560
for field in ["dtype", "amp_dtype"]:
5661
if self.args is not None and field in self.args:
5762
# Convert to a string in order to make it serializable.
58-
# For example for torch.float32 we will store "float32",
59-
# for "bfloat16" we will store "bfloat16".
63+
# For example for torch.float32 we will store "float32".
6064
self.args[field] = str(args[field]).split(".")[-1]
6165
self.version = version
6266
self.patch_size = patch_size
@@ -68,13 +72,13 @@ def __init__(
6872
self.vitdet_window_size = vitdet_window_size
6973
self.feature_normalizer_config = feature_normalizer_config
7074
self.inter_feature_normalizer_config = inter_feature_normalizer_config
71-
self.num_key_value_heads = NUM_ATTENTION_HEADS
72-
self.num_attention_heads = NUM_ATTENTION_HEADS
75+
self.num_key_value_heads = NUM_ATTENTION_HEADS_FOR_VIT
76+
self.num_attention_heads = NUM_ATTENTION_HEADS_FOR_VIT
7377
super().__init__(**kwargs)
7478

7579

7680
class ClsToken(nn.Module):
77-
"""Copy from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/cls_token.py."""
81+
"""Modified from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/cls_token.py."""
7882

7983
def __init__(
8084
self,
@@ -115,7 +119,7 @@ def forward(self, x: torch.Tensor):
115119

116120

117121
class ViTPatchGenerator(nn.Module):
118-
"""Copy from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/vit_patch_generator.py."""
122+
"""Modified from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/vit_patch_generator.py."""
119123

120124
def __init__(
121125
self,
@@ -132,8 +136,6 @@ def __init__(
132136
register_multiple: Optional[int] = None,
133137
num_registers: Optional[int] = None,
134138
patch_bias: bool = False,
135-
device=None,
136-
dtype=None,
137139
):
138140
super().__init__()
139141

@@ -151,42 +153,31 @@ def __init__(
151153
self.cpe_mode = max_input_dims != input_dims
152154
self.pos_dropout = pos_dropout
153155
self.return_pos_enc = return_pos_enc
154-
155-
factory = dict(device=device, dtype=dtype)
156-
157156
self.patch_size = patch_size
158157
self.abs_pos = abs_pos
159158
self.embed_dim = embed_dim
160-
161159
self.num_rows = max_input_dims[0] // patch_size
162160
self.num_cols = max_input_dims[1] // patch_size
163161
self.input_dims = tuple(d // patch_size for d in input_dims)
164162
self.num_patches = self.num_rows * self.num_cols
165163
self.max_input_dims = max_input_dims
166164

167165
self.im_to_patches = Im2Patches(patch_size)
168-
self.embedder = ViTPatchLinear(patch_size,
169-
embed_dim,
170-
bias=patch_bias,
171-
**factory)
172-
166+
self.embedder = ViTPatchLinear(patch_size, embed_dim, bias=patch_bias)
173167
if abs_pos:
174168
scale = embed_dim**-0.5
175169
self.pos_embed = nn.Parameter(
176-
torch.randn(1, self.num_patches, embed_dim, **factory) * scale)
177-
170+
torch.randn(1, self.num_patches, embed_dim) * scale)
178171
self.cls_token = ClsToken(
179172
embed_dim,
180173
num_tokens=num_cls_tokens,
181174
enabled=cls_token,
182175
register_multiple=register_multiple,
183176
num_registers=num_registers,
184177
)
185-
186178
self.patch_normalizer = nn.LayerNorm(
187179
embed_dim) if normalize_patches else nn.Identity()
188180

189-
@torch.compile
190181
def forward(self, x: torch.Tensor) -> torch.Tensor:
191182
patches = self.embed_patches(x)
192183
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
@@ -265,7 +256,6 @@ def window_select(pos_embed):
265256
size=(max_dim, max_dim),
266257
align_corners=True,
267258
mode='bilinear').to(pos_embed.dtype)
268-
269259
pos_embed = window_select(pos_embed)
270260
else:
271261
pos_embed = window_select(pos_embed)
@@ -277,12 +267,11 @@ def window_select(pos_embed):
277267
mode='bilinear').to(pos_embed.dtype)
278268

279269
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
280-
281270
return pos_embed
282271

283272

284273
class Im2Patches(nn.Module):
285-
"""Copy from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/vit_patch_generator.py."""
274+
"""Modified from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/vit_patch_generator.py."""
286275

287276
def __init__(self, patch_size: int):
288277
super().__init__()
@@ -308,22 +297,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
308297

309298

310299
class ViTPatchLinear(nn.Linear):
311-
"""Copy from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/vit_patch_generator.py."""
312-
313-
def __init__(self,
314-
patch_size: int,
315-
embed_dim: int,
316-
bias: bool = False,
317-
**kwargs):
318-
super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **kwargs)
300+
"""Modified from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/vit_patch_generator.py."""
301+
302+
def __init__(
303+
self,
304+
patch_size: int,
305+
embed_dim: int,
306+
bias: bool = False,
307+
):
308+
super().__init__(3 * (patch_size**2), embed_dim, bias=bias)
319309
self.patch_size = patch_size
320310

321311

322312
class Block(nn.Module):
323313
"""Transformer block with pre-normalization.
324314
325-
Copy from https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
326-
and use trtllm_attn and trtllm_mlp to replace attn and mlp.
315+
Modified from https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
316+
Use trtllm_attn and trtllm_mlp to replace original attention and mlp layers.
327317
"""
328318

329319
def __init__(
@@ -378,16 +368,16 @@ def __init__(
378368
hidden_size=dim,
379369
num_attention_heads=num_heads,
380370
num_key_value_heads=num_heads,
371+
max_position_embeddings=None,
381372
bias=qkv_bias,
382-
dense_bias=proj_bias,
383-
dtype=self.model_config.torch_dtype,
384-
layer_idx=layer_idx,
385373
pos_embd_params=None,
386374
rope_fusion=None,
375+
layer_idx=layer_idx,
376+
dtype=self.model_config.torch_dtype,
377+
dense_bias=proj_bias,
378+
config=self.model_config,
387379
q_scaling=1.0,
388380
attention_chunk_size=None,
389-
config=self.model_config,
390-
max_position_embeddings=None,
391381
)
392382
if init_values:
393383
raise IOError(
@@ -399,8 +389,6 @@ def __init__(
399389
"Limited RADIO model support: Block does not support DropPath for now."
400390
)
401391
self.drop_path1 = nn.Identity()
402-
403-
self.norm2 = norm_layer(dim)
404392
if scale_mlp_norm:
405393
raise IOError(
406394
"Limited RADIO model support: Block does not support scale_mlp_norm for now."
@@ -409,6 +397,7 @@ def __init__(
409397
raise IOError(
410398
"Limited RADIO model support: Block does not support proj_drop for now."
411399
)
400+
self.norm2 = norm_layer(dim)
412401

413402
self.mlp = trtllm_mlp.MLP(
414403
hidden_size=dim,
@@ -442,8 +431,7 @@ def forward(
442431
position_ids=None,
443432
hidden_states=x,
444433
attn_metadata=attn_metadata,
445-
attention_mask=attention_interface.PredefinedAttentionMask.
446-
FULL # Always FULL for Vision
434+
attention_mask=attention_interface.PredefinedAttentionMask.FULL,
447435
)
448436
x = self.ls1(x)
449437
x = self.drop_path1(x)
@@ -461,7 +449,7 @@ def forward(
461449
class VisionTransformer(nn.Module):
462450
""" Vision Transformer.
463451
464-
Copy from https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py.
452+
Modified from https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py.
465453
"""
466454

467455
def __init__(
@@ -535,9 +523,11 @@ def __init__(
535523
**kwargs: Additional keyword arguments, to store unused arguments.
536524
"""
537525
super().__init__()
538-
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
539-
assert class_token or global_pool != 'token'
540-
assert pos_embed in ('', 'none', 'learn')
526+
if not (class_token or global_pool != 'token'):
527+
raise ValueError(
528+
"Class token must be used with global_pool == 'token'")
529+
if pos_embed not in ('', 'none', 'learn'):
530+
raise ValueError(f"Invalid pos_embed: {pos_embed}")
541531
use_fc_norm = global_pool in ('avg', 'avgmax',
542532
'max') if fc_norm is None else fc_norm
543533

@@ -555,7 +545,7 @@ def __init__(
555545

556546
self.num_classes = num_classes
557547
self.global_pool = global_pool
558-
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
548+
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim
559549
self.num_prefix_tokens = 1 if class_token else 0
560550
self.num_prefix_tokens += reg_tokens
561551
self.num_reg_tokens = reg_tokens
@@ -565,7 +555,7 @@ def __init__(
565555
self.patch_drop = nn.Identity()
566556
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
567557

568-
# stochastic depth decay rule
558+
# Stochastic depth decay rule.
569559
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
570560
self.blocks = nn.ModuleList([
571561
Block(
@@ -590,7 +580,7 @@ def __init__(
590580
self.norm = norm_layer(
591581
embed_dim) if final_norm and not use_fc_norm else nn.Identity()
592582

593-
# Classifier Head but not used for RADIO embedding models.
583+
# Initialize classifier head but not used for RADIO embedding models.
594584
self.attn_pool = None
595585
self.fc_norm = norm_layer(
596586
embed_dim) if final_norm and use_fc_norm else nn.Identity()
@@ -664,9 +654,8 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
664654
if self.model_config is not None:
665655
seq_lengths = [seq_len] * batch_size
666656
attn_metadata = self.prepare_attn_metadata(batch_size, seq_lengths)
667-
x = x.reshape(
668-
batch_size * seq_len,
669-
hidden_size) # Need flatten batch/seq_len for trtllm attention.
657+
# Need flatten batch/seq_len for trtllm attention.
658+
x = x.reshape(batch_size * seq_len, hidden_size)
670659
else:
671660
attn_metadata = None
672661
for block in self.blocks:
@@ -678,7 +667,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
678667

679668

680669
class RADIOVisionModelBase(nn.Module):
681-
"""Copy and modify from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/radio_model.py"""
670+
"""Modify from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/radio_model.py"""
682671

683672
def __init__(
684673
self,
@@ -783,17 +772,13 @@ def get_nearest_supported_resolution(self, height: int,
783772
round(height / self.min_resolution_step) * self.min_resolution_step)
784773
width = int(
785774
round(width / self.min_resolution_step) * self.min_resolution_step)
786-
787775
height = max(height, self.min_resolution_step)
788776
width = max(width, self.min_resolution_step)
789-
790777
return Resolution(height=height, width=width)
791778

792-
def forward(
793-
self,
794-
x: torch.Tensor,
795-
feature_fmt: str = 'NLC'
796-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
779+
def forward(self,
780+
x: torch.Tensor,
781+
feature_fmt: str = 'NLC') -> torch.Tensor:
797782
res_step = self.min_resolution_step
798783
if res_step is not None and (x.shape[-2] % res_step != 0
799784
or x.shape[-1] % res_step != 0):
@@ -807,7 +792,6 @@ def forward(
807792
ret = self._extract_final(x, y, feature_fmt=feature_fmt)
808793
return ret
809794

810-
@torch.compile
811795
def _extract_final(self,
812796
x: torch.Tensor,
813797
y: torch.Tensor,
@@ -836,12 +820,11 @@ def _extract_final(self,
836820
raise ValueError(
837821
f'Unsupported feature_fmt: {feature_fmt}. Must be one of ["NLC", "NCHW"]'
838822
)
839-
840823
return fmt_feat
841824

842825

843826
class RADIOVisionModel(PreTrainedModel):
844-
"""Copy and modify from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/hf_model.py."""
827+
"""Modify from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/hf_model.py."""
845828

846829
def __init__(self, model_config: model_config_lib.ModelConfig):
847830
"""
@@ -863,11 +846,11 @@ def __init__(self, model_config: model_config_lib.ModelConfig):
863846
elif args.input_size is not None:
864847
in_chans = args.input_size[0]
865848
vit_model = VisionTransformer(
866-
img_size=224,
867-
patch_size=16,
868-
embed_dim=1280,
869-
depth=32,
870-
num_heads=NUM_ATTENTION_HEADS,
849+
img_size=IMAGE_SIZE_FOR_VIT,
850+
patch_size=PATCH_SIZE_FOR_VIT,
851+
embed_dim=EMBED_DIM_FOR_VIT,
852+
depth=DEPTH_FOR_VIT,
853+
num_heads=NUM_ATTENTION_HEADS_FOR_VIT,
871854
in_chans=in_chans,
872855
drop_rate=args.drop,
873856
special_args=args,
@@ -920,11 +903,11 @@ def load_weights(self, weights):
920903
}
921904
missing_keys, unexpected_keys = self.radio_model.load_state_dict(
922905
filter_weights, strict=False)
923-
924906
# Check missing and unexpected keys.
925907
# The input conditioner is not initialized in current implementation.
926908
unexpected_keys.remove("input_conditioner.norm_mean")
927909
unexpected_keys.remove("input_conditioner.norm_std")
910+
# Partial model.blocks weights will loaded in the following step.
928911
for m in missing_keys:
929912
if not m.startswith('model.blocks.'):
930913
raise ValueError(f"Missing key: {m}")

0 commit comments

Comments
 (0)