|
18 | 18 | import torch.nn as nn |
19 | 19 | import torchvision.transforms as T |
20 | 20 | from PIL import Image |
21 | | -from transformers import (AutoModel, BatchEncoding, BatchFeature, |
22 | | - PretrainedConfig, TensorType) |
| 21 | +from transformers import (BatchEncoding, BatchFeature, PretrainedConfig, |
| 22 | + TensorType) |
23 | 23 |
|
24 | 24 | from vllm.config import VllmConfig |
25 | 25 | from vllm.model_executor.layers.activation import ReLUSquaredActivation |
|
32 | 32 | get_internvl_target_ratios) |
33 | 33 | from vllm.model_executor.models.module_mapping import MultiModelKeys |
34 | 34 | from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM |
| 35 | +from vllm.model_executor.models.radio import RadioModel |
35 | 36 | from vllm.model_executor.models.utils import (flatten_bn, |
36 | 37 | init_vllm_registered_model, |
37 | 38 | maybe_prefix, |
|
48 | 49 | PromptUpdate, PromptUpdateDetails) |
49 | 50 | from vllm.multimodal.profiling import BaseDummyInputsBuilder |
50 | 51 | from vllm.sequence import IntermediateTensors |
| 52 | +from vllm.transformers_utils.configs.radio import RadioConfig |
51 | 53 | from vllm.transformers_utils.tokenizer import AnyTokenizer |
52 | 54 | from vllm.utils.tensor_schema import TensorSchema, TensorShape |
53 | 55 |
|
@@ -122,11 +124,6 @@ class NanoNemotronVLVideoEmbeddingInputs(TensorSchema): |
122 | 124 | NanoNemotronVLVideoEmbeddingInputs] |
123 | 125 |
|
124 | 126 |
|
125 | | -def input_conditioner(x, norm_mean, norm_std): |
126 | | - y = (x - norm_mean) / norm_std |
127 | | - return y |
128 | | - |
129 | | - |
130 | 127 | def dynamic_preprocess(image, |
131 | 128 | *, |
132 | 129 | image_size=512, |
@@ -305,8 +302,7 @@ def _preprocess_image( |
305 | 302 | images, max_num_tiles) |
306 | 303 | image_inputs: dict[str, NestedTensors] = { |
307 | 304 | "pixel_values_flat": |
308 | | - input_conditioner(torch.cat(pixel_values_lst), self.norm_mean, |
309 | | - self.norm_std), |
| 305 | + torch.cat(pixel_values_lst), |
310 | 306 | "image_num_patches": |
311 | 307 | torch.tensor([len(item) for item in pixel_values_lst]), |
312 | 308 | } |
@@ -428,8 +424,7 @@ def _preprocess_video( |
428 | 424 |
|
429 | 425 | video_inputs: dict[str, NestedTensors] = { |
430 | 426 | "pixel_values_flat_video": |
431 | | - input_conditioner(torch.cat(pixel_values_lst_video), |
432 | | - self.norm_mean, self.norm_std), |
| 427 | + torch.cat(pixel_values_lst_video), |
433 | 428 | "video_num_patches": |
434 | 429 | torch.tensor([len(item) for item in pixel_values_lst_video]), |
435 | 430 | } |
@@ -905,18 +900,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
905 | 900 | hf_config=config.text_config, |
906 | 901 | prefix=maybe_prefix(prefix, "language_model"), |
907 | 902 | ) |
908 | | - self.vision_model = AutoModel.from_config(config.vision_config, |
909 | | - trust_remote_code=True) |
910 | | - self.vision_model.model._initialize_weights = ( |
911 | | - self.vision_model.model._init_weights) |
912 | | - # Move input normalization to processor to mirror original HF |
913 | | - # implementation where normalization is done in fp32 |
914 | | - self.vision_model.radio_model.make_preprocessor_external() |
915 | | - self.vision_model = self.vision_model.to( |
| 903 | + self.vision_model = self.get_vit_model_from_radio_config(config).to( |
916 | 904 | self.language_model.config.torch_dtype) |
917 | 905 |
|
918 | | - self.drop_vision_class_token = True |
919 | | - |
920 | 906 | # Construct the vision projection. |
921 | 907 | vit_hidden_size = config.vit_hidden_size |
922 | 908 | vision_projection_hidden_size = config.projector_hidden_size |
@@ -972,7 +958,7 @@ def pixel_shuffle(self, x, scale_factor=0.5): |
972 | 958 | return x |
973 | 959 |
|
974 | 960 | def extract_feature(self, pixel_values): |
975 | | - vit_embeds = self.vision_model(pixel_values).features |
| 961 | + vit_embeds = self.vision_model(pixel_values) |
976 | 962 | vit_embeds = vit_embeds.to(dtype=torch.bfloat16) |
977 | 963 | h = w = int(vit_embeds.shape[1]**0.5) |
978 | 964 | vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) |
@@ -1212,47 +1198,39 @@ def compute_logits( |
1212 | 1198 | sampling_metadata) |
1213 | 1199 |
|
1214 | 1200 | def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
| 1201 | + adapter_dict = dict(self.mlp1.named_parameters()) |
1215 | 1202 |
|
1216 | | - def is_vision_model_weights(weight: tuple[str, torch.Tensor]): |
1217 | | - return weight[0].startswith("vision_model") |
| 1203 | + def is_llm(name: str) -> bool: |
| 1204 | + return name.startswith("language_model") |
1218 | 1205 |
|
1219 | 1206 | def is_adapter_weights(weight: tuple[str, torch.Tensor]): |
1220 | 1207 | return weight[0].startswith("mlp1") |
1221 | 1208 |
|
1222 | | - # Get references to parameters for direct loading |
1223 | | - vision_model_dict = dict(self.vision_model.named_parameters()) |
1224 | | - vision_model_buffers = dict(self.vision_model.named_buffers()) |
1225 | | - adapter_dict = dict(self.mlp1.named_parameters()) |
1226 | | - |
1227 | | - def llm_weights_generator(): |
1228 | | - # Single pass over weights |
1229 | | - for name, w in weights: |
1230 | | - if is_vision_model_weights((name, w)): |
1231 | | - # Load vision encoder weights directly |
1232 | | - trimmed_name = ".".join(name.split(".")[1:]) |
1233 | | - if "input_conditioner" in trimmed_name: |
1234 | | - continue |
1235 | | - if trimmed_name in vision_model_buffers: |
1236 | | - param = vision_model_buffers[trimmed_name] |
1237 | | - else: |
1238 | | - param = vision_model_dict[trimmed_name] |
1239 | | - with torch.no_grad(): |
1240 | | - default_weight_loader(param, w) |
1241 | | - elif is_adapter_weights((name, w)): |
1242 | | - # Load vision-language adapter weights directly |
1243 | | - trimmed_name = ".".join(name.split(".")[1:]) |
1244 | | - param = adapter_dict[trimmed_name] |
1245 | | - with torch.no_grad(): |
1246 | | - default_weight_loader(param, w) |
1247 | | - else: |
1248 | | - # LLM weights: yield them to be loaded |
1249 | | - # by language_model.load_weights |
1250 | | - assert name.startswith("language_model") |
1251 | | - trimmed_name = ".".join(name.split(".")[1:]) |
1252 | | - yield (trimmed_name, w) |
1253 | | - |
1254 | | - # Now we call the language model load with the generator |
1255 | | - self.language_model.load_weights(llm_weights_generator()) |
| 1209 | + def is_vision_weights(name: str) -> bool: |
| 1210 | + return name.startswith("vision_model.radio_model.") |
| 1211 | + |
| 1212 | + # Separate weights by component |
| 1213 | + llm_weights = [] |
| 1214 | + vision_weights = [] |
| 1215 | + |
| 1216 | + for name, w in weights: |
| 1217 | + if is_llm(name): |
| 1218 | + # Strip 'language_model.' prefix for LLM weights |
| 1219 | + llm_weights.append((".".join(name.split(".")[1:]), w)) |
| 1220 | + elif is_adapter_weights((name, w)): |
| 1221 | + # Load vision-language adapter weights directly |
| 1222 | + trimmed_name = ".".join(name.split(".")[1:]) |
| 1223 | + param = adapter_dict[trimmed_name] |
| 1224 | + with torch.no_grad(): |
| 1225 | + default_weight_loader(param, w) |
| 1226 | + elif is_vision_weights(name): |
| 1227 | + # Convert: vision_model.radio_model.* → radio_model.* |
| 1228 | + hf_key = name[len( |
| 1229 | + "vision_model."):] # Remove "vision_model." prefix |
| 1230 | + vision_weights.append((hf_key, w)) |
| 1231 | + |
| 1232 | + self.language_model.load_weights(llm_weights) |
| 1233 | + self.vision_model.load_weights(vision_weights) |
1256 | 1234 |
|
1257 | 1235 | def print_architecture(self, |
1258 | 1236 | detailed: bool = True, |
@@ -1370,6 +1348,30 @@ def get_model_info(self): |
1370 | 1348 | }, |
1371 | 1349 | } |
1372 | 1350 |
|
| 1351 | + def get_vit_model_from_radio_config(self, hf_config): |
| 1352 | + hf_config_vision = hf_config.vision_config |
| 1353 | + model_name = hf_config_vision.args.get("model") |
| 1354 | + if model_name is None: |
| 1355 | + raise ValueError(f'Unsupported vit model type: {model_name}') |
| 1356 | + |
| 1357 | + preferred_resolution = getattr(hf_config_vision, |
| 1358 | + "preferred_resolution", None) |
| 1359 | + image_size = preferred_resolution[0] if preferred_resolution else 224 |
| 1360 | + patch_size = getattr(hf_config_vision, "patch_size", 16) |
| 1361 | + |
| 1362 | + radio_config = RadioConfig( |
| 1363 | + model_name=model_name, |
| 1364 | + image_size=image_size, |
| 1365 | + patch_size=patch_size, |
| 1366 | + norm_mean=hf_config.norm_mean, |
| 1367 | + norm_std=hf_config.norm_std, |
| 1368 | + reg_tokens=(hf_config_vision.args.get("register_multiple") |
| 1369 | + if hasattr(hf_config_vision, "args") |
| 1370 | + and isinstance(hf_config_vision.args, dict) else None), |
| 1371 | + ) |
| 1372 | + |
| 1373 | + return RadioModel(config=radio_config) |
| 1374 | + |
1373 | 1375 | def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): |
1374 | 1376 | return self.language_model.mamba_cache.copy_inputs_before_cuda_graphs( |
1375 | 1377 | input_buffers, **kwargs) |
|
0 commit comments