diff --git a/components/backends/trtllm/README.md b/components/backends/trtllm/README.md index bf365e6039..9dbc266a9b 100644 --- a/components/backends/trtllm/README.md +++ b/components/backends/trtllm/README.md @@ -185,6 +185,7 @@ For comprehensive instructions on multinode serving, see the [multinode-examples ### Speculative Decoding - **[Llama 4 Maverick Instruct + Eagle Speculative Decoding](./llama4_plus_eagle.md)** +- **[Async Speculative Decoding](./async_spec_dec.md)** ### Kubernetes Deployment diff --git a/components/backends/trtllm/async_spec_dec.md b/components/backends/trtllm/async_spec_dec.md new file mode 100644 index 0000000000..93fb76b639 --- /dev/null +++ b/components/backends/trtllm/async_spec_dec.md @@ -0,0 +1,54 @@ + + +# Async Speculative Decoding + +This guide demonstrates how to run Draft-Target Model (DTM) speculative decoding asynchronously in Dynamo, where the draft model and target model run as separate Dynamo workers with the TRT-LLM backend. + +## Setup + +Follow the [Quickstart setup](./README.md#quick-start) instructions. Then, inside the container, run the following example: + +``` +cd $DYNAMO_HOME/components/backends/trtllm +./launch/spec_dec.sh +``` + +To scale up the number of drafters: + +``` +cd $DYNAMO_HOME/components/backends/trtllm +export NUM_DRAFTERS=2 +export DRAFTER_CUDA_VISIBLE_DEVICES:-"1,2" +./launch/spec_dec.sh +``` + +## Parallel Speculative Decoding + +To enable parallel speculative decoding, add the ```--parallel-spec-dec``` to the verifier: + +``` +# run verifier worker with speculative decoding +CUDA_VISIBLE_DEVICES=$VERIFIER_CUDA_VISIBLE_DEVICES \ +python3 -m dynamo.trtllm \ + --model-path "$MODEL_PATH" \ + --served-model-name "$SERVED_MODEL_NAME" \ + --extra-engine-args "$VERIFIER_ENGINE_ARGS" \ + --spec-dec-mode "verifier" & \ + --parallel-spec-dec +VERIFIER_PID=$! +``` \ No newline at end of file diff --git a/components/backends/trtllm/engine_configs/drafter.yaml b/components/backends/trtllm/engine_configs/drafter.yaml new file mode 100644 index 0000000000..4ea714f96c --- /dev/null +++ b/components/backends/trtllm/engine_configs/drafter.yaml @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +tensor_parallel_size: 1 +moe_expert_parallel_size: 1 +enable_attention_dp: false +max_num_tokens: 8192 +max_batch_size: 16 +trust_remote_code: true +backend: pytorch +enable_chunked_prefill: true + +kv_cache_config: + free_gpu_memory_fraction: 0.95 + +# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603 +# NOTE: overlap_scheduler enabled by default since this commit and changed +# config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler': +# https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428 +cuda_graph_config: + max_batch_size: 16 diff --git a/components/backends/trtllm/engine_configs/verifier.yaml b/components/backends/trtllm/engine_configs/verifier.yaml new file mode 100644 index 0000000000..67f8bf8fd1 --- /dev/null +++ b/components/backends/trtllm/engine_configs/verifier.yaml @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +backend: pytorch +tensor_parallel_size: 4 +moe_expert_parallel_size: 1 +enable_attention_dp: false +max_num_tokens: 1024 +max_batch_size: 16 +trust_remote_code: true +enable_chunked_prefill: true +# External API does not support overlap scheduler +disable_overlap_scheduler: true + +drafter_config: + drafter_endpoint: dyn://dynamo.tensorrt_llm.generate_draft + max_draft_len: 10 + +kv_cache_config: + enable_block_reuse: false + +next_client: dyn://dynamo.tensorrt_llm.generate_draft + +cuda_graph_config: + max_batch_size: 16 diff --git a/components/backends/trtllm/launch/spec_dec.sh b/components/backends/trtllm/launch/spec_dec.sh new file mode 100755 index 0000000000..5d1992b2ad --- /dev/null +++ b/components/backends/trtllm/launch/spec_dec.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Environment variables with defaults +# Verifier variables +export MODEL_PATH=${MODEL_PATH:-"meta-llama/Meta-Llama-3.1-8B-Instruct"} +export SERVED_MODEL_NAME=${SERVED_MODEL_NAME:-"meta-llama/Meta-Llama-3.1-8B-Instruct"} +export VERIFIER_ENGINE_ARGS=${VERIFIER_ENGINE_ARGS:-"engine_configs/verifier.yaml"} +export VERIFIER_CUDA_VISIBLE_DEVICES=${VERIFIER_CUDA_VISIBLE_DEVICES:-"0"} + +# Drafter variables +export NUM_DRAFTERS=${NUM_DRAFTERS:-1} +export DRAFTER_MODEL_PATH=${DRAFTER_MODEL_PATH:-"meta-llama/Meta-Llama-3.2-1B-Instruct"} +export DRAFTER_MODEL_NAME=${DRAFTER_MODEL_NAME:-"meta-llama/Meta-Llama-3.2-1B-Instruct"} +export DRAFTER_ENGINE_ARGS=${DRAFTER_ENGINE_ARGS:-"engine_configs/drafter.yaml"} +export DRAFTER_CUDA_VISIBLE_DEVICES=${DRAFTER_CUDA_VISIBLE_DEVICES:-"1"} + +declare -a DRAFTER_PIDS=() + +# Check enough GPUs for drafters +IFS=',' read -ra CUDA_DEVICES <<< "$DRAFTER_CUDA_VISIBLE_DEVICES" +if [ ${#CUDA_DEVICES[@]} -lt $NUM_DRAFTERS ]; then + echo "Error: Not enough CUDA devices specified for drafters. Need $NUM_DRAFTERS devices, but only ${#CUDA_DEVICES[@]} provided." + exit 1 +fi + +# Check num drafters >= 1 +if [[ $NUM_DRAFTERS -lt 1 ]]; then + echo "Error: NUM_DRAFTERS must be >= 1, got: $NUM_DRAFTERS" + exit 1 +fi + +# Setup cleanup trap +cleanup() { + echo "Cleaning up background processes..." + kill $DYNAMO_PID $VERIFIER_PID "${DRAFTER_PIDS[@]}" 2>/dev/null || true + wait $DYNAMO_PID $VERIFIER_PID "${DRAFTER_PIDS[@]}" 2>/dev/null || true + echo "Cleanup complete." +} +trap cleanup EXIT INT TERM + +# run clear_namespace +python3 utils/clear_namespace.py --namespace dynamo + +# run frontend +python3 -m dynamo.frontend --http-port 8000 & +DYNAMO_PID=$! + +# run verifier worker with speculative decoding +CUDA_VISIBLE_DEVICES=$VERIFIER_CUDA_VISIBLE_DEVICES \ +python3 -m dynamo.trtllm \ + --model-path "$MODEL_PATH" \ + --served-model-name "$SERVED_MODEL_NAME" \ + --extra-engine-args "$VERIFIER_ENGINE_ARGS" \ + --spec-dec-mode "verifier" & +VERIFIER_PID=$! + +# run drafter workers without speculative decoding +start_drafter() { + local cuda_device=$1 + CUDA_VISIBLE_DEVICES=$cuda_device \ + python3 -m dynamo.trtllm \ + --endpoint "dyn://dynamo.tensorrt_llm.generate_draft" \ + --model-path "$DRAFTER_MODEL_PATH" \ + --served-model-name "$DRAFTER_MODEL_NAME" \ + --extra-engine-args "$DRAFTER_ENGINE_ARGS" \ + --spec-dec-mode "drafter" & + DRAFTER_PIDS+=($!) +} + +for ((i=0; i<$NUM_DRAFTERS; i++)); do + start_drafter ${CUDA_DEVICES[$i]} +done + +wait diff --git a/components/backends/trtllm/src/dynamo/trtllm/main.py b/components/backends/trtllm/src/dynamo/trtllm/main.py index 8e54f00df2..7d313d23a3 100644 --- a/components/backends/trtllm/src/dynamo/trtllm/main.py +++ b/components/backends/trtllm/src/dynamo/trtllm/main.py @@ -15,6 +15,7 @@ KvCacheConfig, SchedulerConfig, ) +from tensorrt_llm.llmapi.llm_args import ExternalAPIConfig, UserProvidedDecodingConfig from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options from tensorrt_llm.llmapi.tokenizer import tokenizer_factory from torch.cuda import device_count @@ -30,10 +31,13 @@ RequestHandlerConfig, RequestHandlerFactory, ) +from dynamo.trtllm.utils.api_drafter import DynamoAPIDrafter from dynamo.trtllm.utils.trtllm_utils import ( Config, cmd_line_args, + is_drafter, is_first_worker, + is_verifier, parse_endpoint, ) @@ -122,6 +126,7 @@ async def init(runtime: DistributedRuntime, config: Config): dynamic_batch_config=dynamic_batch_config, ) modality = getattr(config, "modality", None) or "text" + parallel_spec_dec = getattr(config, "parallel_spec_dec", False) arg_map = { "model": model_path, "scheduler_config": scheduler_config, @@ -165,6 +170,39 @@ async def init(runtime: DistributedRuntime, config: Config): ) sys.exit(1) + if is_verifier(config): + # Set speculative_config to use DynamoAPIDrafter in the verifier worker. + try: + drafter_config = arg_map["drafter_config"] + drafter_endpoint = drafter_config["drafter_endpoint"] + if "max_draft_len" in drafter_config: + max_draft_len = drafter_config["max_draft_len"] + else: + max_draft_len = 10 + except KeyError: + raise ValueError( + "Verifier worker requires drafter_config to be specified in the engine config with drafter_endpoint." + ) + del arg_map["drafter_config"] + + drafter_config = ExternalAPIConfig( + endpoint=drafter_endpoint, + max_draft_len=max_draft_len, + ) + drafter = DynamoAPIDrafter(spec_config=drafter_config, draft_client=next_client) + spec_config = UserProvidedDecodingConfig( + drafter=drafter, + max_draft_len=max_draft_len, + ) + arg_map["speculative_config"] = spec_config + + parallel_spec_dec_config = None + if parallel_spec_dec: + parallel_spec_dec_config = { + "pre_verify": True, + "old_draft_tokens": [], + } + logging.info(f"TensorRT-LLM engine args: {arg_map}") engine_args = arg_map @@ -175,7 +213,6 @@ async def init(runtime: DistributedRuntime, config: Config): default_sampling_params.stop = None modelType = ModelType.Backend multimodal_processor = None - if modality == "multimodal": engine_args["skip_tokenizer_init"] = False modelType = ModelType.Chat @@ -194,9 +231,9 @@ async def init(runtime: DistributedRuntime, config: Config): async with get_llm_engine(engine_args) as engine: endpoint = component.endpoint(config.endpoint) - - if is_first_worker(config): + if is_first_worker(config) and not is_drafter(config): # Register the model with the endpoint if only the worker is first in the disaggregation chain. + # Drafter workers should not register. await register_llm( modelType, endpoint, @@ -215,8 +252,8 @@ async def init(runtime: DistributedRuntime, config: Config): disaggregation_strategy=config.disaggregation_strategy, next_client=next_client, multimodal_processor=multimodal_processor, + parallel_spec_dec_config=parallel_spec_dec_config, ) - if config.publish_events_and_metrics and is_first_worker(config): # Initialize and pass in the publisher to the request handler to # publish events and metrics. diff --git a/components/backends/trtllm/src/dynamo/trtllm/request_handlers/handler_base.py b/components/backends/trtllm/src/dynamo/trtllm/request_handlers/handler_base.py index 55dae25de7..ea8f0034ef 100644 --- a/components/backends/trtllm/src/dynamo/trtllm/request_handlers/handler_base.py +++ b/components/backends/trtllm/src/dynamo/trtllm/request_handlers/handler_base.py @@ -60,6 +60,9 @@ class RequestHandlerConfig: multimodal_processor: Optional[ MultimodalRequestProcessor ] = None # for multimodal support + parallel_spec_dec_config: Optional[ + dict + ] = None # for parallel speculative decoding support class HandlerBase: @@ -76,6 +79,7 @@ def __init__(self, config: RequestHandlerConfig): self.disaggregation_strategy = config.disaggregation_strategy self.next_client = config.next_client self.multimodal_processor = config.multimodal_processor + self.parallel_spec_dec_config = config.parallel_spec_dec_config self.first_generation = True def check_error(self, result: dict): @@ -174,6 +178,7 @@ async def generate_locally(self, request: dict): sampling_params=sampling_params, disaggregated_params=disaggregated_params, streaming=streaming, + parallel_spec_dec_params=self.parallel_spec_dec_config, ): # TRTLLM engine needs to start generating tokens first before stats # can be retrieved. diff --git a/components/backends/trtllm/src/dynamo/trtllm/utils/api_drafter.py b/components/backends/trtllm/src/dynamo/trtllm/utils/api_drafter.py new file mode 100644 index 0000000000..c3d211f73f --- /dev/null +++ b/components/backends/trtllm/src/dynamo/trtllm/utils/api_drafter.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List + +from tensorrt_llm._torch.speculative.external_api import APIDrafter + +from dynamo.runtime.logging import configure_dynamo_logging + +configure_dynamo_logging() + + +class DynamoAPIDrafter(APIDrafter): + """ + Custom Dynamo drafter to support internal Dynamo endpoints instead of only HTTP endpoints. + """ + + def __init__(self, spec_config, draft_client): + super().__init__(spec_config) + if draft_client is None: + raise ValueError( + "next_client must be provided when using parallel speculative decoding" + ) + self.client = draft_client + self.max_draft_len = spec_config.max_draft_len + + async def get_draft_tokens( + self, + prefix: list[int], + request_id: int, + end_id: int, + max_sequence_length: int, + ) -> List[int]: + request_data = { + "token_ids": prefix, + "sampling_options": {}, + "stop_conditions": { + "max_tokens": self.max_draft_len, + }, + } + + draft_tokens: List[int] = [] + try: + response = await self.client.round_robin(request_data) + async for chunk in response: + chunk_data = chunk.data() + if chunk_data.get("finish_reason"): + break + draft_tokens.extend(chunk_data.get("token_ids", [])) + if len(draft_tokens) >= self.max_draft_len: + break + return draft_tokens[: self.max_draft_len] + except Exception as e: + logging.error( + f"Failed to get draft tokens for Dynamo endpoint {self.endpoint} with error: {e}" + ) + raise e diff --git a/components/backends/trtllm/src/dynamo/trtllm/utils/trtllm_utils.py b/components/backends/trtllm/src/dynamo/trtllm/utils/trtllm_utils.py index 0aae7b0801..3f6350b3b5 100644 --- a/components/backends/trtllm/src/dynamo/trtllm/utils/trtllm_utils.py +++ b/components/backends/trtllm/src/dynamo/trtllm/utils/trtllm_utils.py @@ -46,7 +46,9 @@ def __init__(self) -> None: DEFAULT_DISAGGREGATION_STRATEGY ) self.next_endpoint: str = "" + self.spec_dec_mode: Optional[str] = None self.modality: str = "text" + self.parallel_spec_dec: bool = False def __str__(self) -> str: return ( @@ -71,7 +73,9 @@ def __str__(self) -> str: f"disaggregation_mode={self.disaggregation_mode}, " f"disaggregation_strategy={self.disaggregation_strategy}, " f"next_endpoint={self.next_endpoint}, " - f"modality={self.modality})" + f"spec_dec_mode={self.spec_dec_mode}, " + f"modality={self.modality}, " + f"parallel_spec_dec={self.parallel_spec_dec})" ) @@ -93,6 +97,20 @@ def is_first_worker(config): return is_primary_worker +def is_drafter(config): + """ + Check if the current worker is a drafter worker. + """ + return config.spec_dec_mode == "drafter" + + +def is_verifier(config): + """ + Check if the current worker is a verifier worker. + """ + return config.spec_dec_mode == "verifier" + + def parse_endpoint(endpoint: str) -> tuple[str, str, str]: endpoint_str = endpoint.replace("dyn://", "", 1) endpoint_parts = endpoint_str.split(".") @@ -230,6 +248,18 @@ def cmd_line_args(): default="", help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) to send requests to when running in disaggregation mode. Default: {DEFAULT_NEXT_ENDPOINT} if first worker, empty if next worker", ) + parser.add_argument( + "--spec-dec-mode", + type=str, + default=None, + choices=["drafter", "verifier"], + help="Mode to use for speculative decoding. Options: 'drafter', 'verifier'. Default: None", + ) + parser.add_argument( + "--parallel-spec-dec", + action="store_true", + help="Enable parallel speculative decoding.", + ) args = parser.parse_args() config = Config() @@ -288,6 +318,8 @@ def cmd_line_args(): config.migration_limit = args.migration_limit config.extra_engine_args = args.extra_engine_args config.publish_events_and_metrics = args.publish_events_and_metrics + config.spec_dec_mode = args.spec_dec_mode config.modality = args.modality + config.parallel_spec_dec = args.parallel_spec_dec return config