Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8d67b7d
feat: add more robust handling for MM prompt
hhzhang16 Jun 4, 2025
b65efb5
feat: [WIP] generalize workers
hhzhang16 Jun 4, 2025
40c2154
Merge branch 'main' of github.com:ai-dynamo/dynamo into hannahz/dep-1…
hhzhang16 Jun 4, 2025
e13f827
feat: remove cls token
hhzhang16 Jun 4, 2025
a866d73
Merge branch 'main' of github.com:ai-dynamo/dynamo into hannahz/dep-1…
hhzhang16 Jun 4, 2025
0adb7e6
Merge branch 'main' of github.com:ai-dynamo/dynamo into hannahz/dep-1…
hhzhang16 Jun 5, 2025
86c6135
Merge branch 'main' of github.com:ai-dynamo/dynamo into hannahz/dep-1…
hhzhang16 Jun 5, 2025
19f2158
Merge branch 'main' of github.com:ai-dynamo/dynamo into hannahz/dep-1…
hhzhang16 Jun 6, 2025
a766509
feat: working multimodal agg for multiple vision models
hhzhang16 Jun 7, 2025
17aecda
Merge branch 'main' of github.com:ai-dynamo/dynamo into hannahz/dep-1…
hhzhang16 Jun 7, 2025
496ee57
feat: addressing ci comments
hhzhang16 Jun 9, 2025
820c7e3
feat: addressing ci comments
hhzhang16 Jun 9, 2025
bb4f95e
Merge branch 'main' of github.com:ai-dynamo/dynamo into hannahz/dep-1…
hhzhang16 Jun 9, 2025
027341a
Update examples/multimodal/README.md
hhzhang16 Jun 9, 2025
0eff4e0
feat: trust remote code when loading autoconfig
hhzhang16 Jun 9, 2025
d736895
feat: working code for phi3v
hhzhang16 Jun 10, 2025
36eacb9
docs: add phi3v to multimodal readme
hhzhang16 Jun 10, 2025
d586343
feat: working for Qwen 2.5 VL
hhzhang16 Jun 11, 2025
d5025a7
docs: fixing dash issue
hhzhang16 Jun 11, 2025
1b0efc0
Merge branch 'main' into hannahz/dep-114-generalize-vlm-embedding-ext…
hhzhang16 Jun 11, 2025
843d586
docs: add readme note about disagg support
hhzhang16 Jun 11, 2025
d12e86d
Merge branch 'hannahz/dep-114-generalize-vlm-embedding-extraction' of…
hhzhang16 Jun 11, 2025
073ad67
feat: remove pynvml from this MR
hhzhang16 Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions examples/multimodal/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.
# Multimodal Deployment Examples

This directory provides example workflows and reference implementations for deploying a multimodal model using Dynamo.
The examples are based on the [llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model.

## Use the Latest Release

Expand Down Expand Up @@ -59,11 +58,15 @@ flowchart LR
decode_worker --image_url--> encode_worker
encode_worker --embeddings--> decode_worker
```
```

```bash
cd $DYNAMO_HOME/examples/multimodal
dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml
# Serve a LLaVA 1.5 7B model:
dynamo serve graphs.agg:Frontend -f ./configs/agg-llava.yaml
# Serve a Qwen2.5-VL model:
# dynamo serve graphs.agg:Frontend -f ./configs/agg-qwen.yaml
# Serve a Phi3V model:
# dynamo serve graphs.agg:Frontend -f ./configs/agg-phi3v.yaml
```

### Client
Expand Down Expand Up @@ -92,10 +95,13 @@ curl http://localhost:8000/v1/chat/completions \
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```

If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.

You should see a response similar to this:
```json
{"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]}
Expand Down Expand Up @@ -162,6 +168,7 @@ curl http://localhost:8000/v1/chat/completions \
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```
Expand All @@ -171,6 +178,8 @@ You should see a response similar to this:
{"id": "c1774d61-3299-4aa3-bea1-a0af6c055ba8", "object": "chat.completion", "created": 1747725645, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " This image shows a passenger bus traveling down the road near power lines and trees. The bus displays a sign that says \"OUT OF SERVICE\" on its front."}, "finish_reason": "stop"}]}
```

***Note***: disaggregation is currently only confirmed to work with LLaVA. Qwen VL and PhiV are not confirmed to be supported.

## Deployment with Dynamo Operator

These multimodal examples can be deployed to a Kubernetes cluster using [Dynamo Cloud](../../docs/guides/dynamo_deploy/dynamo_cloud.md) and the Dynamo CLI.
Expand Down Expand Up @@ -206,8 +215,12 @@ DYNAMO_TAG=$(dynamo build graphs.agg:Frontend | grep "Successfully built" | awk

# Deploy to Kubernetes
export DEPLOYMENT_NAME=multimodal-agg
# For aggregated serving:
dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg.yaml
# For aggregated serving with LLaVA:
dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-llava.yaml
# For aggregated serving with Qwen2.5-VL:
# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-qwen.yaml
# For aggregated serving with Phi3V:
# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/agg-phi3v.yaml
# For disaggregated serving:
# export DEPLOYMENT_NAME=multimodal-disagg
# dynamo deploy $DYNAMO_TAG -n $DEPLOYMENT_NAME -f ./configs/disagg.yaml
Expand Down Expand Up @@ -244,8 +257,11 @@ curl localhost:8000/v1/chat/completions \
}
],
"max_tokens": 300,
"temperature": 0.0,
"stream": false
}'
```

If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`.

