Skip to content

Commit 88594b4

Browse files
committed
[TRTLLM-6577][feat] Support nano_v2_vlm in pytorch backend
* clean up codes. Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent ecaa7dd commit 88594b4

File tree

2 files changed

+122
-686
lines changed

2 files changed

+122
-686
lines changed

tensorrt_llm/_torch/models/modeling_nanov2vlm.py

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..model_config import ModelConfig
2222
from .modeling_auto import AutoModelForCausalLM
2323
from .modeling_multimodal_utils import fuse_input_embeds
24+
from .modeling_radio import RADIOVisionModel
2425
from .modeling_utils import register_auto_model
2526

2627

@@ -66,19 +67,12 @@ def __init__(self,
6667
self.num_image_token = int((self.image_size // self.patch_size)**2 *
6768
(config.downsample_ratio**2))
6869
self.downsample_ratio = config.downsample_ratio
69-
self.ps_version = config.ps_version
70-
# self.image_tag_type = config.image_tag_type
71-
72-
logger.info(f'num_image_token: {self.num_image_token}')
73-
logger.info(f'ps_version: {self.ps_version}')
74-
75-
# self.drop_vision_class_token = True
70+
self.ps_version = config.ps_version # Pixel shuffle version.
7671

7772
# Construct the vision projection.
7873
self.vit_hidden_size = config.vit_hidden_size
7974
self.vision_projection_hidden_size = config.projector_hidden_size
8075
self.llm_hidden_size = config.llm_config.hidden_size
81-
8276
self.mlp1 = nn.Sequential(
8377
RMSNorm(self.vit_hidden_size * int(1 / self.downsample_ratio)**2,
8478
eps=1e-5),
@@ -90,37 +84,19 @@ def __init__(self,
9084
bias=False))
9185
self.mlp1 = self.mlp1.to(config.torch_dtype)
9286

93-
WITH_HF_CODES = False
94-
if WITH_HF_CODES:
87+
# Construct the vision encoder.
88+
self.with_hf_codes = os.getenv("WITH_HF_CODES", "0") == "1"
89+
if self.with_hf_codes:
9590
self.vision_model = transformers.AutoModel.from_config(
9691
config.vision_config, trust_remote_code=True)
9792
# set input_condition as Identity module.
9893
self.vision_model.radio_model.make_preprocessor_external()
9994
self.vision_model.to(config.torch_dtype)
100-
101-
with open("hf_vision_encoder_arch.txt", "w") as f:
102-
f.write(str(self.vision_model))
10395
else:
104-
WITH_TRTLLM_CODES = True
105-
if WITH_TRTLLM_CODES:
106-
from .modeling_radio import RADIOVisionModel
107-
108-
vision_model_config = copy.deepcopy(model_config)
109-
vision_model_config.pretrained_config = vision_model_config.pretrained_config.vision_config
110-
111-
self.vision_model = RADIOVisionModel(vision_model_config)
112-
self.vision_model.to(config.torch_dtype)
113-
114-
with open("trtllm_vision_encoder_arch.txt", "w") as f:
115-
f.write(str(self.vision_model))
116-
else:
117-
# Update the vision model with customized one.
118-
from .modeling_radio import RADIOModel
119-
self.vision_model = RADIOModel(config.vision_config)
120-
self.vision_model.to(config.torch_dtype)
121-
122-
with open("user_vision_encoder_arch.txt", "w") as f:
123-
f.write(str(self.vision_model))
96+
vision_model_config = copy.deepcopy(model_config)
97+
vision_model_config.pretrained_config = vision_model_config.pretrained_config.vision_config
98+
self.vision_model = RADIOVisionModel(vision_model_config)
99+
self.vision_model.to(config.torch_dtype)
124100

125101
@torch.compile
126102
def pixel_shuffle(self, x, scale_factor=0.5):
@@ -141,8 +117,12 @@ def pixel_shuffle(self, x, scale_factor=0.5):
141117
return x
142118

143119
def extract_feature(self, pixel_values):
144-
vit_embeds = self.vision_model(pixel_values).features
120+
if self.with_hf_codes:
121+
vit_embeds = self.vision_model(pixel_values).features
122+
else:
123+
vit_embeds = self.vision_model(pixel_values)
145124
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
125+
# Down-sampling and projection.
146126
h = w = int(vit_embeds.shape[1]**0.5)
147127
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
148128
vit_embeds = self.pixel_shuffle(vit_embeds,
@@ -163,12 +143,15 @@ def forward(self, multimodal_params: List[MultimodalParams]):
163143
for multimodal_param in multimodal_params
164144
],
165145
dim=0)
166-
batched_num_patches = [
146+
# -> [num_patches, channel, height, width]
147+
batched_num_patches = torch.cat([
167148
multimodal_param.multimodal_data["num_patches"]
168149
for multimodal_param in multimodal_params
169-
]
170-
# -> [num_patches, num_image_token, hidden_size]
150+
],
151+
dim=0).tolist()
152+
# -> list of[num_patches1, num_patches2, ...]
171153
batched_image_embeds = self.extract_feature(batched_pixel_values)
154+
# -> [num_patches, num_image_token, hidden_size]
172155
mm_embedding = torch.split(batched_image_embeds,
173156
batched_num_patches,
174157
dim=0)
@@ -309,6 +292,8 @@ def __init__(self, model_config: ModelConfig):
309292
self.is_loaded = True
310293

311294
def load_weights(self, weights):
295+
# TODO: move vision encoder weights loading to vision encoder class.
296+
312297
# Load vision encoder weights for pytorch modules.
313298
filter_weights = {
314299
k: v
@@ -317,7 +302,11 @@ def load_weights(self, weights):
317302
}
318303
missing_keys, unexpected_keys = self.vision_encoder.load_state_dict(
319304
filter_weights, strict=False)
320-
missing_keys.remove("vision_model.radio_model.summary_idxs")
305+
try:
306+
missing_keys.remove("vision_model.radio_model.summary_idxs")
307+
except ValueError:
308+
pass
309+
321310
unexpected_keys.remove(
322311
"vision_model.radio_model.input_conditioner.norm_mean")
323312
unexpected_keys.remove(

0 commit comments

Comments
 (0)