Skip to content

Commit d736895

Browse files
committed
feat: working code for phi3v
1 parent 0eff4e0 commit d736895

File tree

7 files changed

+95
-30
lines changed

7 files changed

+95
-30
lines changed

examples/llm/configs/disagg.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ Frontend:
2626
Processor:
2727
router: round-robin
2828
common-configs: [model, block-size]
29-
prompt-template: "USER: <image>\n<prompt> ASSISTANT:"
3029

3130
VllmWorker:
3231
remote-prefill: true

examples/multimodal/components/decode_worker.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from components.encode_worker import VllmEncodeWorker
2626
from components.prefill_worker import VllmPrefillWorker
2727
from utils.logging import check_required_workers
28-
from utils.model import construct_mm_data, get_vision_embeddings_size
28+
from utils.model import construct_mm_data, get_vision_embeddings_info
2929
from utils.nixl import NixlMetadataStore
3030
from utils.prefill_queue import PrefillQueue
3131
from utils.protocol import (
@@ -117,7 +117,7 @@ async def async_init(self):
117117
)
118118

119119
runtime = dynamo_context["runtime"]
120-
embeddings_shape = get_vision_embeddings_size(
120+
embeddings_shape, embeddings_dtype = get_vision_embeddings_info(
121121
self.engine_args.model, self.engine_args.num_patches
122122
)
123123
logger.debug(f"Embeddings shape: {embeddings_shape}")
@@ -139,7 +139,6 @@ async def async_init(self):
139139
else:
140140
self.disaggregated_router = None
141141
else:
142-
EMBEDDINGS_DTYPE = torch.float16
143142
EMBEDDINGS_DEVICE = "cuda"
144143

145144
enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
@@ -155,7 +154,7 @@ async def async_init(self):
155154

156155
# Create a longer-lived buffer for receiving the image embeddings.
157156
embeddings = torch.empty(
158-
embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
157+
embeddings_shape, dtype=embeddings_dtype, device=EMBEDDINGS_DEVICE
159158
)
160159
descriptor = connect.Descriptor(embeddings)
161160
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).

examples/multimodal/components/encode_worker.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -165,27 +165,32 @@ async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]:
165165

166166
logger.debug(f"Processing image for request: {{ id: {request_id} }}")
167167
image_embeds = self.image_processor(images=image, return_tensors="pt")
168-
# Add a batch dimension to the pixel values
169-
image_embeds["pixel_values"] = (
170-
image_embeds["pixel_values"].unsqueeze(0).to(DEVICE)
171-
)
168+
# Add a batch dimension to everything
169+
for item in image_embeds:
170+
image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE)
172171
logger.debug(f"Image embeds: {image_embeds}")
173-
image_grid_thw = None
174-
if "image_grid_thw" in image_embeds:
175-
image_grid_thw = image_embeds["image_grid_thw"].tolist()
176-
image_sizes = [image.size]
172+
173+
image_grid_thw = (
174+
image_embeds["image_grid_thw"].tolist()
175+
if "image_grid_thw" in image_embeds
176+
else None
177+
)
178+
image_sizes = (
179+
image_embeds["image_sizes"].tolist()
180+
if "image_sizes" in image_embeds
181+
else [image.size]
182+
)
177183
logger.debug(
178184
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
179185
)
180186

181187
with torch.no_grad():
182188
embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds)
183-
if isinstance(embeddings, tuple):
184-
# The result multimodal_embeddings is tuple of tensors, with each
189+
if isinstance(embeddings, tuple) or isinstance(embeddings, list):
190+
# The result multimodal_embeddings may be a list or tuple of tensors, with each
185191
# tensor corresponding to a multimodal data item (image or video).
186192
# TODO: for multi-image support, this result will contain multiple tensors.
187193
embeddings = embeddings[0].unsqueeze(0)
188-
189194
logger.debug(
190195
f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
191196
)

examples/multimodal/components/prefill_worker.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from components.encode_worker import VllmEncodeWorker
2626
from pydantic import BaseModel
2727
from utils.logging import check_required_workers
28-
from utils.model import construct_mm_data, get_vision_embeddings_size
28+
from utils.model import construct_mm_data, get_vision_embeddings_info
2929
from utils.nixl import NixlMetadataStore
3030
from utils.prefill_queue import PrefillQueue
3131
from utils.protocol import EncodeRequest, EncodeResponse
@@ -40,8 +40,6 @@
4040

4141
logger = logging.getLogger(__name__)
4242

43-
# Constants for the dtype and device of the embeddings tensor.
44-
EMBEDDINGS_DTYPE = torch.float16
4543
EMBEDDINGS_DEVICE = "cuda"
4644

4745

@@ -113,12 +111,12 @@ async def async_init(self):
113111
await self._connector.initialize()
114112