For more details on managing deployments, testing, and troubleshooting, please refer to the [Operator Deployment Guide](../../docs/guides/dynamo_deploy/operator_deployment.md).
37 changes: 17 additions & 20 deletions examples/multimodal/components/decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from components.disagg_router import PyDisaggregatedRouter
from components.encode_worker import VllmEncodeWorker
from components.prefill_worker import VllmPrefillWorker
from transformers import LlavaForConditionalGeneration
from utils.logging import check_required_workers
from utils.model import construct_mm_data, get_vision_embeddings_info
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import (
Expand Down Expand Up @@ -117,6 +117,11 @@ async def async_init(self):
)

runtime = dynamo_context["runtime"]
embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
self.engine_args.model, self.engine_args.num_patches
)
logger.debug(f"Embeddings shape: {embeddings_shape}")
self.embedding_size = embeddings_shape[1]

if self.do_remote_prefill:
metadata = self.engine_client.nixl_metadata
Expand All @@ -133,18 +138,7 @@ async def async_init(self):
await self.disaggregated_router.async_init()
else:
self.disaggregated_router = None

model = LlavaForConditionalGeneration.from_pretrained(
self.engine_args.model,
device_map="auto",
torch_dtype=torch.bfloat16,
).eval()
vision_tower = model.vision_tower
self.embedding_size = (
vision_tower.vision_model.embeddings.position_embedding.num_embeddings
)
else:
EMBEDDINGS_SHAPE = (1, 577, 4096)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda"

Expand All @@ -161,7 +155,7 @@ async def async_init(self):

# Create a longer-lived buffer for receiving the image embeddings.
embeddings = torch.empty(
EMBEDDINGS_SHAPE, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
Expand Down Expand Up @@ -206,13 +200,15 @@ async def generate(self, request: vLLMMultimodalRequest):
multi_modal_data,
remote_prefill_params,
) = await self.remote_prefill(request)

else:
(
prompt_ids,
multi_modal_data,
remote_prefill_params,
) = await self.local_prefill(request)
logger.debug(f"Prompt ids: {prompt_ids}")
logger.debug(f"Multi modal data: {multi_modal_data}")
logger.debug(f"Remote prefill params: {remote_prefill_params}")

# rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA
Expand All @@ -227,7 +223,7 @@ async def generate(self, request: vLLMMultimodalRequest):
remote_prefill_params=remote_prefill_params,
):
logger.debug(
f"Yeilding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
f"Yielding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
)
yield MyRequestOutput(
request_id=response.request_id,
Expand Down Expand Up @@ -294,7 +290,9 @@ async def local_prefill(self, request: vLLMMultimodalRequest) -> tuple:
"Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
)
# When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
multi_modal_data = {"image": embeddings}
multi_modal_data = construct_mm_data(
self.engine_args.model, encode_output, embeddings, self.embeddings_dtype
)

return prompt_ids, multi_modal_data, remote_prefill_params

Expand Down Expand Up @@ -353,17 +351,16 @@ async def remote_prefill(self, request: vLLMMultimodalRequest) -> tuple:
# As a workaround, here we manually insert some placeholder dummy tokens based on the embedding size
# so that decode worker can pre-allocate the memory with the correct size.
# The structure of the prompt will be like: "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:".
# Since the "<image>" token is included in the prompt, only need to insert (embedding_size - 1) dummy tokens after the image token.
IMAGE_TOKEN_ID = 32000
# Since the "<image>" token is included in the prompt, only need to insert embedding_size dummy tokens after the image token.
DUMMY_TOKEN_ID = 0
# Find the index of the image token in the prompt token ids
image_token_index = request.engine_prompt["prompt_token_ids"].index(
IMAGE_TOKEN_ID
self.engine_args.image_token_id
)
dummy_token_index = image_token_index + 1
prompt_ids = (
request.engine_prompt["prompt_token_ids"][:dummy_token_index]
+ [DUMMY_TOKEN_ID] * (self.embedding_size - 1)
+ [DUMMY_TOKEN_ID] * self.embedding_size
+ request.engine_prompt["prompt_token_ids"][dummy_token_index:]
)
logger.debug(
Expand Down
43 changes: 29 additions & 14 deletions examples/multimodal/components/encode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import httpx
import torch
from PIL import Image
from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from transformers import AutoImageProcessor
from utils.model import load_vision_model
from utils.protocol import EncodeRequest, EncodeResponse
from utils.vllm import parse_vllm_args

Expand Down Expand Up @@ -66,10 +67,7 @@ def __init__(self) -> None:
self.image_processor = AutoImageProcessor.from_pretrained(
self.MODEL_ID, trust_remote_code=True
)

self.vision_model = LlavaForConditionalGeneration.from_pretrained(
self.MODEL_ID, device_map="auto", torch_dtype=torch.float16
).eval()
self.vision_model = load_vision_model(self.MODEL_ID)

self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
Expand Down Expand Up @@ -167,17 +165,32 @@ async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]:

