Skip to content

Commit f647f73

Browse files
danielafrimiroot
authored andcommitted
Add RADIO Vision Encoder Support to vLLM (vllm-project#24595)
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com> Co-authored-by: root <root@cw-dfw-h100-001-305-026.cm.cluster>
1 parent 2e4ccd1 commit f647f73

File tree

5 files changed

+828
-58
lines changed

5 files changed

+828
-58
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
import torch.nn as nn
6+
from huggingface_hub import snapshot_download
7+
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
8+
9+
from vllm.distributed import cleanup_dist_env_and_memory
10+
from vllm.model_executor.models.radio import RadioModel
11+
from vllm.transformers_utils.configs.radio import RadioConfig
12+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
13+
14+
from ....conftest import ImageTestAssets
15+
16+
# we use snapshot_download to prevent conflicts between
17+
# dynamic_module and trust_remote_code for hf_runner
18+
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
19+
20+
21+
@torch.inference_mode()
22+
def run_radio_test(
23+
image_assets: ImageTestAssets,
24+
model_id: str,
25+
*,
26+
dtype: str,
27+
):
28+
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
29+
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
30+
31+
img_processor = CLIPImageProcessor.from_pretrained(model)
32+
images = [asset.pil_image for asset in image_assets]
33+
# Input resolution must be a multiple of `self.min_resolution_step`.
34+
# Using `self.get_nearest_supported_resolution`, for assets 432x642 the
35+
# nearest supported resolution is 432x640.
36+
pixel_values = [
37+
img_processor(
38+
image,
39+
return_tensors='pt').pixel_values.to(torch_dtype)[:, :, :, :640]
40+
for image in images
41+
]
42+
43+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
44+
45+
hf_model = AutoModel.from_pretrained(
46+
model_id,
47+
config=config,
48+
torch_dtype=torch_dtype,
49+
trust_remote_code=True,
50+
).to("cuda")
51+
hf_model.eval()
52+
53+
hf_outputs_per_image = [
54+
hf_model(pixel_value.to("cuda")).features
55+
for pixel_value in pixel_values
56+
]
57+
58+
radio_config = RadioConfig(model_name=config.args["model"],
59+
reg_tokens=config.args["register_multiple"])
60+
vllm_model = RadioModel(radio_config)
61+
vllm_model.load_weights(hf_model.state_dict())
62+
vllm_model = vllm_model.to("cuda", torch_dtype)
63+
64+
vllm_outputs_per_image = [
65+
vllm_model(pixel_values=pixel_value.to("cuda"))
66+
for pixel_value in pixel_values
67+
]
68+
del vllm_model, hf_model
69+
cleanup_dist_env_and_memory()
70+
71+
cos_similar = nn.CosineSimilarity(dim=-1)
72+
for vllm_output, hf_output in zip(vllm_outputs_per_image,
73+
hf_outputs_per_image):
74+
assert cos_similar(vllm_output, hf_output).mean() > 0.99
75+
76+
77+
@pytest.mark.parametrize("model_id", [
78+
"nvidia/C-RADIOv2-H",
79+
])
80+
@pytest.mark.parametrize("dtype", ["half"])
81+
def test_radio(dist_init, image_assets, model_id, dtype: str) -> None:
82+
run_radio_test(
83+
image_assets,
84+
model_id,
85+
dtype=dtype,
86+
)

vllm/model_executor/models/nano_nemotron_vl.py

Lines changed: 60 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import torch.nn as nn
1919
import torchvision.transforms as T
2020
from PIL import Image
21-
from transformers import (AutoModel, BatchEncoding, BatchFeature,
22-
PretrainedConfig, TensorType)
21+
from transformers import (BatchEncoding, BatchFeature, PretrainedConfig,
22+
TensorType)
2323

2424
from vllm.config import VllmConfig
2525
from vllm.model_executor.layers.activation import ReLUSquaredActivation
@@ -32,6 +32,7 @@
3232
get_internvl_target_ratios)
3333
from vllm.model_executor.models.module_mapping import MultiModelKeys
3434
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
35+
from vllm.model_executor.models.radio import RadioModel
3536
from vllm.model_executor.models.utils import (flatten_bn,
3637
init_vllm_registered_model,
3738
maybe_prefix,
@@ -48,6 +49,7 @@
4849
PromptUpdate, PromptUpdateDetails)
4950
from vllm.multimodal.profiling import BaseDummyInputsBuilder
5051
from vllm.sequence import IntermediateTensors
52+
from vllm.transformers_utils.configs.radio import RadioConfig
5153
from vllm.transformers_utils.tokenizer import AnyTokenizer
5254
from vllm.utils.tensor_schema import TensorSchema, TensorShape
5355