115113
# Create a longer-lived buffer for receiving the image embeddings.
116-
embeddings_shape = get_vision_embeddings_size(
114+
embeddings_shape, embeddings_dtype = get_vision_embeddings_info(
117115
self.engine_args.model, self.engine_args.num_patches
118116
)
119117
embeddings = torch.empty(
120118
embeddings_shape,
121-
dtype=EMBEDDINGS_DTYPE,
119+
dtype=embeddings_dtype,
122120
device=EMBEDDINGS_DEVICE,
123121
)
124122
descriptor = connect.Descriptor(embeddings)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
Common:
16+
model: microsoft/Phi-3.5-vision-instruct
17+
block-size: 64
18+
max-model-len: 4096
19+
trust-remote-code: true
20+
21+
Processor:
22+
router: round-robin
23+
prompt-template: "<|user|>\n<|image_1|>\n<prompt><|end|>\n<|assistant|>\n"
24+
common-configs: [model, block-size, max-model-len, trust-remote-code]
25+
26+
VllmDecodeWorker:
27+
enforce-eager: true
28+
max-num-batched-tokens: 16384
29+
max-num-seqs: 2
30+
mm-processor-kwargs:
31+
num_crops: 16
32+
enable-prefix-caching: true
33+
image-token-id: 32000
34+
num-patches: 757
35+
router: random
36+
tensor-parallel-size: 1
37+
ServiceArgs:
38+
workers: 1
39+
resources:
40+
gpu: '1'
41+
common-configs: [model, block-size, max-model-len, trust-remote-code]
42+
43+
VllmEncodeWorker:
44+
tensor-parallel-size: 1
45+
router: random
46+
ServiceArgs:
47+
workers: 1
48+
resources:
49+
gpu: '1'
50+
common-configs: [model]

examples/multimodal/utils/model.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Any, Dict
16+
import logging
17+
from typing import Any, Dict, Tuple
1718

1819
import torch
1920
from transformers import AutoConfig
@@ -22,6 +23,8 @@
2223
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
2324
from vllm.worker.worker import Worker
2425

26+
logger = logging.getLogger(__name__)
27+
2528

2629
def load_vision_model(model_id: str) -> torch.nn.Module:
2730
"""
@@ -44,13 +47,24 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
4447
return worker.model_runner.model
4548

4649

47-
def get_vision_embeddings_size(model_id: str, num_patches: int) -> tuple[int, int, int]:
48-
"""Calculate vision embeddings size using model config and image processor
49-
Returns a tuple of (batch_size, num_patches, hidden_dim).
50+
def get_vision_embeddings_info(
51+
model_id: str, num_patches: int
52+
) -> Tuple[Tuple[int, int, int], torch.dtype]:
53+
"""Calculate vision embeddings size and dtype using model config
54+
Returns a tuple of (batch_size, num_patches, hidden_dim), dtype.
5055
"""
5156
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
5257
assert num_patches > 0, "Number of patches must be positive"
53-
return 1, num_patches, getattr(config, "hidden_size", 4096)
58+
if not hasattr(config, "torch_dtype"):
59+
raise ValueError("Model config missing required 'torch_dtype' attribute")
60+
if not hasattr(config, "hidden_size"):
61+
logger.warning(
62+
"Model config missing required 'hidden_size' attribute, using 4096"
63+
)
64+
hidden_size = 4096
65+
else:
66+
hidden_size = config.hidden_size
67+
return (1, num_patches, hidden_size), config.torch_dtype
5468

5569

5670
def construct_mm_data(
@@ -60,8 +74,8 @@ def construct_mm_data(
6074
if "Qwen2" in model:
6175
return {
6276
"image": {
63-
"image_embeds": image_embeds.squeeze(0),
64-
"image_grid_thw": torch.tensor(encode_output.image_grid_thw),
77+
"image_embeds": image_embeds.squeeze(0).to(torch.float16),
78+
"image_grid_thw": torch.tensor(encode_output.image_grid_thw).squeeze(0),
6579
}
6680
}
6781
elif "MiniCPM-V" in model:

examples/multimodal/utils/protocol.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
import json
18-
from typing import Any, List, Literal, Optional, Tuple, Union
18+
from typing import Any, List, Literal, Optional, Union
1919

2020
import connect
2121
import msgspec
@@ -143,7 +143,7 @@ class EncodeResponse(BaseModel):
143143
model_config = ConfigDict(arbitrary_types_allowed=True)
144144
request_id: str
145145
image_grid_thw: Optional[List[Any]] = None
146-
image_sizes: Optional[List[Tuple[int, int]]] = None
146+
image_sizes: Optional[List[Any]] = None
147147

148148

149149
class MyRequestOutput(BaseModel):

0 commit comments

Comments
 (0)