logger.debug(f"Processing image for request: {{ id: {request_id} }}")
image_embeds = self.image_processor(images=image, return_tensors="pt")
# Add a batch dimension to everything
for item in image_embeds:
image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE)
logger.debug(f"Image embeds: {image_embeds}")

image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
image_sizes = (
image_embeds["image_sizes"].tolist()
if "image_sizes" in image_embeds
else [image.size]
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)

with torch.no_grad():
logger.debug(f"Vision model device: {self.vision_model.device}")
vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device)
)
logger.debug("Vision model completed.")

embeddings = vision_outputs.last_hidden_state
embeddings = self.vision_model.multi_modal_projector(embeddings)

embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds)
if isinstance(embeddings, tuple) or isinstance(embeddings, list):
# The result multimodal_embeddings may be a list or tuple of tensors, with each
# tensor corresponding to a multimodal data item (image or video).
# TODO: for multi-image support, this result will contain multiple tensors.
embeddings = embeddings[0].unsqueeze(0)
logger.debug(
f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
)
Expand All @@ -201,6 +214,8 @@ async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]:

yield EncodeResponse(
request_id=request.request_id,
image_grid_thw=image_grid_thw,
image_sizes=image_sizes,
).model_dump_json()
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
Expand Down
25 changes: 16 additions & 9 deletions examples/multimodal/components/prefill_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from components.encode_worker import VllmEncodeWorker
from pydantic import BaseModel
from utils.logging import check_required_workers
from utils.model import construct_mm_data, get_vision_embeddings_info
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import EncodeRequest, EncodeResponse
Expand All @@ -39,9 +40,6 @@

logger = logging.getLogger(__name__)

# Constants for the shape and dtype of the embeddings tensor.
EMBEDDINGS_SHAPE = (1, 577, 4096)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda"


Expand Down Expand Up @@ -113,9 +111,12 @@ async def async_init(self):
await self._connector.initialize()

# Create a longer-lived buffer for receiving the image embeddings.
embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info(
self.engine_args.model, self.engine_args.num_patches
)
embeddings = torch.empty(
EMBEDDINGS_SHAPE,
dtype=EMBEDDINGS_DTYPE,
embeddings_shape,
dtype=self.embeddings_dtype,
device=EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
Expand Down Expand Up @@ -248,10 +249,11 @@ async def generate(self, request: RemotePrefillRequest):
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens are inserted based on the embedding size in the worker.py.
# TODO: make this more flexible/model-dependent
IMAGE_TOKEN_ID = 32000
embedding_size = embeddings.shape[1]
padding_size = embedding_size - 1
image_token_index = request.prompt_token_ids.index(IMAGE_TOKEN_ID)
padding_size = embedding_size
image_token_index = request.prompt_token_ids.index(
self.engine_args.image_token_id
)
dummy_token_index = image_token_index + 1
prompt_token_ids = (
request.prompt_token_ids[:dummy_token_index]
Expand All @@ -262,7 +264,12 @@ async def generate(self, request: RemotePrefillRequest):
request_id=request_id,
prompt=TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data={"image": embeddings},
multi_modal_data=construct_mm_data(
self.engine_args.model,
encode_output,
embeddings,
self.embeddings_dtype,
),
),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
Expand Down
17 changes: 14 additions & 3 deletions examples/multimodal/components/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,19 @@ async def _generate_responses(
# The generate endpoint will be used by the frontend to handle incoming requests.
@endpoint()
async def generate(self, raw_request: MultiModalRequest):
prompt = str(self.engine_args.prompt_template).replace(
"<prompt>", raw_request.messages[0].content[0].text
)
# Ensure the configured template includes the placeholder
template = self.engine_args.prompt_template
if "<prompt>" not in template:
raise ValueError("prompt_template must contain '<prompt>' placeholder")

# Safely extract user text
try:
user_text = raw_request.messages[0].content[0].text
except (IndexError, AttributeError) as e:
raise ValueError(f"Invalid message structure: {e}")

prompt = template.replace("<prompt>", user_text)

msg = {
"role": "user",
"content": prompt,
Expand All @@ -201,6 +211,7 @@ async def generate(self, raw_request: MultiModalRequest):
messages=[msg],
stream=raw_request.stream,
max_tokens=raw_request.max_tokens,
temperature=raw_request.temperature,
request_id=str(uuid.uuid4()),
)
image_url = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ VllmDecodeWorker:
enforce-eager: true
max-num-batched-tokens: 16384
enable-prefix-caching: true
image-token-id: 32000
num-patches: 576
router: random
tensor-parallel-size: 1
ServiceArgs:
Expand Down
Loading
Loading