@@ -122,11 +124,6 @@ class NanoNemotronVLVideoEmbeddingInputs(TensorSchema):
122124
NanoNemotronVLVideoEmbeddingInputs]
123125

124126

125-
def input_conditioner(x, norm_mean, norm_std):
126-
y = (x - norm_mean) / norm_std
127-
return y
128-
129-
130127
def dynamic_preprocess(image,
131128
*,
132129
image_size=512,
@@ -305,8 +302,7 @@ def _preprocess_image(
305302
images, max_num_tiles)
306303
image_inputs: dict[str, NestedTensors] = {
307304
"pixel_values_flat":
308-
input_conditioner(torch.cat(pixel_values_lst), self.norm_mean,
309-
self.norm_std),
305+
torch.cat(pixel_values_lst),
310306
"image_num_patches":
311307
torch.tensor([len(item) for item in pixel_values_lst]),
312308
}
@@ -428,8 +424,7 @@ def _preprocess_video(
428424

429425
video_inputs: dict[str, NestedTensors] = {
430426
"pixel_values_flat_video":
431-
input_conditioner(torch.cat(pixel_values_lst_video),
432-
self.norm_mean, self.norm_std),
427+
torch.cat(pixel_values_lst_video),
433428
"video_num_patches":
434429
torch.tensor([len(item) for item in pixel_values_lst_video]),
435430
}
@@ -905,18 +900,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
905900
hf_config=config.text_config,
906901
prefix=maybe_prefix(prefix, "language_model"),
907902
)
908-
self.vision_model = AutoModel.from_config(config.vision_config,
909-
trust_remote_code=True)
910-
self.vision_model.model._initialize_weights = (
911-
self.vision_model.model._init_weights)
912-
# Move input normalization to processor to mirror original HF
913-
# implementation where normalization is done in fp32
914-
self.vision_model.radio_model.make_preprocessor_external()
915-
self.vision_model = self.vision_model.to(
903+
self.vision_model = self.get_vit_model_from_radio_config(config).to(
916904
self.language_model.config.torch_dtype)
917905

918-
self.drop_vision_class_token = True
919-
920906
# Construct the vision projection.
921907
vit_hidden_size = config.vit_hidden_size
922908
vision_projection_hidden_size = config.projector_hidden_size
@@ -972,7 +958,7 @@ def pixel_shuffle(self, x, scale_factor=0.5):
972958
return x
973959

974960
def extract_feature(self, pixel_values):
975-
vit_embeds = self.vision_model(pixel_values).features
961+
vit_embeds = self.vision_model(pixel_values)
976962
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
977963
h = w = int(vit_embeds.shape[1]**0.5)
978964
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
@@ -1212,47 +1198,39 @@ def compute_logits(
12121198
sampling_metadata)
12131199

12141200
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
1201+
adapter_dict = dict(self.mlp1.named_parameters())
12151202

1216-
def is_vision_model_weights(weight: tuple[str, torch.Tensor]):
1217-
return weight[0].startswith("vision_model")
1203+
def is_llm(name: str) -> bool:
1204+
return name.startswith("language_model")
12181205

12191206
def is_adapter_weights(weight: tuple[str, torch.Tensor]):
12201207
return weight[0].startswith("mlp1")
12211208

1222-
# Get references to parameters for direct loading
1223-
vision_model_dict = dict(self.vision_model.named_parameters())
1224-
vision_model_buffers = dict(self.vision_model.named_buffers())
1225-
adapter_dict = dict(self.mlp1.named_parameters())
1226-
1227-
def llm_weights_generator():
1228-
# Single pass over weights
1229-
for name, w in weights:
1230-
if is_vision_model_weights((name, w)):
1231-
# Load vision encoder weights directly
1232-
trimmed_name = ".".join(name.split(".")[1:])
1233-
if "input_conditioner" in trimmed_name:
1234-
continue
1235-
if trimmed_name in vision_model_buffers:
1236-
param = vision_model_buffers[trimmed_name]
1237-
else:
1238-
param = vision_model_dict[trimmed_name]
1239-
with torch.no_grad():
1240-
default_weight_loader(param, w)
1241-
elif is_adapter_weights((name, w)):
1242-
# Load vision-language adapter weights directly
1243-
trimmed_name = ".".join(name.split(".")[1:])
1244-
param = adapter_dict[trimmed_name]
1245-
with torch.no_grad():
1246-
default_weight_loader(param, w)
1247-
else:
1248-
# LLM weights: yield them to be loaded
1249-
# by language_model.load_weights
1250-
assert name.startswith("language_model")
1251-
trimmed_name = ".".join(name.split(".")[1:])
1252-
yield (trimmed_name, w)
1253-
1254-
# Now we call the language model load with the generator
1255-
self.language_model.load_weights(llm_weights_generator())
1209+
def is_vision_weights(name: str) -> bool:
1210+
return name.startswith("vision_model.radio_model.")
1211+
1212+
# Separate weights by component
1213+
llm_weights = []
1214+
vision_weights = []
1215+
1216+
for name, w in weights:
1217+
if is_llm(name):
1218+
# Strip 'language_model.' prefix for LLM weights
1219+
llm_weights.append((".".join(name.split(".")[1:]), w))
1220+
elif is_adapter_weights((name, w)):
1221+
# Load vision-language adapter weights directly
1222+
trimmed_name = ".".join(name.split(".")[1:])
1223+
param = adapter_dict[trimmed_name]
1224+
with torch.no_grad():
1225+
default_weight_loader(param, w)
1226+
elif is_vision_weights(name):
1227+
# Convert: vision_model.radio_model.* → radio_model.*
1228+
hf_key = name[len(
1229+
"vision_model."):] # Remove "vision_model." prefix
1230+
vision_weights.append((hf_key, w))
1231+
1232+
self.language_model.load_weights(llm_weights)
1233+
self.vision_model.load_weights(vision_weights)
12561234

12571235
def print_architecture(self,
12581236
detailed: bool = True,
@@ -1370,6 +1348,30 @@ def get_model_info(self):
13701348
},
13711349
}
13721350

1351+
def get_vit_model_from_radio_config(self, hf_config):
1352+
hf_config_vision = hf_config.vision_config
1353+
model_name = hf_config_vision.args.get("model")
1354+
if model_name is None:
1355+
raise ValueError(f'Unsupported vit model type: {model_name}')
1356+
1357+
preferred_resolution = getattr(hf_config_vision,
1358+
"preferred_resolution", None)
1359+
image_size = preferred_resolution[0] if preferred_resolution else 224
1360+
patch_size = getattr(hf_config_vision, "patch_size", 16)
1361+
1362+
radio_config = RadioConfig(
1363+
model_name=model_name,
1364+
image_size=image_size,
1365+
patch_size=patch_size,
1366+
norm_mean=hf_config.norm_mean,
1367+
norm_std=hf_config.norm_std,
1368+
reg_tokens=(hf_config_vision.args.get("register_multiple")
1369+
if hasattr(hf_config_vision, "args")
1370+
and isinstance(hf_config_vision.args, dict) else None),
1371+
)
1372+
1373+
return RadioModel(config=radio_config)
1374+
13731375
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
13741376
return self.language_model.mamba_cache.copy_inputs_before_cuda_graphs(
13751377
input_buffers, **kwargs)

0 commit comments

Comments
 (0)