diff --git a/container/Dockerfile.vllm b/container/Dockerfile.vllm index bde2200daf..687fd5bb95 100644 --- a/container/Dockerfile.vllm +++ b/container/Dockerfile.vllm @@ -12,7 +12,7 @@ ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda" ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04" # Make sure to update the dependency version in pyproject.toml when updating this -ARG VLLM_REF="f4135232b9a8c4845f8961fb1cd17581c56ae2ce" +ARG VLLM_REF="ba81acbdc1eec643ba815a76628ae3e4b2263b76" ARG TORCH_BACKEND="cu128" # Match 0.10.0 vLLM release @@ -186,6 +186,7 @@ RUN if [ "$ARCH" = "arm64" ]; then \ # Install vllm - keep this early in Dockerfile to avoid # rebuilds from unrelated source code changes ARG VLLM_REF +ARG VLLM_GIT_URL ARG DEEPGEMM_REF ARG FLASHINF_REF diff --git a/container/deps/vllm/install_vllm.sh b/container/deps/vllm/install_vllm.sh index baacebdd15..2f95a0d599 100755 --- a/container/deps/vllm/install_vllm.sh +++ b/container/deps/vllm/install_vllm.sh @@ -20,7 +20,8 @@ set -euo pipefail # Parse arguments EDITABLE=true -VLLM_REF="f4135232b9a8c4845f8961fb1cd17581c56ae2ce" +VLLM_REF="ba81acbdc1eec643ba815a76628ae3e4b2263b76" +VLLM_GIT_URL="https://github.com/vllm-project/vllm.git" MAX_JOBS=16 INSTALLATION_DIR=/tmp ARCH=$(uname -m) @@ -49,6 +50,10 @@ while [[ $# -gt 0 ]]; do VLLM_REF="$2" shift 2 ;; + --vllm-git-url) + VLLM_GIT_URL="$2" + shift 2 + ;; --max-jobs) MAX_JOBS="$2" shift 2 @@ -113,7 +118,7 @@ uv pip install lmcache # Create vllm directory and clone mkdir -p $INSTALLATION_DIR cd $INSTALLATION_DIR -git clone https://github.com/vllm-project/vllm.git +git clone $VLLM_GIT_URL vllm cd vllm git checkout $VLLM_REF @@ -148,7 +153,7 @@ fi # Install ep_kernels and DeepGEMM echo "Installing ep_kernels and DeepGEMM" cd tools/ep_kernels -bash install_python_libraries.sh # These libraries aren't pinned. +TORCH_CUDA_ARCH_LIST="9.0;10.0" bash install_python_libraries.sh # These libraries aren't pinned. cd ep_kernels_workspace git clone https://github.com/deepseek-ai/DeepGEMM.git cd DeepGEMM diff --git a/examples/deployments/router_standalone/worker.py b/examples/deployments/router_standalone/worker.py index 5b73893b6e..488830c5f3 100644 --- a/examples/deployments/router_standalone/worker.py +++ b/examples/deployments/router_standalone/worker.py @@ -42,7 +42,10 @@ def __init__(self, port: int) -> None: logger.info(f"ZMQ publisher initialized on port {port}") def record( - self, scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats] + self, + scheduler_stats: SchedulerStats, + iteration_stats: Optional[IterationStats], + engine_idx: int = 0, ): # Send metrics over ZMQ metrics_data = { diff --git a/examples/multimodal_v1/README.md b/examples/multimodal_v1/README.md new file mode 100644 index 0000000000..ba6bd448e2 --- /dev/null +++ b/examples/multimodal_v1/README.md @@ -0,0 +1,328 @@ + + +# Multimodal Deployment Examples + +This directory provides example workflows and reference implementations for deploying a multimodal model using Dynamo and vLLM v1. + +## Use the Latest Release + +We recommend using the latest stable release of dynamo to avoid breaking changes: + +[![GitHub Release](https://img.shields.io/github/v/release/ai-dynamo/dynamo)](https://github.com/ai-dynamo/dynamo/releases/latest) + +You can find the latest release [here](https://github.com/ai-dynamo/dynamo/releases/latest) and check out the corresponding branch with: + +```bash +git checkout $(git describe --tags $(git rev-list --tags --max-count=1)) +``` + +## Multimodal Aggregated Serving + +### Components + +- workers: For aggregated serving, we have two workers, [VllmEncodeWorker](components/encode_worker.py) for encoding and [VllmPDWorker](components/worker.py) for prefilling and decoding. +- processor: Tokenizes the prompt and passes it to the VllmEncodeWorker. +- frontend: HTTP endpoint to handle incoming requests. + +### Graph + +In this graph, we have two workers, [VllmEncodeWorker](components/encode_worker.py) and [VllmPDWorker](components/worker.py). +The VllmEncodeWorker is responsible for encoding the image and passing the embeddings to the VllmPDWorker via a combination of NATS and RDMA. +The work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface. +Its VllmPDWorker then prefills and decodes the prompt, just like the [LLM aggregated serving](/components/backends/vllm/README.md) example. +By separating the encode from the prefill and decode stages, we can have a more flexible deployment and scale the +VllmEncodeWorker independently from the prefill and decode workers if needed. + +This figure shows the flow of the graph: +```mermaid +flowchart LR + HTTP --> processor + processor --> HTTP + processor --image_url--> encode_worker + encode_worker --> processor + encode_worker --embeddings--> pd_worker + pd_worker --> encode_worker +``` + +```bash +cd $DYNAMO_HOME/examples/multimodal_v1 +# Serve a LLaVA 1.5 7B model: +bash launch/agg.sh --model llava-hf/llava-1.5-7b-hf +# Serve a Qwen2.5-VL model: +# bash launch/agg.sh --model Qwen/Qwen2.5-VL-7B-Instruct +# Serve a Phi3V model: +# bash launch/agg.sh --model microsoft/Phi-3.5-vision-instruct +``` + +### Client + +In another terminal: +```bash +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llava-hf/llava-1.5-7b-hf", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "http://images.cocodataset.org/test2017/000000155781.jpg" + } + } + ] + } + ], + "max_tokens": 300, + "temperature": 0.0, + "stream": false + }' +``` + +If serving the example Qwen model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"Qwen/Qwen2.5-VL-7B-Instruct"`. If serving the example Phi3V model, replace `"llava-hf/llava-1.5-7b-hf"` in the `"model"` field with `"microsoft/Phi-3.5-vision-instruct"`. + +You should see a response similar to this: +```json +{"id": "c37b946e-9e58-4d54-88c8-2dbd92c47b0c", "object": "chat.completion", "created": 1747725277, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " In the image, there is a city bus parked on a street, with a street sign nearby on the right side. The bus appears to be stopped out of service. The setting is in a foggy city, giving it a slightly moody atmosphere."}, "finish_reason": "stop"}]} +``` + +## Multimodal Disaggregated Serving + +### Components + +- workers: For disaggregated serving, we have three workers, [VllmEncodeWorker](components/encode_worker.py) for encoding, [VllmDecodeWorker](components/worker.py) for decoding, and [VllmPDWorker](components/worker.py) for prefilling. +- processor: Tokenizes the prompt and passes it to the VllmEncodeWorker. +- frontend: HTTP endpoint to handle incoming requests. + +### Graph + +In this graph, we have three workers, [VllmEncodeWorker](components/encode_worker.py), [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py). +For the Llava model, embeddings are only required during the prefill stage. As such, the VllmEncodeWorker is connected directly to the prefill worker. +The VllmEncodeWorker is responsible for encoding the image and passing the embeddings to the prefill worker via a combination of NATS and RDMA. +Its work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface. +The prefill worker performs the prefilling step and forwards the KV cache to the decode worker for decoding. +For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](/components/backends/vllm/README.md) example. + +This figure shows the flow of the graph: +```mermaid +flowchart LR + HTTP --> processor + processor --> HTTP + processor --image_url--> encode_worker + encode_worker --> processor + encode_worker --embeddings--> prefill_worker + prefill_worker --> encode_worker + prefill_worker --> decode_worker + decode_worker --> prefill_worker +``` + +```bash +cd $DYNAMO_HOME/examples/multimodal_v1 +bash launch/disagg.sh --model llava-hf/llava-1.5-7b-hf +``` + +### Client + +In another terminal: +```bash +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llava-hf/llava-1.5-7b-hf", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "http://images.cocodataset.org/test2017/000000155781.jpg" + } + } + ] + } + ], + "max_tokens": 300, + "temperature": 0.0, + "stream": false + }' +``` + +You should see a response similar to this: +```json +{"id": "c1774d61-3299-4aa3-bea1-a0af6c055ba8", "object": "chat.completion", "created": 1747725645, "model": "llava-hf/llava-1.5-7b-hf", "choices": [{"index": 0, "message": {"role": "assistant", "content": " This image shows a passenger bus traveling down the road near power lines and trees. The bus displays a sign that says \"OUT OF SERVICE\" on its front."}, "finish_reason": "stop"}]} +``` + +***Note***: disaggregation is currently only confirmed to work with LLaVA. Qwen VL and PhiV are not confirmed to be supported. + +## Llama 4 family Serving + +The family of Llama 4 models is natively multimodal, however, different +from Llava, they do not directly consume image embedding as input +(see the [support metrics](https://docs.vllm.ai/en/latest/models/supported_models.html#text-generation_1) +from vLLM for the types of multi-modal inputs supported by the model). +Therefore, encoder worker will not be used in the following example and the +encoding will be done along side with prefill. + +`meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8` will be used as an example +for the content below. And the system will be H100x8 which can hold one instance +of the model per node. + +### Multimodal Aggregated Serving + +#### Components + +- workers: For aggregated serving, we have one worker, [VllmPDWorker](components/worker.py) for prefilling and decoding. +- processor: Tokenizes the prompt and passes it to the VllmPDWorker. +- frontend: HTTP endpoint to handle incoming requests. + +#### Graph + +In this graph, we have [VllmPDWorker](components/worker.py) which will encode the image, prefill and decode the prompt, just like the [LLM aggregated serving](/components/backends/vllm/README.md) example. + +This figure shows the flow of the graph: +```mermaid +flowchart LR + HTTP --> processor + processor --> HTTP + processor --image_url--> pd_worker + pd_worker --> processor +``` + +```bash +cd $DYNAMO_HOME/examples/multimodal_v1 +bash launch/agg_llama.sh +``` + +#### Client + +In another terminal: +```bash +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "http://images.cocodataset.org/test2017/000000155781.jpg" + } + } + ] + } + ], + "max_tokens": 300, + "temperature": 0.0, + "stream": false + }' +``` + +You should see a response similar to this: +```json +{"id": "b8f060fa95584e34b9204eaba7b105cc", "object": "chat.completion", "created": 1752706281, "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "choices": [{"index": 0, "message": {"role": "assistant", "content": "The image depicts a street scene with a trolley bus as the central focus. The trolley bus is positioned on the left side of the road, facing the camera, and features a white and yellow color scheme. A prominent sign on the front of the bus reads \"OUT OF SERVICE\" in orange letters.\n\n**Key Elements:**\n\n* **Trolley Bus:** The bus is the main subject of the image, showcasing its distinctive design and color.\n* **Sign:** The \"OUT OF SERVICE\" sign is clearly visible on the front of the bus, indicating its current status.\n* **Street Scene:** The surrounding environment includes trees, buildings, and power lines, creating a sense of context and atmosphere.\n* **Lighting:** The image is characterized by a misty or foggy quality, with soft lighting that adds to the overall ambiance.\n\n**Overall Impression:**\n\nThe image presents a serene and somewhat melancholic scene, with the out-of-service trolley bus serving as a focal point. The misty atmosphere and soft lighting contribute to a dreamy or nostalgic feel, inviting the viewer to reflect on the scene."}, "finish_reason": "stop"}]} +``` + +### Multimodal Disaggregated Serving + +#### Components + +- workers: For disaggregated serving, we have two workers, [VllmDecodeWorker](components/worker.py) for decoding, and [VllmPDWorker](components/worker.py) for encoding and prefilling. +- processor: Tokenizes the prompt and passes it to the VllmPDWorker. +- frontend: HTTP endpoint to handle incoming requests. + +#### Graph + +In this graph, we have two workers, [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py). +The prefill worker performs the encoding and prefilling steps and forwards the KV cache to the decode worker for decoding. +For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](/components/backends/vllm/README.md) example. + +This figure shows the flow of the graph: +```mermaid +flowchart LR + HTTP --> processor + processor --> HTTP + processor --image_url--> prefill_worker + prefill_worker --> processor + prefill_worker --> decode_worker + decode_worker --> prefill_worker +``` + +```bash +cd $DYNAMO_HOME/examples/multimodal_v1 +bash launch/disagg_llama.sh --head-node + +# On a separate node that has finished standard dynamo setup, i.e. +# the worker node needs NATS_SERVER and ETCD_ENDPOINTS environment variables +# pointing to the head node's external IP address for distributed coordination +cd $DYNAMO_HOME/examples/multimodal_v1 +bash launch/disagg_llama.sh +``` + +#### Client + +In another terminal: +```bash +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "http://images.cocodataset.org/test2017/000000155781.jpg" + } + } + ] + } + ], + "max_tokens": 300, + "temperature": 0.0, + "stream": false + }' +``` + +You should see a response similar to this: +```json +{"id": "6cc99123ad6948d685b8695428238d4b", "object": "chat.completion", "created": 1752708043, "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "choices": [{"index": 0, "message": {"role": "assistant", "content": "The image depicts a street scene with a trolley bus as the central focus. The trolley bus is positioned on the left side of the road, facing the camera, and features a white and yellow color scheme. A prominent sign on the front of the bus reads \"OUT OF SERVICE\" in orange letters.\n\n**Key Elements:**\n\n* **Trolley Bus:** The bus is the main subject of the image, showcasing its distinctive design and color.\n* **Sign:** The \"OUT OF SERVICE\" sign is clearly visible on the front of the bus, indicating its current status.\n* **Street Scene:** The surrounding environment includes trees, buildings, and power lines, creating a sense of context and atmosphere.\n* **Lighting:** The image is characterized by a misty or foggy quality, with soft lighting that adds to the overall mood.\n\n**Overall Impression:**\n\nThe image presents a serene and somewhat melancholic scene, with the out-of-service trolley bus serving as a focal point. The misty atmosphere and soft lighting contribute to a contemplative ambiance, inviting the viewer to reflect on the situation."}, "finish_reason": "stop"}]} +``` diff --git a/examples/multimodal_v1/components/encode_worker.py b/examples/multimodal_v1/components/encode_worker.py new file mode 100644 index 0000000000..7294df9408 --- /dev/null +++ b/examples/multimodal_v1/components/encode_worker.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +import logging +import os +import signal +import sys +from typing import AsyncIterator, Tuple + +import torch +import uvloop +from transformers import AutoImageProcessor, LlavaForConditionalGeneration +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.utils import FlexibleArgumentParser + +from dynamo.runtime import DistributedRuntime, dynamo_worker +from dynamo.runtime.logging import configure_dynamo_logging + +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +import connect +from utils.args import Config, base_parse_args, parse_endpoint +from utils.image_loader import ImageLoader +from utils.protocol import MyRequestOutput, vLLMMultimodalRequest + +configure_dynamo_logging() +logger = logging.getLogger(__name__) + +try: + import cupy as array_module + + if not array_module.cuda.is_available(): + raise ImportError("CUDA is not available.") + DEVICE = "cuda" + logger.info("Using cupy for array operations (GPU mode).") +except ImportError as e: + logger.warning(f"Failed to import cupy, falling back to numpy: {e}.") + import numpy as array_module + + DEVICE = "cpu" + +CACHE_SIZE_MAXIMUM = 8 + + +class VllmEncodeWorker: + def __init__(self, args: argparse.Namespace, engine_args: AsyncEngineArgs) -> None: + self.downstream_endpoint = args.downstream_endpoint + self.engine_args = engine_args + self.model = self.engine_args.model + + self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM) + self.image_processor = AutoImageProcessor.from_pretrained( + self.model, trust_remote_code=True + ) + # self.vision_model = load_vision_model(self.model) + self.vision_model = LlavaForConditionalGeneration.from_pretrained( + self.model, device_map="auto", torch_dtype=torch.float16 + ).eval() + + self.min_workers = 1 + + def cleanup(self): + pass + + async def generate( + self, request: vLLMMultimodalRequest + ) -> AsyncIterator[MyRequestOutput]: + logger.debug(f"Got raw request: {request}") + if not isinstance(request, vLLMMultimodalRequest): + if isinstance(request, str): + request = vLLMMultimodalRequest.model_validate_json(request) + else: + request = vLLMMultimodalRequest.model_validate(request) + logger.debug(f"Received encode request: {{ id: {request.request_id} }}.") + + request_id = request.request_id + + # The following steps encode the requested image and provided useful embeddings. + # 1. Open the image from the provided URL. + # 2. Process the image using the image processor. + # 3. Run the image through the vision model's vision tower. + # 4. Run the results of the vision tower through the multi-modal projector. + # 5. Create a descriptor for the embeddings. + # 6. Create a write operation using the serialized request and the descriptor. + # 7. Await for the write operation to complete. + # 8. Yield the encode response. + + try: + image = await self.image_loader.load_image(request.image_url) + + logger.debug(f"Processing image for request: {{ id: {request_id} }}") + image_embeds = self.image_processor(images=image, return_tensors="pt") + # [gluo NOTE] The commented section is for VLM generalization support, + # will use more generic approach once utils/model.py is fixed, + # see utils/models.py for details. + # # Add a batch dimension to everything + # for item in image_embeds: + # image_embeds[item] = image_embeds[item].unsqueeze(0).to(DEVICE) + # logger.debug(f"Image embeds: {image_embeds}") + + # image_grid_thw = ( + # image_embeds["image_grid_thw"].tolist() + # if "image_grid_thw" in image_embeds + # else None + # ) + # image_sizes = ( + # image_embeds["image_sizes"].tolist() + # if "image_sizes" in image_embeds + # else [image.size] + # ) + # logger.debug( + # f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}" + # ) + + # with torch.no_grad(): + # embeddings = self.vision_model.get_multimodal_embeddings(**image_embeds) + # if isinstance(embeddings, tuple) or isinstance(embeddings, list): + # # The result multimodal_embeddings may be a list or tuple of tensors, with each + # # tensor corresponding to a multimodal data item (image or video). + # # TODO: for multi-image support, this result will contain multiple tensors. + # embeddings = embeddings[0].unsqueeze(0) + # logger.debug( + # f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}." + # ) + + with torch.no_grad(): + logger.debug(f"Vision model device: {self.vision_model.device}") + vision_outputs = self.vision_model.vision_tower( + image_embeds["pixel_values"].to(self.vision_model.device) + ) + logger.debug("Vision model completed.") + + embeddings = vision_outputs.last_hidden_state + embeddings = self.vision_model.multi_modal_projector(embeddings) + + descriptor = connect.Descriptor(embeddings) + + with self._connector.create_readable(descriptor) as readable: + request.serialized_request = readable.to_serialized() + # Clear the image URL as hint that the image is passed as embeddings. + request.image_url = None + + logger.debug(f"Request: {request.model_dump_json()}") + + # Get the response generator + response_generator = await self.pd_worker_client.round_robin( + request.model_dump_json() + ) + await readable.wait_for_completion() + + async for response in response_generator: + output = MyRequestOutput.model_validate_json(response.data()) + yield MyRequestOutput( + request_id=output.request_id, + prompt=output.prompt, + prompt_token_ids=output.prompt_token_ids, + prompt_logprobs=output.prompt_logprobs, + outputs=output.outputs, + finished=output.finished, + ).model_dump_json() + + except Exception as e: + logger.error(f"Error processing request {request_id}: {e}") + raise + + async def async_init(self, runtime: DistributedRuntime): + logger.info("Startup started.") + parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint( + self.downstream_endpoint + ) + self.pd_worker_client = ( + await runtime.namespace(parsed_namespace) + .component(parsed_component_name) + .endpoint(parsed_endpoint_name) + .client() + ) + + # Create and initialize a dynamo connector for this worker. + # We'll needs this to move data between this worker and remote workers efficiently. + self._connector = connect.Connector(runtime=runtime, namespace=parsed_namespace) + await self._connector.initialize() + + logger.info("Startup completed.") + + @classmethod + def parse_args(cls) -> Tuple[argparse.Namespace, Config]: + DEFAULT_ENDPOINT = "dyn://dynamo.encoder.generate" + DEFAULT_DOWNSTREAM_ENDPOINT = "dyn://dynamo.llm.generate" + + parser = FlexibleArgumentParser( + description="vLLM based encoder for Dynamo LLM." + ) + parser.add_argument( + "--endpoint", + type=str, + default=DEFAULT_ENDPOINT, + help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'", + ) + parser.add_argument( + "--downstream-endpoint", + type=str, + default=DEFAULT_DOWNSTREAM_ENDPOINT, + help=f"The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'", + ) + + args, config = base_parse_args(parser) + + return args, config + + +async def graceful_shutdown(runtime): + """ + By calling `runtime.shutdown()`, the endpoints will immediately be unavailable. + However, in-flight requests will still be processed until they are finished. + After all in-flight requests are finished, the `serve_endpoint` functions will return + and the engine will be shutdown by Python's garbage collector. + """ + logging.info("Received shutdown signal, shutting down DistributedRuntime") + runtime.shutdown() + logging.info("DistributedRuntime shutdown complete") + + +@dynamo_worker(static=False) +async def worker(runtime: DistributedRuntime): + # Runtime setup + # Set up signal handler for graceful shutdown + loop = asyncio.get_running_loop() + + def signal_handler(): + asyncio.create_task(graceful_shutdown(runtime)) + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, signal_handler) + + logging.info("Signal handlers set up for graceful shutdown") + + # worker setup + args, config = VllmEncodeWorker.parse_args() + await init(runtime, args, config) + + +async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config): + """ + Instantiate and serve + """ + + component = runtime.namespace(config.namespace).component(config.component) + await component.create_service() + + generate_endpoint = component.endpoint(config.endpoint) + + handler = VllmEncodeWorker(args, config.engine_args) + await handler.async_init(runtime) + + logger.info(f"Starting to serve the {args.endpoint} endpoint...") + + try: + await asyncio.gather( + generate_endpoint.serve_endpoint(handler.generate), + ) + except Exception as e: + logger.error(f"Failed to serve endpoints: {e}") + raise + finally: + handler.cleanup() + + +if __name__ == "__main__": + uvloop.install() + asyncio.run(worker()) diff --git a/examples/multimodal_v1/components/processor.py b/examples/multimodal_v1/components/processor.py new file mode 100644 index 0000000000..12de2a97b6 --- /dev/null +++ b/examples/multimodal_v1/components/processor.py @@ -0,0 +1,330 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +import json +import logging +import os +import signal +import sys +import uuid +from enum import Enum +from typing import AsyncIterator, Tuple, Union + +import uvloop +from transformers import AutoTokenizer +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest +from vllm.outputs import RequestOutput +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import FlexibleArgumentParser + +from dynamo.llm import ModelType, register_llm +from dynamo.runtime import DistributedRuntime, dynamo_worker +from dynamo.runtime.logging import configure_dynamo_logging + +# To import example local module +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +from utils.args import Config, base_parse_args, parse_endpoint +from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn +from utils.protocol import MultiModalRequest, MyRequestOutput, vLLMMultimodalRequest + +configure_dynamo_logging() +logger = logging.getLogger(__name__) + +prompt_template = "USER: \n ASSISTANT:" + + +class RequestType(Enum): + CHAT = "chat" + COMPLETION = "completion" + + +class Processor(ProcessMixIn): + """ + vLLM pre and post processing + """ + + @classmethod + def parse_args(cls) -> Tuple[argparse.Namespace, Config]: + DEFAULT_ENDPOINT = "dyn://dynamo.processor.generate" + DEFAULT_DOWNSTREAM_ENDPOINT = "dyn://dynamo.encoder.generate" + + parser = FlexibleArgumentParser( + description="vLLM based processor for Dynamo LLM." + ) + parser.add_argument( + "--prompt-template", + type=str, + required=True, + help=( + "Different multi-modal models expect the prompt to contain different special media prompts. " + "The processor will use this argument to construct the final prompt. " + "User prompt will replace '' in the provided template. " + "For example, if the user prompt is 'please describe the image' and the prompt template is " + "'USER: ASSISTANT:', the resulting prompt is " + "'USER: please describe the image ASSISTANT:'." + ), + ) + parser.add_argument( + "--endpoint", + type=str, + default=DEFAULT_ENDPOINT, + help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'", + ) + parser.add_argument( + "--downstream-endpoint", + type=str, + default=DEFAULT_DOWNSTREAM_ENDPOINT, + help=f"The endpoint string of the downstream encoder in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'", + ) + + args, config = base_parse_args(parser) + + return args, config + + def __init__(self, args: argparse.Namespace, engine_args: AsyncEngineArgs): + self.prompt_template = args.prompt_template + self.downstream_endpoint = args.downstream_endpoint + self.engine_args = engine_args + self.model_config = self.engine_args.create_model_config() + self.default_sampling_params = self.model_config.get_diff_sampling_param() + self.tokenizer = self._create_tokenizer(self.engine_args) + self.chat_processor = ChatProcessor(self.tokenizer, self.model_config) + self.completions_processor = CompletionsProcessor( + self.tokenizer, self.model_config + ) + + def cleanup(self): + pass + + def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer: + """Create a TokenizerGroup using engine arguments similar to VLLM's approach""" + model_path = engine_args.model + + # Create the base tokenizer with VLLM's typical settings + base_tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + padding_side="left", + truncation_side="left", + use_fast=True, # VLLM might use the fast tokenizer for efficiency + ) + return base_tokenizer + + async def async_init(self, runtime: DistributedRuntime): + parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint( + self.downstream_endpoint + ) + self.encode_worker_client = ( + await runtime.namespace(parsed_namespace) + .component(parsed_component_name) + .endpoint(parsed_endpoint_name) + .client() + ) + + # Main method to parse the request and send the request to the vllm worker. + async def _generate( + self, + raw_request: Union[CompletionRequest, ChatCompletionRequest], + image: str, + request_type: RequestType, + ): + request_id = str(uuid.uuid4().hex) + logger.debug(f"Got raw request: {raw_request}") + ( + request, + conversation, + prompt, + engine_prompt, + sampling_params, + ) = await self._parse_raw_request(raw_request) + + worker_request = vLLMMultimodalRequest( + engine_prompt=engine_prompt, + sampling_params=sampling_params, + request_id=request_id, + image_url=image, + ) + + # model_dump_json() serializes the request to JSON string + # This API could accept Pydantic class, but SamplingParams + # in vLLMMultimodalRequest is not a Pydantic class and will + # cause TypeError: unsupported type SamplingParams + response_generator = await self.encode_worker_client.round_robin( + worker_request.model_dump_json() + ) + + output = self._generate_responses(response_generator, request_type) + + # Stream the processed responses + async for response in await self._stream_response( + request, output, request_id, conversation + ): + yield response + + # This method is used to process the responses from the engine generator. + async def _generate_responses( + self, + response_generator: AsyncIterator[RequestOutput], + request_type: RequestType, + ): + async for resp in response_generator: + # Deserialize the response from the engine + # Creates correct vLLM objects for each field + output = MyRequestOutput.model_validate_json(resp.data()) + + # OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object + request_output = RequestOutput( + request_id=output.request_id, + prompt=output.prompt, + prompt_token_ids=output.prompt_token_ids, + prompt_logprobs=output.prompt_logprobs, + outputs=output.outputs, + finished=output.finished, + metrics=output.metrics, + ) + + if request_type == RequestType.CHAT: + # For chat requests, yield the request_output directly. + yield request_output + else: + raise NotImplementedError( + f"Request type {request_type} not implemented" + ) + + # The generate endpoint will be used by the frontend to handle incoming requests. + async def generate(self, raw_request: MultiModalRequest): + logger.debug(f"Got raw request: {raw_request}") + if not isinstance(raw_request, MultiModalRequest): + # If the request is not MultiModalRequest, convert it to MultiModalRequest + raw_request = MultiModalRequest.model_validate(raw_request) + + # Ensure the configured template includes the placeholder + template = self.prompt_template + if "" not in template: + raise ValueError("prompt_template must contain '' placeholder") + + # Safely extract user text + try: + user_text = raw_request.messages[0].content[0].text + except (IndexError, AttributeError) as e: + raise ValueError(f"Invalid message structure: {e}") + + prompt = template.replace("", user_text) + + msg = { + "role": "user", + "content": prompt, + } + + chat_request = ChatCompletionRequest( + model=raw_request.model, + messages=[msg], + stream=raw_request.stream, + max_tokens=raw_request.max_tokens, + temperature=raw_request.temperature, + request_id=str(uuid.uuid4()), + ) + image_url = None + + for message in raw_request.messages: + for item in message.content: + if item.type == "image_url": + image_url = item.image_url.url + if image_url is None: + raise ValueError("Image URL is required") + + async for response in self._generate(chat_request, image_url, RequestType.CHAT): + logger.debug( + f"Generated response type {type(response)}, content: {response}" + ) + # reconstructing back the OpenAI chat response as dynamo egress expects it + if response.startswith("data: [DONE]"): + break + response = json.loads(response.lstrip("data: ")) + yield response + + +async def graceful_shutdown(runtime): + """ + By calling `runtime.shutdown()`, the endpoints will immediately be unavailable. + However, in-flight requests will still be processed until they are finished. + After all in-flight requests are finished, the `serve_endpoint` functions will return + and the engine will be shutdown by Python's garbage collector. + """ + logging.info("Received shutdown signal, shutting down DistributedRuntime") + runtime.shutdown() + logging.info("DistributedRuntime shutdown complete") + + +@dynamo_worker(static=False) +async def worker(runtime: DistributedRuntime): + # Runtime setup + # Set up signal handler for graceful shutdown + loop = asyncio.get_running_loop() + + def signal_handler(): + asyncio.create_task(graceful_shutdown(runtime)) + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, signal_handler) + + logging.info("Signal handlers set up for graceful shutdown") + + # worker setup + args, config = Processor.parse_args() + await init(runtime, args, config) + + +async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config): + """ + Instantiate and serve + """ + + component = runtime.namespace(config.namespace).component(config.component) + await component.create_service() + + generate_endpoint = component.endpoint(config.endpoint) + + handler = Processor(args, config.engine_args) + await handler.async_init(runtime) + + # Register the endpoint as entrypoint to a model + await register_llm( + ModelType.Chat, # Custom processor is used and this type bypasses SDK processor + generate_endpoint, + config.model, + config.served_model_name, + kv_cache_block_size=config.engine_args.block_size, + ) + + logger.info(f"Starting to serve the {args.endpoint} endpoint...") + + try: + await asyncio.gather( + generate_endpoint.serve_endpoint(handler.generate), + ) + except Exception as e: + logger.error(f"Failed to serve endpoints: {e}") + raise + finally: + handler.cleanup() + + +if __name__ == "__main__": + uvloop.install() + asyncio.run(worker()) diff --git a/examples/multimodal_v1/components/publisher.py b/examples/multimodal_v1/components/publisher.py new file mode 100644 index 0000000000..4a7e5eee36 --- /dev/null +++ b/examples/multimodal_v1/components/publisher.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from vllm.config import VllmConfig +from vllm.v1.metrics.loggers import StatLoggerBase +from vllm.v1.metrics.stats import IterationStats, SchedulerStats + +from dynamo.llm import ( + ForwardPassMetrics, + KvStats, + SpecDecodeStats, + WorkerMetricsPublisher, + WorkerStats, +) +from dynamo.runtime import Component + + +class NullStatLogger(StatLoggerBase): + def __init__(self): + pass + + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0, + ): + pass + + def log_engine_initialized(self): + pass + + +class DynamoStatLoggerPublisher(StatLoggerBase): + """Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface.""" + + def __init__(self, component: Component, dp_rank: int) -> None: + self.inner = WorkerMetricsPublisher() + self.inner.create_endpoint(component) + self.dp_rank = dp_rank + self.num_gpu_block = 1 + self.request_total_slots = 1 + + # TODO: Remove this and pass as metadata through etcd + def set_num_gpu_block(self, num_blocks): + self.num_gpu_block = num_blocks + + # TODO: Remove this and pass as metadata through etcd + def set_num_request_total_slots(self, request_total_slots): + self.request_total_slots = request_total_slots + + def record( + self, + scheduler_stats: SchedulerStats, + iteration_stats: Optional[IterationStats], + engine_idx: int = 0, + ): + # request_total_slots and kv_total_blocks are properties of model + gpu + # we should only publish them once, not every metric update + # they should be part of some runtime metadata tied to MDC or put in etcd ? + hit_rate = 0 + if scheduler_stats.prefix_cache_stats.queries > 0: + hit_rate = ( + scheduler_stats.prefix_cache_stats.hits + / scheduler_stats.prefix_cache_stats.queries + ) + + worker_stats = WorkerStats( + request_active_slots=scheduler_stats.num_running_reqs, + request_total_slots=self.request_total_slots, + num_requests_waiting=scheduler_stats.num_waiting_reqs, + data_parallel_rank=self.dp_rank, + ) + + kv_stats = KvStats( + kv_active_blocks=int(self.num_gpu_block * scheduler_stats.kv_cache_usage), + kv_total_blocks=self.num_gpu_block, + gpu_cache_usage_perc=scheduler_stats.kv_cache_usage, + gpu_prefix_cache_hit_rate=hit_rate, # TODO: This is a point in time update, not cumulative. Will be problematic on router side if we try to use it. + ) + + spec_dec_stats = scheduler_stats.spec_decoding_stats + if spec_dec_stats: + spec_dec_stats = SpecDecodeStats( + num_spec_tokens=spec_dec_stats.num_spec_tokens, + num_drafts=spec_dec_stats.num_drafts, + num_draft_tokens=spec_dec_stats.num_draft_tokens, + num_accepted_tokens=spec_dec_stats.num_accepted_tokens, + num_accepted_tokens_per_pos=spec_dec_stats.num_accepted_tokens_per_pos, + ) + + metrics = ForwardPassMetrics( + worker_stats=worker_stats, + kv_stats=kv_stats, + spec_decode_stats=spec_dec_stats, + ) + + self.inner.publish(metrics) + + def init_publish(self): + worker_stats = WorkerStats( + request_active_slots=0, + request_total_slots=self.request_total_slots, + num_requests_waiting=0, + data_parallel_rank=self.dp_rank, + ) + + kv_stats = KvStats( + kv_active_blocks=0, + kv_total_blocks=self.num_gpu_block, + gpu_cache_usage_perc=0, + gpu_prefix_cache_hit_rate=0, + ) + + metrics = ForwardPassMetrics( + worker_stats=worker_stats, + kv_stats=kv_stats, + spec_decode_stats=None, + ) + + self.inner.publish(metrics) + + def log_engine_initialized(self) -> None: + pass + + +class StatLoggerFactory: + """Factory for creating stat logger publishers. Required by vLLM.""" + + def __init__(self, component: Component, dp_rank: int = 0) -> None: + self.component = component + self.created_logger: Optional[DynamoStatLoggerPublisher] = None + self.dp_rank = dp_rank + + def create_stat_logger(self, dp_rank: int) -> StatLoggerBase: + if self.dp_rank != dp_rank: + return NullStatLogger() + logger = DynamoStatLoggerPublisher(self.component, dp_rank) + self.created_logger = logger + + return logger + + def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase: + return self.create_stat_logger(dp_rank=dp_rank) + + # TODO Remove once we publish metadata to etcd + def set_num_gpu_blocks_all(self, num_blocks): + if self.created_logger: + self.created_logger.set_num_gpu_block(num_blocks) + + def set_request_total_slots_all(self, request_total_slots): + if self.created_logger: + self.created_logger.set_num_request_total_slots(request_total_slots) + + def init_publish(self): + if self.created_logger: + self.created_logger.init_publish() diff --git a/examples/multimodal_v1/components/worker.py b/examples/multimodal_v1/components/worker.py new file mode 100644 index 0000000000..b58082e89f --- /dev/null +++ b/examples/multimodal_v1/components/worker.py @@ -0,0 +1,461 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +import copy +import logging +import os +import signal +import sys +from typing import Tuple + +import torch +import uvloop +from transformers import AutoImageProcessor +from vllm.distributed.kv_events import ZmqEventPublisher +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.inputs.data import TokensPrompt +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser +from vllm.v1.engine.async_llm import AsyncLLM + +from dynamo.llm import ZmqKvEventPublisher, ZmqKvEventPublisherConfig +from dynamo.runtime import Component, DistributedRuntime, Endpoint, dynamo_worker +from dynamo.runtime.logging import configure_dynamo_logging + +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) +import connect +from publisher import StatLoggerFactory +from utils.args import ( + Config, + base_parse_args, + configure_ports_with_etcd, + overwrite_args, + parse_endpoint, +) +from utils.image_loader import ImageLoader +from utils.protocol import MyRequestOutput, vLLMMultimodalRequest + +configure_dynamo_logging() +logger = logging.getLogger(__name__) + + +class VllmBaseWorker: + @classmethod + def parse_args(cls) -> Tuple[argparse.Namespace, Config]: + parser = FlexibleArgumentParser( + description="vLLM based encoder for Dynamo LLM." + ) + parser.add_argument( + "--endpoint", + type=str, + help="Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default value will vary based on the worker type, see --worker-type for details.", + ) + parser.add_argument( + "--downstream-endpoint", + type=str, + help="The endpoint string of the downstream LLM in 'dyn://namespace.component.endpoint' format. Default value will vary based on the worker type, see --worker-type for details.", + ) + parser.add_argument( + "--worker-type", + type=str, + choices=["prefill", "decode", "encode_prefill"], + required=True, + help="Specify the type of worker. Must be one of: 'prefill', 'decode', 'encode_prefill'", + ) + parser.add_argument( + "--enable-disagg", + action="store_true", + help="Enable disaggregated mode, where prefill and decode are handled by separate workers." + " If not set, the '*prefill' worker type will handle both prefill and decode.", + ) + + # use endpoint_overwrite to set the default endpoint based on worker type + def endpoint_overwrite(args): + # default endpoint for this worker + if args.worker_type == "prefill": + args.endpoint = args.endpoint or "dyn://dynamo.llm.generate" + elif args.worker_type == "decode": + args.endpoint = args.endpoint or "dyn://dynamo.decoder.generate" + elif args.worker_type == "encode_prefill": + args.endpoint = args.endpoint or "dyn://dynamo.encoder.generate" + # set downstream endpoint for disaggregated workers + if args.enable_disagg: + args.downstream_endpoint = ( + args.downstream_endpoint or "dyn://dynamo.decoder.generate" + ) + + return args + + args, config = base_parse_args(parser, endpoint_overwrite) + + return args, config + + def __init__( + self, + args: argparse.Namespace, + engine_args: AsyncEngineArgs, + component: Component, + endpoint: Endpoint, + ): + self.enable_disagg = args.enable_disagg + self.endpoint = args.endpoint + self.downstream_endpoint = args.downstream_endpoint + self.engine_args = engine_args + self.setup_vllm_engine(component, endpoint) + + async def async_init(self, runtime: DistributedRuntime): + pass + + def setup_vllm_engine(self, component: Component, endpoint: Endpoint): + """Initialize the vLLM engine. + This method sets up the vLLM engine client, and configures the dynamo-aware KV + event publisher and metrics stats logger based on component and endpoint. + """ + + os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + # Load default sampling params from `generation_config.json` + self.default_sampling_params = ( + self.engine_args.create_model_config().get_diff_sampling_param() + ) + + # Taken from build_async_engine_client_from_engine_args() + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = self.engine_args.create_engine_config(usage_context=usage_context) + + # Create vLLM engine with metrics logger and KV event publisher attached + self.stats_logger = StatLoggerFactory( + component, self.engine_args.data_parallel_rank or 0 + ) + self.engine_client = AsyncLLM.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + stat_loggers=[self.stats_logger], + disable_log_requests=self.engine_args.disable_log_requests, + disable_log_stats=self.engine_args.disable_log_stats, + ) + + # TODO Hack to get data, move this to registering in ETCD + self.stats_logger.set_num_gpu_blocks_all( + vllm_config.cache_config.num_gpu_blocks + ) + self.stats_logger.set_request_total_slots_all( + vllm_config.scheduler_config.max_num_seqs + ) + self.stats_logger.init_publish() + + # TODO: We start off with a valid endpoint, then we increment it by dp_rank + # May no longer be valid. Lets remove the increment behavior from vLLM and here + zmq_endpoint = ZmqEventPublisher.offset_endpoint_port( + self.engine_args.kv_events_config.endpoint, + data_parallel_rank=self.engine_args.data_parallel_rank or 0, + ).replace("*", "127.0.0.1") + + zmq_config = ZmqKvEventPublisherConfig( + worker_id=endpoint.lease_id(), + kv_block_size=vllm_config.cache_config.block_size, + zmq_endpoint=zmq_endpoint, + ) + self.kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config) + + logger.info(f"Reading Events from {zmq_endpoint}") + + logger.info(f"VllmWorker for {self.engine_args.model} has been initialized") + + async def generate(self, request: vLLMMultimodalRequest): + raise NotImplementedError( + "This method should be implemented in subclasses to handle the generation logic." + ) + + async def clear_kv_blocks(self, request=None): + try: + await self.engine_client.reset_prefix_cache() + yield {"status": "success", "message": "KV cache cleared"} + except Exception as e: + yield {"status": "error", "message": str(e)} + + def cleanup(self): + """Override in subclasses if cleanup is needed.""" + pass + + +class VllmDecodeWorker(VllmBaseWorker): + async def generate(self, request: vLLMMultimodalRequest): + logger.debug(f"Got raw request: {request}") + if not isinstance(request, vLLMMultimodalRequest): + if isinstance(request, str): + request = vLLMMultimodalRequest.model_validate_json(request) + else: + request = vLLMMultimodalRequest.model_validate(request) + logger.debug(f"Received decode request: {{ id: {request.request_id} }}.") + + # Decode worker doesn't process embeddings, so we pass None or empty tensor + gen = self.engine_client.generate( + prompt=TokensPrompt( + prompt_token_ids=request.engine_prompt["prompt_token_ids"], + ), + sampling_params=request.sampling_params, + request_id=request.request_id, + ) + + async for response in gen: + logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}") + yield MyRequestOutput( + request_id=response.request_id, + prompt=response.prompt, + prompt_token_ids=response.prompt_token_ids, + prompt_logprobs=response.prompt_logprobs, + outputs=response.outputs, + finished=response.finished, + metrics=response.metrics, + kv_transfer_params=response.kv_transfer_params, + ).model_dump_json() + + +class VllmPDWorker(VllmBaseWorker): + async def async_init(self, runtime: DistributedRuntime): + logger.info("Startup started.") + + if self.enable_disagg: + ( + parsed_namespace, + parsed_component_name, + parsed_endpoint_name, + ) = parse_endpoint(self.downstream_endpoint) + self.decode_worker_client = ( + await runtime.namespace(parsed_namespace) + .component(parsed_component_name) + .endpoint(parsed_endpoint_name) + .client() + ) + + EMBEDDINGS_DTYPE = torch.float16 + EMBEDDINGS_DEVICE = "cpu" + # Create and initialize a dynamo connector for this worker. + # We'll needs this to move data between this worker and remote workers efficiently. + parsed_namespace, _, _ = parse_endpoint(self.endpoint) + self._connector = connect.Connector(runtime=runtime, namespace=parsed_namespace) + await self._connector.initialize() + + # embeddings_shape, self.embeddings_dtype = get_vision_embeddings_info( + # self.engine_args.model, self.engine_args.num_patches + # ) + # [gluo NOTE] Hardcoded for now, will use more generic approach once utils/model.py + # is fixed, see utils/models.py for details. + embeddings_shape = (1, 577, 4096) + logger.debug(f"Embeddings shape: {embeddings_shape}") + self.embedding_size = embeddings_shape[1] + + embeddings = torch.empty( + embeddings_shape, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE + ) + + descriptor = connect.Descriptor(embeddings) + + # Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically). + # descriptor.register_memory(self._connector) + self._embeddings_descriptor = (embeddings, descriptor) + + self.image_loader = ImageLoader() + self.image_processor = AutoImageProcessor.from_pretrained( + self.engine_args.model, trust_remote_code=True + ) + + logger.info("VllmPDWorker has been initialized") + + async def generate(self, request: vLLMMultimodalRequest): + logger.debug(f"Got raw request: {request}") + if type(request) is not vLLMMultimodalRequest: + if type(request) is str: + request = vLLMMultimodalRequest.model_validate_json(request) + else: + request = vLLMMultimodalRequest.model_validate(request) + logger.debug(f"Received PD request: {{ id: {request.request_id} }}.") + + if request.image_url is None: + # Process embeddings using the connector + embeddings, descriptor = self._embeddings_descriptor + + if descriptor is None: + raise RuntimeError( + "Descriptor is None in PD worker - cannot process embeddings" + ) + + read_op = await self._connector.begin_read( + request.serialized_request, descriptor + ) + await read_op.wait_for_completion() + logger.debug(f"in PD worker, image features: {embeddings}") + multi_modal_data = embeddings + else: + # Use PIL image instead of image embeddings + multi_modal_data = await self.image_loader.load_image(request.image_url) + # multi_modal_data = self.image_processor(images=image, return_tensors="pt")["pixel_values"].to(dtype=torch.float16) + # image input is expected to be (image_num, channel, height, width) + # logger.info(f"Image features shape: {multi_modal_data.shape}") + # multi_modal_data = multi_modal_data.unsqueeze(0) + + # Remove the image features from the request as they are not required + request.image_url = None + request.serialized_request = None + + pd_request = copy.deepcopy(request) + # Do prefill and remote decode if enable_disagg is true + if self.enable_disagg: + extra_args = pd_request.sampling_params.extra_args or {} + extra_args["kv_transfer_params"] = { + "do_remote_decode": True, + } + pd_request.sampling_params.extra_args = extra_args + pd_request.sampling_params.max_tokens = 1 + pd_request.sampling_params.min_tokens = 1 + + logger.debug("Prefill request: %s", pd_request) + + gen = self.engine_client.generate( + prompt=TokensPrompt( + prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"], + multi_modal_data={"image": multi_modal_data}, + ), + sampling_params=pd_request.sampling_params, + request_id=pd_request.request_id, + ) + + if self.enable_disagg: + decode_request = copy.deepcopy(request) + async for prefill_response in gen: + # Update the prompt token id in the decode request to the one + # in response, which has image templated filled in. So that + # the decode worker will fetch correct amount of KV blocks. + decode_request.engine_prompt[ + "prompt_token_ids" + ] = prefill_response.prompt_token_ids + logger.debug( + f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}" + ) + extra_args = decode_request.sampling_params.extra_args or {} + extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params + extra_args.pop("serialized_request", None) + decode_request.sampling_params.extra_args = extra_args + logger.debug("Decode request: %s", decode_request) + async for decode_response in await self.decode_worker_client.round_robin( + decode_request.model_dump_json() + ): + output = MyRequestOutput.model_validate_json(decode_response.data()) + yield MyRequestOutput( + request_id=output.request_id, + prompt=output.prompt, + prompt_token_ids=output.prompt_token_ids, + prompt_logprobs=output.prompt_logprobs, + outputs=output.outputs, + finished=output.finished, + metrics=output.metrics, + kv_transfer_params=output.kv_transfer_params, + ).model_dump_json() + + else: + async for response in gen: + logger.debug( + f"Response kv_transfer_params: {response.kv_transfer_params}" + ) + yield MyRequestOutput( + request_id=response.request_id, + prompt=response.prompt, + prompt_token_ids=response.prompt_token_ids, + prompt_logprobs=response.prompt_logprobs, + outputs=response.outputs, + finished=response.finished, + metrics=response.metrics, + kv_transfer_params=response.kv_transfer_params, + ).model_dump_json() + + +async def graceful_shutdown(runtime): + """ + By calling `runtime.shutdown()`, the endpoints will immediately be unavailable. + However, in-flight requests will still be processed until they are finished. + After all in-flight requests are finished, the `serve_endpoint` functions will return + and the engine will be shutdown by Python's garbage collector. + """ + logging.info("Received shutdown signal, shutting down DistributedRuntime") + runtime.shutdown() + logging.info("DistributedRuntime shutdown complete") + + +@dynamo_worker(static=False) +async def worker(runtime: DistributedRuntime): + # Runtime setup + # Set up signal handler for graceful shutdown + loop = asyncio.get_running_loop() + + def signal_handler(): + asyncio.create_task(graceful_shutdown(runtime)) + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, signal_handler) + + logging.info("Signal handlers set up for graceful shutdown") + + # worker setup + args, config = VllmBaseWorker.parse_args() + + # vLLM config overwrites + etcd_client = runtime.etcd_client() + await configure_ports_with_etcd(config, etcd_client) + overwrite_args(config) + await init(runtime, args, config) + + +async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config): + """ + Instantiate and serve + """ + + component = runtime.namespace(config.namespace).component(config.component) + await component.create_service() + + generate_endpoint = component.endpoint(config.endpoint) + clear_endpoint = component.endpoint("clear_kv_blocks") + + if args.worker_type in ["prefill", "encode_prefill"]: + handler: VllmBaseWorker = VllmPDWorker( + args, config.engine_args, component, generate_endpoint + ) + elif args.worker_type == "decode": + handler = VllmDecodeWorker( + args, config.engine_args, component, generate_endpoint + ) + await handler.async_init(runtime) + + logger.info(f"Starting to serve the {args.endpoint} endpoint...") + + try: + await asyncio.gather( + generate_endpoint.serve_endpoint(handler.generate), + clear_endpoint.serve_endpoint(handler.clear_kv_blocks), + ) + except Exception as e: + logger.error(f"Failed to serve endpoints: {e}") + raise + finally: + handler.cleanup() + + +if __name__ == "__main__": + uvloop.install() + asyncio.run(worker()) diff --git a/examples/multimodal_v1/connect/README.md b/examples/multimodal_v1/connect/README.md new file mode 100644 index 0000000000..e8df057d76 --- /dev/null +++ b/examples/multimodal_v1/connect/README.md @@ -0,0 +1,342 @@ + + +# Dynamo Connect + +Dynamo connect provides a Pythonic interface to the NIXL base RDMA subsystem via a set of Python classes. +The primary goal of this library to simplify the integration of NIXL based RDMA into inference applications. + +All operations using the Connect library begin with the [`Connector`](#connector) class and the type of operation required. +There are four types of supported operations: + + - **Register local readable memory**: + + Register local memory buffer(s) with the RDMA subsystem to enable a remote worker to read from. + + - **Register local writable memory**: + + Register local memory buffer(s) with the RDMA subsystem to enable a remote worker to write to. + + - **Read from registered, remote memory**: + + Read remote memory buffer(s), registered by a remote worker to be readable, into local memory buffer(s). + + - **Write to registered, remote memory**: + + Write local memory buffer(s) to remote memory buffer(s) registered by a remote worker to writable. + +By connecting correctly paired operations, high-throughput GPU Direct RDMA data transfers can be completed. +Given the list above, the correct pairing of operations would be 1 & 3 or 2 & 4. +Where one side is a "(read|write)-able operation" and the other is its correctly paired "(read|write) operation". +Specifically, a read operation must be paired with a readable operation, and a write operation must be paired with a writable operation. + +## Examples + +### Generic Example + +In the diagram below, Local creates a [`WritableOperation`](#writableoperation) intended to receive data from Remote. +Local then sends metadata about the requested RDMA operation to Remote. +Remote then uses the metadata to create a [`WriteOperation`](#writeoperation) which will perform the GPU Direct RDMA memory transfer from Remote's GPU memory to Local's GPU memory. + +```mermaid +--- +title: Write Operation Between Two Workers +--- +flowchart LR + c1[Remote] --"3: .begin_write()"--- WriteOperation + WriteOperation e1@=="4: GPU Direct RDMA"==> WritableOperation + WritableOperation --"1: .create_writable()"--- c2[Local] + c2 e2@--"2: RDMA Metadata via HTTP"--> c1 + e1@{ animate: true; } + e2@{ animate: true; } +``` + +### Multimodal Example + +In the case of the [Dynamo Multimodal Disaggregated Example](../README.md): + + 1. The HTTP frontend accepts a text prompt and a URL to an image. + + 2. The prompt and URL are then enqueued with the Processor before being dispatched to the first available Decode Worker. + + 3. Decode Worker then requests a Prefill Worker to provide key-value data for the LLM powering the Decode Worker. + + 4. Prefill Worker then requests that the image be processed and provided as embeddings by the Encode Worker. + + 5. Encode Worker acquires the image, processes it, performs inference on the image using a specialized vision model, and finally provides the embeddings to Prefill Worker. + + 6. Prefill Worker receives the embeddings from Encode Worker and generates a key-value cache (KV$) update for Decode Worker's LLM and writes the update directly to the GPU memory reserved for the data. + + 7. Finally, Decode Worker performs the requested inference. + +```mermaid +--- +title: Multimodal Disaggregated Workflow +--- +flowchart LR + p0[HTTP Frontend] i0@--"text prompt"-->p1[Processor] + p0 i1@--"url"-->p1 + p1 i2@--"prompt"-->dw[Decode Worker] + p1 i3@--"url"-->dw + dw i4@--"prompt"-->pw[Prefill Worker] + dw i5@--"url"-->pw + pw i6@--"url"-->ew[Encode Worker] + ew o0@=="image embeddings"==>pw + pw o1@=="kv_cache updates"==>dw + dw o2@--"inference results"-->p0 + + i0@{ animate: true; } + i1@{ animate: true; } + i2@{ animate: true; } + i3@{ animate: true; } + i4@{ animate: true; } + i5@{ animate: true; } + i6@{ animate: true; } + o0@{ animate: true; } + o1@{ animate: true; } + o2@{ animate: true; } +``` + + _Note: In this example, it is the data transfer between the Prefill Worker and the Encode Worker that utilizes the Dynamo Connect library. The KV Cache transfer between Decode Worker and Prefill Worker utilizes the NIXL base RDMA subsystem directly without using the Dynamo Connect library._ + +#### Code Examples + +See [prefill_worker](../components/prefill_worker.py#L199) or [decode_worker](../components/decode_worker.py#L239), +for how they coordinate directly with the Encode Worker by creating a [`WritableOperation`](#writableoperation), +sending the operation's metadata via Dynamo's round-robin dispatcher, and awaiting the operation for completion before making use of the transferred data. + +See [encode_worker](../components/encode_worker.py#L190), +for how the resulting embeddings are registered with the RDMA subsystem by creating a [`Descriptor`](#descriptor), +a [`WriteOperation`](#writeoperation) is created using the metadata provided by the requesting worker, +and the worker awaits for the data transfer to complete for yielding a response. + +## Python Classes + +### Connector + +Core class for managing the connection between workers in a distributed environment. +Use this class to create readable and writable operations, or read and write data to remote workers. + +This class is responsible for interfacing with the NIXL-based RDMA subsystem and providing a "Pythonic" interface +with which to utilize GPU Direct RDMA accelerated data transfers between models hosted by different workers in a Dynamo pipeline. +The connector provides two methods of moving data between workers: + + - Preparing local memory to be written to by a remote worker. + + - Preparing local memory to be read by a remote worker. + +In both cases, local memory is registered with the NIXL-based RDMA subsystem via the [`Descriptor`](#descriptor) class and provided to the connector. +The connector then configures the RDMA subsystem to expose the memory for the requested operation and returns an operation control object. +The operation control object, either a [`ReadableOperation`](#readableoperation) or a [`WritableOperation`](#writableoperation), +provides RDMA metadata via its [`.to_serialized()`](#to_serialized) method as well as functionality to know when the operation has been completed or cancel the operation prior to completion. + +The RDMA metadata must be provided to the remote worker expected to complete the operation. +The metadata contains required information (identifiers, keys, etc.) which enables the remote worker to interact with the provided memory. + +#### Methods + +##### `begin_read` + +> Creates a [`ReadOperation`](#readoperation) for transferring data from a remote worker. +> +> To create the operation, the serialized request from a remote worker's [`ReadableOperation`](#readableoperation) +> along with a matching set of local memory descriptors which reference memory intended to receive data from the remote worker +> must be provided. +> The serialized request must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS. +> +> Once created, the operation will begin reading immediately. +> Disposal of the object reference will instruct the RDMA subsystem to cancel the read operation, +> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended. + +##### `begin_write` + +> Creates a write operation for transferring data to a remote worker. +> +> To create the operation, the serialized request from a remote worker's [`WritableOperation`](#writableoperation) +> along with a matching set of local memory descriptors which reference memory to be transferred to the remote worker +> must be provided. +> The serialized request must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS. +> +> Once created, the operation will begin writing immediately. +> Disposal of the object reference will instruct the RDMA subsystem to cancel the write operation, +> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended. + +##### `create_readable` + +> Creates a [`ReadableOperation`](#readableoperation) for transferring data to a remote worker. +> +> To create the operation, a set of local memory descriptors must be provided that reference memory intended to be transferred to +> a remote worker. +> Once created, the memory referenced by the provided descriptors becomes immediately readable by a remote worker with the necessary metadata. +> The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method. +> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS. +> +> Disposable of the operation's object reference will instruct the RDMA subsystem to cancel the operation, +> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended. + +##### `create_writable` + +> Creates a [`WritableOperation`](#writableoperation) for transferring data from a remote worker. +> +> To create the operation, a set of local memory descriptors must be provided which reference memory intended to receive data from +> a remote worker. +> Once created, the memory referenced by the provided descriptors becomes immediately writable by a remote worker with the necessary metadata. +> The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method. +> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS. +> +> Disposable of the operation's object reference will instruct the RDMA subsystem to cancel the operation, +> therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended. + + +### Descriptor + +Memory descriptor that ensures memory is registered with the NIXL base RDMA subsystem. +Memory must be registered with the RDMA subsystem to enable interaction with the memory. + +Descriptor objects are administrative and do not copy, move, or otherwise modify the registered memory. + +There are four ways to create a descriptor: + + 1. From a `torch.Tensor` object. Device information will be derived from the provided object. + + 2. From a `tuple` containing either a NumPy or CuPy `ndarray` and information describing where the memory resides (Host/CPU vs GPU). + + 3. From a Python `bytes` object. Memory is assumed to reside in CPU addressable host memory. + + 4. From a `tuple` comprised of the address of the memory, its size in bytes, and device information. + An optional reference to a Python object can be provided to avoid garbage collection issues. + + +### Device + +Device describes the device, or kind of memory, a given allocation resides in. +Usually host (`"cpu"`) or GPU (`"cuda"`) memory. + +When a system contains multiple GPU devices, specific GPU devices can be identified by including their ordinal index number. +For example, to reference the second GPU in a system `"cuda:1"` can be used. + +By default, when `"cuda"` is provided, it is assumed to be `"cuda:0"` or the first GPU enumerated by the system. + + +### ReadOperation + +An operation which transfers data from a remote worker to the local worker. + +To create the operation, RDMA metadata ([`SerializedRequest`](#serializedrequest)) from a remote worker's [`ReadableOperation`](#readableoperation) +along with a matching set of local [`Descriptor`](#descriptor) objects which reference memory intended to receive data from the remote worker must be provided. +The RDMA metadata must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS. + +Once created, the operation will begin reading immediately. +Disposal of the object reference will instruct the RDMA subsystem to cancel the read operation, +therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended. + +#### Methods + +##### `cancel` + +> Instructs the RDMA subsystem to cancel the operation. +> Completed operations cannot be cancelled. + +##### `wait_for_completion` + +> Blocks the caller until the memory from the remote worker has been transferred to the provided buffers. + + +### ReadableOperation + +An operation which enables a remote worker to read data from the local worker. + +To create the operation, a set of local [`Descriptor`](#descriptor) objects must be provided that reference memory intended to be transferred to a remote worker. +Once created, the memory referenced by the provided descriptors becomes immediately readable by a remote worker with the necessary metadata. +The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method. +Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS. + +Disposal of the operation's object reference will instruct the RDMA subsystem to cancel the operation, +therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended. + +#### Methods + +##### `to_serialized` + +> Generates and returns the RDMA metadata ([`SerializedRequest`](#serializedrequest)) required for a remote worker to read from the operation. +> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS. + +##### `wait_for_completion` + +> Blocks the caller until the operation has received a completion signal from a remote worker. + + +### WriteOperation + +An operation which transfers data from the local worker to a remote worker. + +To create the operation, RDMA metadata ([`SerializedRequest`](#serializedrequest)) from a remote worker's [`WritableOperation`](#writableoperation) +along with a matching set of local [`Descriptor`](#descriptor) objects which reference memory to be transferred to the remote worker must be provided. +The RDMA metadata must be transferred from the remote to the local worker via a secondary channel, most likely HTTP or TCP+NATS. + +Once created, the operation will begin writing immediately. +Disposal of the object reference will instruct the RDMA subsystem to cancel the write operation, +therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended. + +#### Methods + +##### `cancel` + +> Instructs the RDMA subsystem to cancel the operation. +> Completed operations cannot be cancelled. + +##### `wait_for_completion` + +> Blocks the caller until all provided buffers have been transferred to the remote worker. + + +### WritableOperation + +An operation which enables a remote worker to write data to the local worker. + +To create the operation, a set of local [`Descriptor`](#descriptor) objects must be provided which reference memory intended to receive data from a remote worker. +Once created, the memory referenced by the provided descriptors becomes immediately writable by a remote worker with the necessary metadata. +The metadata required to access the memory referenced by the provided descriptors is accessible via the operations `.to_serialized()` method. +Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS. + +Disposal of the operation's object reference will instruct the RDMA subsystem to cancel the operation, +therefore the operation should be awaited until complete or and deleted prior to completion when cancellation is intended. + +#### Methods + +##### `to_serialized` + +> Generates and returns the RDMA metadata ([`SerializedRequest`](#serializedrequest)) required for a remote worker to write to the operation. +> Once acquired, the metadata needs to be provided to a remote worker via a secondary channel, most likely HTTP or TCP+NATS. + +##### `wait_for_completion` + +> Blocks the caller until the operation has received a completion signal from a remote worker. + + +### SerializedRequest + +A Pydantic type intended to provide JSON serialized RDMA metadata about a [`ReadableOperation`](#readableoperation) or [`WritableOperation`](#writableoperation) object. + +Use the [`.to_serialized()`](#to_serialized) method on either of the above types to generate a `SerializedRequest` object for an operation. + +## References + + - [NVIDIA Dynamo](https://developer.nvidia.com/dynamo) @ [GitHub](https://github.com/ai-dynamo/dynamo) + - [NVIDIA Inference Transfer Library (NIXL)](https://developer.nvidia.com/blog/introducing-nvidia-dynamo-a-low-latency-distributed-inference-framework-for-scaling-reasoning-ai-models/#nvidia_inference_transfer_library_nixl_low-latency_hardware-agnostic_communication%C2%A0) @ [GitHub](https://github.com/ai-dynamo/nixl) + - [Dynamo Multimodal Example](https://github.com/ai-dynamo/dynamo/tree/main/examples/multimodal) + - [NVIDIA GPU Direct](https://developer.nvidia.com/gpudirect) diff --git a/examples/multimodal_v1/connect/__init__.py b/examples/multimodal_v1/connect/__init__.py new file mode 100644 index 0000000000..17fdb55bbd --- /dev/null +++ b/examples/multimodal_v1/connect/__init__.py @@ -0,0 +1,1472 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import logging +import socket +import uuid +import zlib +from abc import ABC, abstractmethod +from enum import IntEnum +from functools import cached_property +from typing import Any, List, Optional + +import nixl._api as nixl_api +import nixl._bindings as nixl_bindings +import torch +from pydantic import BaseModel, ConfigDict, field_validator + +from dynamo.runtime import DistributedRuntime + +logger = logging.getLogger(__name__) + +try: + import cupy as array_module + from cupy_backends.cuda.api.runtime import CUDARuntimeError + + logger.info("Utilizing cupy to enable GPU acceleration.") +except ImportError: + try: + import numpy as array_module + + logger.warning("Failed to load cupy for GPU acceleration, utilizing numpy to provide CPU based operations.") + except ImportError as e: + raise ImportError("Numpy or cupy must be installed to use this module.") from e + + +class AbstractOperation(ABC): + """ + Abstract base class for awaitable NIXL based RDMA operations. + """ + + def __init__( + self, + connector: Connector, + operation_kind: OperationKind, + local_descriptors: Descriptor | list[Descriptor], + remote_descriptors: Optional[Descriptor | list[Descriptor]], + notification_key: Optional[str], + ) -> None: + if not isinstance(connector, Connector): + raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.") + if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE: + raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.") + if not ( + isinstance(local_descriptors, (Descriptor, list)) + or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors)) + ): + raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.") + if ( + remote_descriptors is not None + and not ( + isinstance(remote_descriptors, Descriptor) + or (isinstance(remote_descriptors, list) and all(isinstance(d, Descriptor) for d in remote_descriptors)) + ) + ): + raise TypeError("Argument `remote_descriptors` must be dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`, or `None`.") + + if isinstance(local_descriptors, list) and len(local_descriptors) == 0: + raise ValueError("Argument `local_descriptors` must not be an empty list.") + if ( + remote_descriptors is not None + and isinstance(remote_descriptors, list) + and len(remote_descriptors) == 0 + ): + raise ValueError("Argument `remote_descriptors` must not be an empty list.") + + notification_key = str(uuid.uuid4()) if notification_key is None else notification_key + if not isinstance(notification_key, str): + raise TypeError("Argument `notification_key` must be `str` or `None`.") + if len(notification_key) == 0: + raise ValueError("Argument `notification_key` must not be an empty string.") + + self._notification_key: str = notification_key + self._connector: Connector = connector + self._operation_kind: OperationKind = operation_kind + self._local_descriptors: Descriptor | list[Descriptor] = local_descriptors + self._local_dlist: Optional[list[tuple[int, int, int]]] = None + self._local_memtype: DeviceKind = DeviceKind.UNSPECIFIED + self._remote_descriptors: Optional[Descriptor | list[Descriptor]] = None if remote_descriptors is None else remote_descriptors + self._remote_dlist: Optional[list[tuple[int, int, int]]] = None + self._remote_memtype: DeviceKind = DeviceKind.UNSPECIFIED + + # Register local descriptors with NIXL. + # Note: Only local descriptors should be registered with NIXL, + if isinstance(local_descriptors, list): + for d in local_descriptors: + d.register_memory(self._connector) + else: + local_descriptors.register_memory(self._connector) + + # Record local descriptors. + memtype, dtlist = self._create_dlist(local_descriptors) + self._local_dlist = dtlist + self._local_memtype = memtype + + # Record remote descriptors when provided. + if remote_descriptors is not None: + memtype, dtlist = self._create_dlist(remote_descriptors) + self._remote_dlist = dtlist + self._remote_memtype = memtype + + def __del__(self) -> None: + self._release() + + def __enter__(self) -> AbstractOperation: + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self._release() + + def _release(self) -> None: + """ + Private method to release resources. Only to be called by `self`. + """ + pass + + @property + def connector(self) -> Connector: + """ + Gets the local associated with this operation. + """ + return self._connector + + @property + def operation_kind(self) -> OperationKind: + """ + Gets the kind of operation. + """ + return self._operation_kind + + @abstractmethod + async def wait_for_completion(self) -> None: + """ + Blocks the caller asynchronously until the operation has completed. + """ + raise NotImplementedError("Abstract method not implemented by derived class.") + + # Private Methods + + def _create_dlist( + self, + descriptors: Descriptor | list[Descriptor], + ) -> tuple[DeviceKind, list[tuple[int, int, int]]]: + """ + Helper function to create a list of tuples (ptr, size, device) from descriptors. + """ + dlist: list[tuple[int, int, int]] = [] + memtype: DeviceKind = DeviceKind.UNSPECIFIED + if isinstance(descriptors, list): + memtype = descriptors[0].device.kind + for desc in descriptors: + if memtype != desc.device.kind: + raise ValueError("All local descriptors must have the same memory type.") + dlist.append((desc.ptr, desc.size, desc.device.id)) + else: + memtype = descriptors.device.kind + dlist.append((descriptors.ptr, descriptors.size, descriptors.device.id)) + return (memtype, dlist) + + +class ActiveOperation(AbstractOperation): + """ + Abstract class for active operations that initiates a NIXL based RDMA transfer based `SerializedRequest` + provided by the remote worker's corresponding `PassiveOperation`. + """ + + def __init__( + self, + remote: Remote, + operation_kind: OperationKind, + local_descriptors: Descriptor | list[Descriptor], + remote_descriptors: Descriptor | list[Descriptor], + notification_key: str, + ) -> None: + if not isinstance(remote, Remote) or remote._connector is None: + raise TypeError("Argument `remote` must be valid `dynamo.connect.Remote`.") + if not isinstance(operation_kind, OperationKind): + raise TypeError("Argument `operation_kind` must `dynamo.connect.OperationKind`.") + if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE: + raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.") + if not ( + isinstance(local_descriptors, Descriptor) + or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors)) + ): + raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.") + if not ( + isinstance(remote_descriptors, Descriptor) + or (isinstance(remote_descriptors, list) and all(isinstance(d, Descriptor) for d in remote_descriptors)) + ): + raise TypeError("Argument `remote_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.") + + # Unpack single descriptors from lists if they are provided as single descriptors. + if isinstance(local_descriptors, list) and len(local_descriptors) == 1: + local_descriptors = local_descriptors[0] + if isinstance(remote_descriptors, list) and len(remote_descriptors) == 1: + remote_descriptors = remote_descriptors[0] + + if (isinstance(local_descriptors, list) and isinstance(remote_descriptors, list) and len(local_descriptors) != len(remote_descriptors)): + raise ValueError("When `local_descriptors` and `remote_descriptors` are lists, they must have the same length.") + elif isinstance(local_descriptors, list) != isinstance(remote_descriptors, list): + raise ValueError("Both `local_descriptors` and `remote_descriptors` must be either lists or single descriptors.") + if not isinstance(notification_key, str): + raise TypeError("Argument `notification_key` must be `str`.") + if len(notification_key) == 0: + raise ValueError("Argument `notification_key` must not be an empty string.") + + self._remote = remote + self._status = OperationStatus.UNINTIALIZED + + super().__init__(remote.connector, operation_kind, local_descriptors, remote_descriptors, notification_key) + # Quick check to ensure remote descriptors are not None to make static analysis happy. + if self._local_dlist is None or self._remote_dlist is None: + raise RuntimeError("NIXL descriptor list(s) not bound to operation.") + + self._local_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None + self._remote_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None + self._xfer_hndl: Optional[nixl_api.nixl_xfer_handle] = None + + self._local_xfer_descs = self._connector._nixl.get_xfer_descs( + descs=self._local_dlist, + mem_type=str(self._local_memtype), + ) + logger.debug(f"Created local NIXL xfer descs: {self._local_xfer_descs}") + self._remote_xfer_descs = self._connector._nixl.get_xfer_descs( + descs=self._remote_dlist, + mem_type=str(self._remote_memtype), + ) + logger.debug(f"Created remote NIXL xfer descs: {self._remote_xfer_descs}") + self._xfer_hndl = self._connector._nixl.initialize_xfer( + operation=str(operation_kind), + local_descs=self._local_xfer_descs, + remote_descs=self._remote_xfer_descs, + remote_agent=self._remote.name, + notif_msg=self._notification_key.encode("utf-8"), + ) + logger.debug(f"Created NIXL transfer handle: {self._xfer_hndl}") + + def __del__(self) -> None: + super().__del__() + self._release() + + def __enter__(self) -> ActiveOperation: + super().__enter__() + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + match self.status: + case OperationStatus.IN_PROGRESS | OperationStatus.INITIALIZED: + self._status = OperationStatus.CANCELLED + + self._release() + + def __repr__(self) -> str: + return str( + f"{self.__class__.__name__}(" + f"operation_kind={self._operation_kind}, " + f"local_descriptors={self._local_descriptors}, " + f"remote_descriptors={self._remote_descriptors}, " + f"notification_key='{self._notification_key}', " + f"remote='{self._remote.name}', " + f"status='{self._status}'" + f")" + ) + + def _release(self) -> None: + """ + Private method to release resources. + """ + error: Optional[Exception] = None + + if self._xfer_hndl is not None: + try: + logger.debug(f"NIXL transfer handle {self._xfer_hndl} released.") + self._connector._nixl.release_xfer_handle(self._xfer_hndl) + except Exception as e: + logger.error(f"Failed to release resources: {e}") + error = e + finally: + self._xfer_hndl = None + + try: + super()._release() + except Exception as e: + logger.error(f"Failed to release WaitableOperation resources: {e}") + if error is not None: + e.__cause__ = error + error = e + + if error is not None: + raise error + + def _cancel_(self) -> None: + if self._xfer_hndl is None: + return + if self.status == OperationStatus.ERRORED: + raise RuntimeError("Operation is errored, unable to cancel the operation.") + + logger.info(f"Cancellation requested for operation {{ kind={self._operation_kind}, remote='{self._remote.name}', status={self._status} }}.") + + # NIXL will cancel the transfer if it is in progress when the handle is released. + self._connector._nixl.release_xfer_handle(self._xfer_hndl) + self._status = OperationStatus.CANCELLED + self._xfer_hndl = None + + async def _wait_for_completion_(self) -> None: + # Loop until the operation is no longer in progress (or "initalized"), + # yielding control to the event loop to allow other operations to run. + iteration_count = 0 + while True: + if iteration_count % 10 == 0: + logger.debug(f"Waiting for operation {{ kind={self._operation_kind}, remote='{self._remote.name}', duration={iteration_count / 10}s }}.") + match self.status: + # "in progress" or "initialized" means the operation is ongoing. + case OperationStatus.INITIALIZED: + await asyncio.sleep(0.1) + case OperationStatus.IN_PROGRESS: + await asyncio.sleep(0.1) + # Any other state indicates completion or error. + case _: + return + iteration_count += 1 + + @abstractmethod + def cancel(self) -> None: + """ + Cancels the operation. + No affect if the operation has already completed or errored, or has been cancelled. + """ + raise NotImplementedError("Abstract method not implemented by derived class.") + + @property + def remote(self) -> Remote: + """ + Gets the remote worker associated with this operation. + """ + return self._remote + + @property + def status(self) -> OperationStatus: + """ + Gets the status of the operation. + """ + # Early return if the operation is already complete, errored, or cancelled. + match self._status: + case OperationStatus.COMPLETE | OperationStatus.ERRORED | OperationStatus.CANCELLED: + return self._status + + if self._xfer_hndl is None: + raise RuntimeError("NIXL transfer handle is invalid.") + + old_status = self._status + + if self._status == OperationStatus.UNINTIALIZED: + state = self._connector._nixl.transfer(self._xfer_hndl, self._notification_key.encode("utf-8")) + logger.debug(f"NIXL reported transfer state: {state}") + if state == "ERR": + self._status = OperationStatus.ERRORED + elif state == "DONE": + self._status = OperationStatus.COMPLETE + else: + self._status = OperationStatus.INITIALIZED + else: + state = self._connector._nixl.check_xfer_state(self._xfer_hndl) + logger.debug(f"NIXL reported transfer state: {state}") + if state == "ERR": + self._status = OperationStatus.ERRORED + elif state == "DONE": + self._status = OperationStatus.COMPLETE + else: + self._status = OperationStatus.IN_PROGRESS + + if self._status != old_status: + logger.debug(f"{self.__class__.__name__} {{ remote: '{self._remote.name}' status: '{old_status}' => '{self._status}' }}.") + + return self._status + + +class Connector: + """ + Core class for managing the connection between workers in a distributed environment. + Use this class to create readable and writable operations, or read and write data to remote workers. + """ + + def __init__( + self, + namespace: Optional[str] = None, + runtime: Optional[DistributedRuntime] = None, + worker_id: Optional[str] = None, + ) -> None: + """ + Creates a new Connector instance. + + Parameters + ---------- + namespace : Optional[str], optional + Dynamo namespace of the component, defaults to "dynamo" when `None`. + runtime : Optional[DistributedRuntime], optional + Reference the dynamo runtime used by the compenent, attempts to use the current runtime when `None`. + worker_id : Optional[str], optional + Unique identifier of the worker, defaults to a new UUID when `None`. + + Raises + ------ + TypeError + When `namespace` is provied and not of type 'str'. + TypeError + When `runtime` iis provied and not of type `dynamo.runtime.DistributedRuntime`. + TypeError + When `worker_id` is provied and not of type `uuid.UUID`. + """ + namespace = "dynamo" if namespace is None else namespace + if not isinstance(namespace, str): + raise TypeError("Argument `namespace` must be `str` or `None`.") + if not isinstance(runtime, DistributedRuntime) or runtime is None: + raise TypeError("Argument `runtime` must be `dynamo.runtime.DistributedRuntime` or `None`.") + worker_id = worker_id if worker_id is not None else str(uuid.uuid4()) + if not isinstance(worker_id, str) or len(worker_id) == 0: + raise TypeError("Argument `worker_id` must be a non-empty `str` or `None`.") + + self._worker_id = worker_id + self._is_initialized = False + self._runtime = runtime + self._namespace = namespace + self._nixl = nixl_api.nixl_agent(self._worker_id) + self._hostname = socket.gethostname() + self._agent_metadata: Optional[bytes] = None + + logger.debug(f"Created {self.__repr__()}.") + + def __repr__(self) -> str: + return str( + f"{self.__class__.__name__}(" + f"worker_id='{self._worker_id}', " + f"namespace={self._namespace}, " + f"hostname={self._hostname}, " + f"metadata=<{0 if self._agent_metadata is None else len(self._agent_metadata)} bytes>" + ")" + ) + + def __str__(self) -> str: + return self._worker_id + + @cached_property + def is_cuda_available(self) -> bool: + # Note: cuda.is_avalailable initializes cuda + # and can't be called when forking subprocesses + # care should be taken to only call it within + # subprocesses or use 'spawn' + try: + return array_module.cuda is not None and array_module.cuda.is_available() + except CUDARuntimeError: + return False + + @property + def metadata(self) -> bytes: + """ + Get the metadata of the worker. + """ + return self._nixl.get_agent_metadata() + + @property + def name(self) -> str | None: + """ + Get the name of the worker. + """ + return self._worker_id + + @property + def namespace(self) -> str: + """ + Get the namespace of the local. + """ + return self._namespace + + @property + def runtime(self) -> DistributedRuntime: + """ + Get the runtime of the local. + """ + if self._runtime is None: + raise RuntimeError("Runtime is not set. This Connector was not initialized with a runtime.") + return self._runtime + + async def begin_read( + self, + remote_request: SerializedRequest, + local_descriptors: Descriptor | list[Descriptor], + ) -> ReadOperation: + """ + Creates a read operation for fulfilling a remote readable operation. + + Parameters + ---------- + remote_request : SerializedRequest + Serialized request from a remote worker that has created a readable operation. + local_descriptors : Descriptor | list[Descriptor] + Local descriptor(s) to receive data from the remote worker described by `remote_request`. + + Returns + ------- + ReadOperation + Awaitable read operation that can be used to transfer data from a remote worker. + + Raises + ------ + TypeError + When `remote_request` is not of type `SerializedRequest`. + TypeError + When `local_descriptors` is not of type `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`. + """ + if remote_request is None or not isinstance(remote_request, SerializedRequest): + raise TypeError("Argument `remote_request` must be `SerializedRequest`.") + if not ( + isinstance(local_descriptors, Descriptor) + or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors)) + ): + raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.") + if remote_request.operation_kind != OperationKind.READ.value: + raise RuntimeError("Cannot create a `dynamo.connect.ReadOperation` to read from a remote `dynamo.connect.WritableOperation`.") + + if not self._is_initialized: + raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.") + + op = ReadOperation(self, remote_request, local_descriptors) + return op + + async def begin_write( + self, + local_descriptors: Descriptor | list[Descriptor], + remote_request: SerializedRequest, + ) -> WriteOperation: + """ + Creates a write operation for transferring data to a remote worker. + + Parameters + ---------- + remote_request : SerializedRequest + Serialized request from a remote worker that has created a readable operation. + local_descriptors : Descriptor | list[Descriptor] + Local descriptors of one or more data objects to be transferred to the remote worker. + """ + if remote_request is None or not isinstance(remote_request, SerializedRequest): + raise TypeError("Argument `remote_request` must be `SerializedRequest`.") + if not ( + isinstance(local_descriptors, Descriptor) + or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors)) + ): + raise TypeError("Argument `local_descriptors` must be `Descriptor` or `list[Descriptor]`.") + if remote_request.operation_kind != OperationKind.WRITE: + raise RuntimeError("Cannot create a `WriteOperation` to write to a remote `ReadableOperation`.") + if not isinstance(remote_request.nixl_metadata, str): + raise TypeError("Argument `remote_request.nixl_metadata` must be `str`.") + + if not self._is_initialized: + raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.") + + op = WriteOperation(self, local_descriptors, remote_request) + return op + + def create_readable( + self, + local_descriptors: Descriptor | list[Descriptor], + ) -> ReadableOperation: + """ + Creates a readable operation for transferring data from a remote worker. + + Returns + ------- + ReadableOperation + A readable operation that can be used to transfer data from a remote worker. + """ + if not self._is_initialized: + raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.") + + op = ReadableOperation(self, local_descriptors) + return op + + def create_writable( + self, + local_descriptors: Descriptor | list[Descriptor], + ) -> WritableOperation: + """ + Creates a writable operation for transferring data to a remote worker. + + Returns + ------- + WritableOperation + A writable operation that can be used to transfer data to a remote worker. + """ + if not self._is_initialized: + raise RuntimeError("Connector not initialized. Call `initialize()` before calling this method.") + + op = WritableOperation(self, local_descriptors) + return op + + async def initialize(self) -> None: + # Only initialize the connector once. + if self._is_initialized: + return + + self._is_initialized = True + # This method is a no-op for now, in the future it may be used to initialize the connector. + logger.debug(f"Initialized Connector {{ name: '{self._worker_id}', namespace '{self._namespace}' }} completed.") + + +class Descriptor: + """ + Memory descriptor that ensures memory is registered w/ NIXL, used for transferring data between workers. + """ + + def __init__( + self, + data: torch.Tensor | tuple[array_module.ndarray, Device|str] | bytes | tuple[int, int, Device|str, Any], + ) -> None: + """ + Memory descriptor for transferring data between workers. + + Parameters + ---------- + data : torch.Tensor | tuple[ndarray, Device|str] | bytes | tuple[int, int, Device|str, Any] + The data to be transferred. + + When `torch.Tensor` is provided, the attributes of the tensor will be used to create the descriptor. + + When `tuple[ndarray, Device]` is provided, the tuple must contain: + - `ndarray`: The CuPy or NumPy array to be transferred. + - `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu"). + + When `bytes` is provided, the pointer and size derived from the bytes object and memory type will be assumed to be CPU. + + When `tuple[int, int, Device|str, Any]` is provided, the tuple must contain the following elements: + - `int`: Pointer to the data in memory. + - `int`: Size of the data in bytes. + - `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu"). + - `Any`: Optional reference to the data (e.g., the original tensor or bytes object). + This is useful for keeping a reference to the data in memory, but it is not required. + + Raises + ------ + ValueError + When `data` is `None`. + TypeError + When `data` is not a valid type (i.e., not `torch.Tensor`, `bytes`, or a valid tuple). + TypeError + When `data` is a tuple but the elements are not of the expected types (i.e., [`ndarray`, `Device|str`] OR [`int`, `int`, `Device|str`, `Any`]). + """ + TYPE_ERROR_MESSAGE = "Argument `data` must be `torch.Tensor`, `tuple[ndarray, Device|str]`, `bytes`, or `tuple[int, int, Device|str, Any]`." + if data is None: + raise ValueError("Argument `data` cannot be `None`.") + if not (isinstance(data, torch.Tensor) or isinstance(data, bytes) or isinstance(data, tuple)): + raise TypeError(TYPE_ERROR_MESSAGE) + + self._data_device: Device = Device("cpu") + self._data_ptr: int = 0 + self._data_ref: Optional[Any] = None + self._data_size: int = 0 + + # Member fields for managing NIXL memory registration. + # Note: ONLY local descriptors should be registered with NIXL, + # remote descriptors do not have a valid memory address and registration will fault. + self._connector: Optional[Connector] = None + self._nixl_hndl: Optional[nixl_bindings.nixlRegDList] = None + + # Initially `None` cached serialized descriptor reference, populated when `to_serialized()` is called. + self._serialized: Optional[SerializedDescriptor] = None + + # Data is `torch.Tensor`. + if isinstance(data, torch.Tensor): + self._data_ptr = data.data_ptr() + self._data_size = data.numel() * data.element_size() + if data.is_cuda: + self._data_device = Device((DeviceKind.CUDA, data.get_device())) + self._data_ref = data + + logger.debug(f"Created {self.__repr__()} from `torch.Tensor`.") + + # Data is `tuple[ndarray, Device]`. + elif ( + isinstance(data, tuple) + and len(data) == 2 + and isinstance(data[0], array_module.ndarray) + and (isinstance(data[1], Device) or isinstance(data[1], str)) + ): + if hasattr(data[0], "__array_interface__"): + self._data_ptr = data[0].__array_interface__["data"][0] + elif hasattr(data[0], "__cuda_array_interface__"): + self._data_ptr = data[0].__cuda_array_interface__["data"][0] + else: + raise TypeError("Argument `data[0]` must be a `ndarray` with a valid array interface.") + self._data_size = data[0].nbytes + self._data_device = data[1] if isinstance(data[1], Device) else Device(data[1]) + self._data_ref = data[0] + + logger.debug(f"Created {self.__repr__()} from `tuple[ndarray, Device|str]`.") + + # Data is `bytes`. + elif isinstance(data, bytes): + self._data_ptr = id(data) + self._data_size = len(data) + self._data_ref = data + + logger.debug(f"Created {self.__repr__()} from `bytes`.") + + # Data is `tuple[int, int, Device, dtype, tuple, Any]`. + elif isinstance(data, tuple) and len(data) >= 2 and isinstance(data[0], int) and isinstance(data[1], int): + if len(data) >= 3 and not (isinstance(data[2], Device) or isinstance(data[2], str)): + raise TypeError("Argument `data` must be a `tuple[int, int, Device|str, Any]`.") + + self._data_ptr = data[0] + self._data_size = data[1] + if len(data) >= 3: + self._data_device = data[2] if isinstance(data[2], Device) else Device(data[2]) + self._data_ref = data[3] if len(data) >=4 else None + + logger.debug(f"Created {self.__repr__()} from `tuple[int, int, Device|str, Any]`.") + else: + raise TypeError(TYPE_ERROR_MESSAGE) + + def __del__(self) -> None: + if self._nixl_hndl is not None and self._connector is not None: + # Unregister the memory with NIXL. + self._connector._nixl.deregister_memory(self._nixl_hndl) + self._nixl_hndl = None + + if self._data_ref is not None: + # Release the reference to the data. + del self._data_ref + + logger.debug(f"Deleted {self.__repr__()}.") + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self})" + + def __str__(self) -> str: + return f"ptr={hex(self._data_ptr)}, size={self._data_size}, device={self._data_device}" + + @property + def device(self) -> Device: + """ + Gets the device the of the descriptor. + """ + return self._data_device + + @property + def ptr(self) -> int: + """ + Gets the pointer of the descriptor. + """ + return self._data_ptr + + @property + def size(self) -> int: + """ + Gets the size of the descriptor. + """ + return self._data_size + + @staticmethod + def from_serialized( + serialized: SerializedDescriptor, + ) -> Descriptor: + """ + Deserializes a `SerializedDescriptor` into a `Descriptor` object. + + Parameters + ---------- + serialized : SerializedDescriptor + The serialized descriptor to deserialize. + + Returns + ------- + Descriptor + The deserialized descriptor. + """ + if not isinstance(serialized, SerializedDescriptor): + raise TypeError("Argument `serialized` must be `SerializedDescriptor`.") + + return serialized.to_descriptor() + + def register_memory( + self, + connector: Connector, + ) -> None: + """ + Registers the memory of the descriptor with NIXL. + """ + if not isinstance(connector, Connector): + raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.") + if self._data_ptr == 0: + raise ValueError("Cannot register memory with a null pointer.") + + if not (self._nixl_hndl is None and self._connector is None): + return + + # Register the memory with NIXL. + self._connector = connector + if isinstance(self._data_ref, torch.Tensor): + self._nixl_hndl = connector._nixl.register_memory(self._data_ref) + else: + mem_type = str(self._data_device.kind) + reg_list = [(self._data_ptr, self._data_size, self._data_device.id, mem_type)] + self._nixl_hndl = connector._nixl.register_memory(reg_list, mem_type) + + logger.debug(f"Registered {self.__repr__()} with NIXL.") + + def to_serialized(self) -> SerializedDescriptor: + """ + Serializes the descriptor into a `SerializedDescriptor` object. + """ + if self._serialized is None: + self._serialized = SerializedDescriptor( + device=f"{self._data_device}", + ptr=self._data_ptr, + size=self._data_size, + ) + + return self._serialized + + +class Device: + """ + Represents a device in the system. + """ + + def __init__( + self, + metadata: str | tuple[DeviceKind, int], + ) -> None: + if metadata is None: + raise ValueError("Argument `metadata` cannot be `None`.") + if isinstance(metadata, tuple) and len(metadata) == 2 and isinstance(metadata[0], DeviceKind) and isinstance(metadata[1], int): + kind, device_id = metadata + elif isinstance(metadata, str): + metadata = metadata.strip().lower() + if metadata.startswith("cuda") or metadata.startswith("gpu"): + kind = DeviceKind.CUDA + device_id = 0 if metadata.find(":") == -1 else int(metadata.split(":")[1]) + elif metadata.startswith("cpu") or metadata.startswith("host"): + kind = DeviceKind.HOST + device_id = 0 + else: + raise ValueError("Argument `metadata` must be in the format 'cuda:' or 'cpu'.") + else: + raise TypeError("Argument `metadata` must be a `tuple[MemoryKind, int]` or a `str`.") + + + self._device_id = device_id + self._kind = kind + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(kind={self._kind}, id={self._device_id})" + + def __str__(self) -> str: + return f"{self._kind}:{self._device_id}" if self._kind is DeviceKind.CUDA else f"{self._kind}" + + @property + def id(self) -> int: + """ + Gets the device ID of the device. + """ + return self._device_id + + @property + def kind(self) -> DeviceKind: + """ + Gets the memory kind of the device. + """ + return self._kind + + +class DeviceKind(IntEnum): + """ + Type of memory a descriptor has been allocated to. + """ + + UNSPECIFIED = 0 + HOST = 1 + CUDA = 2 + + def __str__(self) -> str: + if self == DeviceKind.HOST: + return "cpu" + elif self == DeviceKind.CUDA: + return "cuda" + else: + return "" + + +class OperationKind(IntEnum): + """ + Kind of an operation. + """ + + UNSPECIFIED = 0 + READ = 1 + WRITE = 2 + + def __str__(self) -> str: + if self == OperationKind.READ: + return "READ" + elif self == OperationKind.WRITE: + return "WRITE" + else: + return "" + + +class OperationStatus(IntEnum): + """ + Status of an operation. + """ + + UNINTIALIZED = 0 + INITIALIZED = 1 + IN_PROGRESS = 2 + COMPLETE = 3 + CANCELLED = 4 + ERRORED = 5 + + def __str__(self) -> str: + if self == OperationStatus.INITIALIZED: + return "INIT" + elif self == OperationStatus.IN_PROGRESS: + return "PROC" + elif self == OperationStatus.COMPLETE: + return "DONE" + elif self == OperationStatus.ERRORED: + return "ERR" + elif self == OperationStatus.CANCELLED: + return "STOP" + else: + return "" + + +class PassiveOperation(AbstractOperation): + """ + Abstract class for common functionality of passive operations. + """ + + def __init__( + self, + connector: Connector, + operation_kind: OperationKind, + local_descriptors: Descriptor | list[Descriptor], + ) -> None: + if operation_kind is not OperationKind.READ and operation_kind is not OperationKind.WRITE: + raise ValueError("Argument `operation_kind` must be either `READ` or `WRITE`.") + + self._status = OperationStatus.UNINTIALIZED + + super().__init__(connector, operation_kind, local_descriptors, None, None) + + self._serialized_request: Optional[SerializedRequest] = None + self._status = OperationStatus.INITIALIZED + + def __del__(self) -> None: + super().__del__() + + def __enter__(self) -> AbstractOperation: + super().__enter__() + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + super().__exit__(exc_type, exc_value, traceback) + + def __repr__(self) -> str: + return str( + f"{self.__class__.__name__}(" + f"operation_kind={self._operation_kind}, " + f"local_descriptors={self._local_descriptors}, " + f"notification_key='{self._notification_key}', " + f"status='{self._status}'" + f")" + ) + + async def _wait_for_completion_(self) -> None: + # Loop until the operation is no longer in progress (or "initalized"), + # yielding control to the event loop to allow other operations to run. + while True: + match self.status: + # "in progress" or "initialized" means the operation is ongoing. + case OperationStatus.INITIALIZED: + await asyncio.sleep(0.1) + case OperationStatus.IN_PROGRESS: + await asyncio.sleep(0.1) + # Any other state indicates completion or error. + case _: + return + + @property + def status(self) -> OperationStatus: + """ + Gets the status of the operation. + """ + # Early return if the operation is already complete, errored, or cancelled. + match self._status: + case OperationStatus.COMPLETE | OperationStatus.ERRORED | OperationStatus.CANCELLED: + return self._status + + old_status = self._status + + # Query NIXL for any notifications. + notifications = self._connector._nixl.update_notifs() + + if isinstance(notifications, dict): + remote_state = OperationStatus.IN_PROGRESS + logger.debug(f"NIXL reported notifications: {len(notifications)}.") + + for key, values in notifications.items(): + if not isinstance(values, list): + raise TypeError(f"Expected `dict[str, list[bytes]]` from NIXL notification query; got {type(notifications)}.") + for value in values: + if not isinstance(value, bytes): + continue + notification_key = value.decode("utf-8") + + # Once we've found the notification key, we know the operation is complete. + if notification_key == self._notification_key: + remote_state = OperationStatus.COMPLETE + break + + if remote_state == OperationStatus.COMPLETE: + self._status = remote_state + logger.debug(f"{self.__class__.__name__} {{ remote: '{self._connector.name}' status: '{old_status}' => '{self._status}' }}.") + + return self._status + + def to_serialized(self) -> SerializedRequest: + """ + Gets the request descriptor for the operation. + """ + if self._serialized_request is None: + # When we've not yet cached the serialized request, we need to generate one before returning it. + # Handle both cases: multiple and single descriptors. + if isinstance(self._local_descriptors, list): + descriptors = [desc.to_serialized() for desc in self._local_descriptors] + else: + descriptors = [self._local_descriptors.to_serialized()] + + original_len = len(self._connector.metadata) + nixl_metadata = self._connector.metadata + nixl_metadata = zlib.compress(nixl_metadata, level=6) + compressed_len = len(nixl_metadata) + logger.debug(f"Compressed NIXL metadata from {original_len} bytes to {compressed_len} bytes.") + if compressed_len > original_len: + logger.warning(f"Compressed NIXL metadata is larger than original ({compressed_len} > {original_len}).") + + self._serialized_request = SerializedRequest( + descriptors=descriptors, + nixl_metadata=nixl_metadata.hex(), + notification_key=self._notification_key, + operation_kind=int(self._operation_kind), + ) + + return self._serialized_request + + @abstractmethod + async def wait_for_completion(self) -> None: + """ + Blocks the caller asynchronously until the operation has completed. + """ + raise NotImplementedError("Abstract method not implemented by derived class.") + + +class ReadOperation(ActiveOperation): + """ + Operation that initiates an RDMA read operation to transfer data from a remote worker's `ReadableOperation`, + as described by `remote_request`, to local buffers. + """ + + def __init__( + self, + connector: Connector, + remote_request: SerializedRequest, + local_descriptors: Descriptor | list[Descriptor], + ) -> None: + """ + Creates a new instance of `ReadOperation`, registers `local_descriptors` with NIXL, + and begins an RDMA read operation which will transfer data described by `remote_request` + to `local_descriptors`. + + Parameters + ---------- + connector : Connector + Connector instance to use for the operation. + remote_request : SerializedRequest + Serialized request from the remote worker. + local_descriptors : Descriptor | list[Descriptor] + Local descriptor(s) to to receive the data from the remote worker. + """ + if not isinstance(connector, Connector): + raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.") + if not isinstance(remote_request, SerializedRequest): + raise TypeError("Argument `remote_request` must be `dynamo.connect.RequestDescriptor`.") + if remote_request.operation_kind != OperationKind.READ.value: + raise ValueError("Argument `remote_request` must be of kind `READ`.") + + remote = Remote(connector, remote_request.nixl_metadata) + remote_descriptors = remote_request.to_descriptors() + + if not ( + isinstance(local_descriptors, Descriptor) + or (isinstance(local_descriptors, list) and all(isinstance(d, Descriptor) for d in local_descriptors)) + ): + raise TypeError("Argument `local_descriptors` must be `dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`.") + + super().__init__(remote, OperationKind.READ, local_descriptors, remote_descriptors, remote_request.notification_key) + logger.debug(f"Created {self.__repr__()}") + + def __del__(self) -> None: + super().__del__() + logger.debug(f"Deleted {self.__repr__()}") + + def __enter__(self) -> ReadOperation: + super().__enter__() + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + super().__exit__(exc_type, exc_value, traceback) + + def __repr__(self) -> str: + return super().__repr__() + + def cancel(self) -> None: + """ + Cancels the operation. + No affect if the operation has already completed or errored, or been cancelled. + """ + super()._cancel_() + + def results(self) -> list[Descriptor]: + """ + Gets the results of the operation. + Returns a single descriptor if only one was requested, or a list of descriptors if multiple were requested. + """ + if self._status != OperationStatus.COMPLETE: + raise RuntimeError("Operation has not completed yet, cannot get results.") + + return self._local_descriptors if isinstance(self._local_descriptors, list) else [self._local_descriptors] + + async def wait_for_completion(self) -> None: + """ + Blocks the caller asynchronously until the operation has completed. + """ + await super()._wait_for_completion_() + + +class ReadableOperation(PassiveOperation): + """ + Operation that can be awaited until a remote worker has completed a `ReadOperation`. + """ + + def __init__( + self, + connector: Connector, + local_descriptors: Descriptor | list[Descriptor], + ) -> None: + super().__init__(connector, OperationKind.READ, local_descriptors) + logger.debug(f"Created {self.__repr__()}") + + def __del__(self) -> None: + super().__del__() + logger.debug(f"Deleted {self.__repr__()}") + + def __enter__(self) -> ReadableOperation: + super().__enter__() + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + super().__exit__(exc_type, exc_value, traceback) + + def __repr__(self) -> str: + return super().__repr__() + + async def wait_for_completion(self) -> None: + """ + Blocks the caller asynchronously until the operation has completed. + """ + await super()._wait_for_completion_() + + +class Remote: + """ + Identifies a remote NIXL enabled worker relative to a local NIXL enabled worker. + """ + + def __init__( + self, + connector: Connector, + nixl_metadata: bytes | str, + ) -> None: + if not isinstance(connector, Connector): + raise TypeError("Argument `local` must be `dynamo.connect.Connector`.") + if not (isinstance(nixl_metadata, bytes) or isinstance(nixl_metadata, str)): + raise TypeError("Argument `nixl_metadata` must be `bytes` or `str`.") + if len(nixl_metadata) == 0: + raise ValueError("Argument `nixl_metadata` cannot be empty.") + + self._connector = connector + + # When `nixl_metadata` is a string, it is assumed to have come from a remote worker + # via a `SerializedRequest` object and therefore can assumed be a hex-encoded, compressed + # representation of the NIXL metadata. + if isinstance(nixl_metadata, str): + # Decode the hex-encoded string into bytes. + nixl_metadata = bytes.fromhex(nixl_metadata) + # Decompress the NIXL metadata. + nixl_metadata = zlib.decompress(nixl_metadata) + + self._name = connector._nixl.add_remote_agent(nixl_metadata) + if isinstance(self._name, bytes): + self._name = self._name.decode("utf-8") + + logger.debug(f"Created {self.__repr__()}.") + + def __del__(self) -> None: + self._release() + + def __enter__(self) -> Remote: + """ + Context manager entry method. Returns the current instance. + """ + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + """ + Context manager exit method. Cleans up the instance. + """ + self._release() + + def __repr__(self) -> str: + return f"Remote(name={self._name}, connector={self._connector.name})" + + def __str__(self) -> str: + return self._name + + def _release(self) -> None: + """ + Private method for releasing NIXL resources. Not intended for public use. + """ + # We have to unregister the remote agent from NIXL because we cannot know if the remote worker has updated its descriptors or not, and + # NIXL will return an error if we attempt to register a remote agent with the same name but different descriptors (aka conn_info). + self._connector._nixl.remove_remote_agent(self._name) + logger.debug(f"dynamo.connect.{self.__class__.__name__}: Unregistered NIXL remote {{ name: \"{self._name}\" }}.") + + @property + def connector(self) -> Connector: + """ + Gets the local connector associated with this remote worker. + """ + return self._connector + + @property + def name(self) -> str: + """ + Gets the name of the remote worker. + """ + return self._name + + +class SerializedDescriptor(BaseModel): + """ + Pydantic serialization type for memory descriptors. + """ + model_config = ConfigDict( + extra="forbid", + frozen=True, + arbitrary_types_allowed=True, + ) + + device: str = "cpu" + ptr: int = 0 + size: int = 0 + + def to_descriptor(self) -> Descriptor: + """ + Deserialize the serialized descriptor into a `Descriptor` object. + """ + return Descriptor(data=(self.ptr, self.size, self.device, None)) + + @field_validator("device") + def validate_memtype(cls, v: str) -> str: + if not isinstance(v, str): + raise TypeError("Argument `device` must be `str`.") + v = v.strip().lower() + if not (v.startswith("cuda") or v == "cpu"): + raise ValueError("Argument `device` must be one of 'cpu' or 'cuda:'.") + return v + + @field_validator("ptr") + def validate_ptr(cls, v: int) -> int: + if v == 0: + raise ValueError("Argument `ptr` cannot be zero (aka `null` or `None`).") + return v + + @field_validator("size") + def validate_size(cls, v: int) -> int: + if v < 0: + raise ValueError("Argument `size` must be an integer greater than or equal to zero.") + return v + + +class SerializedRequest(BaseModel): + """ + Pydantic serialization type for describing the passive side of a transfer. + """ + model_config = ConfigDict( + extra="forbid", + frozen=True, + arbitrary_types_allowed=True, + ) + descriptors: List[SerializedDescriptor] = [] + nixl_metadata: str = "" + notification_key: str = "" + operation_kind: int = 0 + + def to_descriptors(self) -> Descriptor | list[Descriptor]: + """ + Deserializes the request descriptor into a `dynamo.connect.Descriptor` or list of `dynamo.connect.Descriptor` objects. + """ + if len(self.descriptors) == 0: + raise ValueError("Request descriptor must contain at least one serialized descriptor.") + if len(self.descriptors) == 1: + return self.descriptors[0].to_descriptor() + return [item.to_descriptor() for item in self.descriptors] + + @field_validator("operation_kind") + def validate_operation_kind(cls, v: int) -> int: + if v < 1 or v > 3: + raise TypeError("Argument `operation_kind` must be an integer value of `dynamo.connect.OperationKind`.") + return v + + +class WritableOperation(PassiveOperation): + """ + Operation which can be awaited until written to by a `WriteOperation` from a remote worker. + """ + + def __init__( + self, + connector: Connector, + local_descriptors: Descriptor | list[Descriptor], + ) -> None: + """ + Creates a new instance of `WritableOperation`, registers the operation and descriptors w/ NIXL, + and enables an RDMA write operation to occur. + + Parameters + ---------- + connector : Connector + Connector instance to use for the operation. + local_descriptors : Descriptor | list[Descriptor] + Descriptors to receive data from a remote worker. + + Raises + TypeError + When `local` is not a `dynamo.connect.Connector`. + TypeError + When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`. + """ + super().__init__(connector, OperationKind.WRITE, local_descriptors) + logger.debug(f"Created {self.__repr__()}") + + def __del__(self) -> None: + super().__del__() + logger.debug(f"Deleted {self.__repr__()}") + + def __enter__(self) -> WritableOperation: + super().__enter__() + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + super().__exit__(exc_type, exc_value, traceback) + + def __repr__(self) -> str: + return super().__repr__() + + async def wait_for_completion(self) -> None: + """ + Blocks the caller asynchronously until the operation has completed. + """ + await super()._wait_for_completion_() + + +class WriteOperation(ActiveOperation): + """ + Awaitable write operation which initiates an RDMA write operation to a remote worker + which provided a `SerializedRequest` object from a `WritableOperation`. + """ + + def __init__( + self, + connector: Connector, + local_descriptors: Descriptor | list[Descriptor], + remote_request: SerializedRequest, + ) -> None: + """ + Creates a new instance of `WriteOperation`, registers `local_descriptors` with NIXL, + and begins an RDMA write operation which will transfer from `local_descriptors` to + remote target(s) described by `remote_request` + + Parameters + ---------- + connector : Connector + Connector instance to use for the operation. + local_descriptors : Descriptor | list[Descriptor] + Local descriptor(s) to send from, to the remote worker. + remote_request : SerializedRequest + Serialized request from the remote worker that describes the target(s) to send to. + + Raises + TypeError + When `connector` is not a `dynamo.connect.Connector`. + TypeError + When `remote_request` is not a `dynamo.connect.RequestDescriptor`. + ValueError + When `remote_request` is not of kind `WRITE`. + ValueError + When `remote_request.nixl_metadata` is not a non-empty `str`. + TypeError + When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`. + """ + if not isinstance(connector, Connector): + raise TypeError("Argument `connector` must be `dynamo.connect.Connector`.") + if not isinstance(remote_request, SerializedRequest): + raise TypeError("Argument `remote_request` must be `dynamo.connect.RequestDescriptor`.") + if remote_request.operation_kind != OperationKind.WRITE.value: + raise ValueError("Argument `remote_request` must be of kind `WRITE`.") + + remote = Remote(connector, remote_request.nixl_metadata) + remote_descriptors = remote_request.to_descriptors() + + super().__init__(remote, OperationKind.WRITE, local_descriptors, remote_descriptors, remote_request.notification_key) + logger.debug(f"Created {self.__repr__()}") + + def __del__(self) -> None: + super().__del__() + logger.debug(f"Deleted {self.__repr__()}") + + def __enter__(self) -> WriteOperation: + super().__enter__() + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + super().__exit__(exc_type, exc_value, traceback) + + def __repr__(self) -> str: + return super().__repr__() + + def cancel(self) -> None: + """ + Cancels the operation. + No affect if the operation has already completed or errored, or has been cancelled. + """ + super()._cancel_() + + async def wait_for_completion(self) -> None: + """ + Blocks the caller asynchronously until the operation has completed. + """ + await super()._wait_for_completion_() diff --git a/examples/multimodal_v1/launch/agg.sh b/examples/multimodal_v1/launch/agg.sh new file mode 100755 index 0000000000..297f3fc1f2 --- /dev/null +++ b/examples/multimodal_v1/launch/agg.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +set -e +trap 'echo Cleaning up...; kill 0' EXIT + +# Default values +MODEL_NAME="llava-hf/llava-1.5-7b-hf" +PROMPT_TEMPLATE="USER: \n ASSISTANT:" +PROVIDED_PROMPT_TEMPLATE="" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --model) + MODEL_NAME=$2 + shift 2 + ;; + --prompt-template) + PROVIDED_PROMPT_TEMPLATE=$2 + shift 2 + ;; + -h|--help) + echo "Usage: $0 [OPTIONS]" + echo "Options:" + echo " --model Specify the model to use (default: $MODEL_NAME)" + echo " --prompt-template