Skip to content

Commit 1e3ec79

Browse files
committed
[TRTLLM-6577][feat] Support nano_v2_vlm in pytorch backend
* clean up codes. Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent ecaa7dd commit 1e3ec79

File tree

2 files changed

+114
-683
lines changed

2 files changed

+114
-683
lines changed

tensorrt_llm/_torch/models/modeling_nanov2vlm.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..model_config import ModelConfig
2222
from .modeling_auto import AutoModelForCausalLM
2323
from .modeling_multimodal_utils import fuse_input_embeds
24+
from .modeling_radio import RADIOVisionModel
2425
from .modeling_utils import register_auto_model
2526

2627

@@ -66,19 +67,12 @@ def __init__(self,
6667
self.num_image_token = int((self.image_size // self.patch_size)**2 *
6768
(config.downsample_ratio**2))
6869
self.downsample_ratio = config.downsample_ratio
69-
self.ps_version = config.ps_version
70-
# self.image_tag_type = config.image_tag_type
71-
72-
logger.info(f'num_image_token: {self.num_image_token}')
73-
logger.info(f'ps_version: {self.ps_version}')
74-
75-
# self.drop_vision_class_token = True
70+
self.ps_version = config.ps_version # Pixel shuffle version.
7671

7772
# Construct the vision projection.
7873
self.vit_hidden_size = config.vit_hidden_size
7974
self.vision_projection_hidden_size = config.projector_hidden_size
8075
self.llm_hidden_size = config.llm_config.hidden_size
81-
8276
self.mlp1 = nn.Sequential(
8377
RMSNorm(self.vit_hidden_size * int(1 / self.downsample_ratio)**2,
8478
eps=1e-5),
@@ -90,37 +84,19 @@ def __init__(self,
9084
bias=False))
9185
self.mlp1 = self.mlp1.to(config.torch_dtype)
9286

93-
WITH_HF_CODES = False
94-
if WITH_HF_CODES:
87+
# Construct the vision encoder.
88+
self.with_hf_codes = os.getenv("WITH_HF_CODES", "0") == "1"
89+
if self.with_hf_codes:
9590
self.vision_model = transformers.AutoModel.from_config(
9691
config.vision_config, trust_remote_code=True)
9792
# set input_condition as Identity module.
9893
self.vision_model.radio_model.make_preprocessor_external()
9994
self.vision_model.to(config.torch_dtype)
100-
101-
with open("hf_vision_encoder_arch.txt", "w") as f:
102-
f.write(str(self.vision_model))
10395
else:
104-
WITH_TRTLLM_CODES = True
105-
if WITH_TRTLLM_CODES:
106-
from .modeling_radio import RADIOVisionModel
107-
108-
vision_model_config = copy.deepcopy(model_config)
109-
vision_model_config.pretrained_config = vision_model_config.pretrained_config.vision_config
110-
111-
self.vision_model = RADIOVisionModel(vision_model_config)
112-
self.vision_model.to(config.torch_dtype)
113-
114-
with open("trtllm_vision_encoder_arch.txt", "w") as f:
115-
f.write(str(self.vision_model))
116-
else:
117-
# Update the vision model with customized one.
118-
from .modeling_radio import RADIOModel
119-
self.vision_model = RADIOModel(config.vision_config)
120-
self.vision_model.to(config.torch_dtype)
121-
122-
with open("user_vision_encoder_arch.txt", "w") as f:
123-
f.write(str(self.vision_model))
96+
vision_model_config = copy.deepcopy(model_config)
97+
vision_model_config.pretrained_config = vision_model_config.pretrained_config.vision_config
98+
self.vision_model = RADIOVisionModel(vision_model_config)
99+
self.vision_model.to(config.torch_dtype)
124100

125101
@torch.compile
126102
def pixel_shuffle(self, x, scale_factor=0.5):
@@ -141,8 +117,12 @@ def pixel_shuffle(self, x, scale_factor=0.5):
141117
return x
142118

143119
def extract_feature(self, pixel_values):
144-
vit_embeds = self.vision_model(pixel_values).features
120+
if self.with_hf_codes:
121+
vit_embeds = self.vision_model(pixel_values).features
122+
else:
123+
vit_embeds = self.vision_model(pixel_values)
145124
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
125+
# Down-sampling and projection.
146126
h = w = int(vit_embeds.shape[1]**0.5)
147127
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
148128
vit_embeds = self.pixel_shuffle(vit_embeds,
@@ -317,7 +297,11 @@ def load_weights(self, weights):
317297
}
318298
missing_keys, unexpected_keys = self.vision_encoder.load_state_dict(
319299
filter_weights, strict=False)
320-
missing_keys.remove("vision_model.radio_model.summary_idxs")
300+
try:
301+
missing_keys.remove("vision_model.radio_model.summary_idxs")
302+
except ValueError:
303+
pass
304+
321305
unexpected_keys.remove(
322306
"vision_model.radio_model.input_conditioner.norm_mean")
323307
unexpected_keys.remove(

0 commit comments

Comments
 (0)