Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions components/backends/mocker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,47 @@ The mocker engine is a mock vLLM implementation designed for testing and develop
- Developing and debugging Dynamo components
- Load testing and performance analysis

**Basic usage:**
## Basic usage

The `--model-path` is required but can point to any valid model path - the mocker doesn't actually load the model weights (but the pre-processor needs the tokenizer). The arguments `block_size`, `num_gpu_blocks`, `max_num_seqs`, `max_num_batched_tokens`, `enable_prefix_caching`, and `enable_chunked_prefill` are common arguments shared with the real VLLM engine.
The mocker engine now supports a vLLM-style CLI interface with individual arguments for all configuration options.

And below are arguments that are mocker-specific:
- `speedup_ratio`: Speed multiplier for token generation (default: 1.0). Higher values make the simulation engines run faster.
- `dp_size`: Number of data parallel workers to simulate (default: 1)
- `watermark`: KV cache watermark threshold as a fraction (default: 0.01). This argument also exists for the real VLLM engine but cannot be passed as an engine arg.
### Required arguments:
- `--model-path`: Path to model directory or HuggingFace model ID (required for tokenizer)

### MockEngineArgs parameters (vLLM-style):
- `--num-gpu-blocks-override`: Number of GPU blocks for KV cache (default: 16384)
- `--block-size`: Token block size for KV cache blocks (default: 64)
- `--max-num-seqs`: Maximum number of sequences per iteration (default: 256)
- `--max-num-batched-tokens`: Maximum number of batched tokens per iteration (default: 8192)
- `--enable-prefix-caching` / `--no-enable-prefix-caching`: Enable/disable automatic prefix caching (default: True)
- `--enable-chunked-prefill` / `--no-enable-chunked-prefill`: Enable/disable chunked prefill (default: True)
- `--watermark`: KV cache watermark threshold as a fraction (default: 0.01)
- `--speedup-ratio`: Speed multiplier for token generation (default: 1.0). Higher values make the simulation engines run faster
- `--data-parallel-size`: Number of data parallel workers to simulate (default: 1)

### Example with individual arguments (vLLM-style):
```bash
echo '{"speedup_ratio": 10.0}' > mocker_args.json
python -m dynamo.mocker --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 --extra-engine-args mocker_args.json
# Start mocker with custom configuration
python -m dynamo.mocker \
--model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--num-gpu-blocks-override 8192 \
--block-size 16 \
--speedup-ratio 10.0 \
--max-num-seqs 512 \
--enable-prefix-caching

# Start frontend server
python -m dynamo.frontend --http-port 8080
```
```

### Legacy JSON file support:
For backward compatibility, you can still provide configuration via a JSON file:

```bash
echo '{"speedup_ratio": 10.0, "num_gpu_blocks": 8192}' > mocker_args.json
python -m dynamo.mocker \
--model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--extra-engine-args mocker_args.json
```

Note: If `--extra-engine-args` is provided, it overrides all individual CLI arguments.
181 changes: 163 additions & 18 deletions components/backends/mocker/src/dynamo/mocker/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

# Usage: `python -m dynamo.mocker --model-path /data/models/Qwen3-0.6B-Q8_0.gguf --extra-engine-args args.json`
# Usage: `python -m dynamo.mocker --model-path /data/models/Qwen3-0.6B-Q8_0.gguf`
# Now supports vLLM-style individual arguments for MockEngineArgs

import argparse
import json
import logging
import os
import tempfile
from pathlib import Path

import uvloop
Expand All @@ -19,35 +23,94 @@
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate"

configure_dynamo_logging()
logger = logging.getLogger(__name__)


def create_temp_engine_args_file(args) -> Path:
"""
Create a temporary JSON file with MockEngineArgs from CLI arguments.
Returns the path to the temporary file.
"""
engine_args = {}

# Only include non-None values that differ from defaults
# Note: argparse converts hyphens to underscores in attribute names
# Extract all potential engine arguments, using None as default for missing attributes
engine_args = {
"num_gpu_blocks": getattr(args, "num_gpu_blocks", None),
"block_size": getattr(args, "block_size", None),
"max_num_seqs": getattr(args, "max_num_seqs", None),
"max_num_batched_tokens": getattr(args, "max_num_batched_tokens", None),
"enable_prefix_caching": getattr(args, "enable_prefix_caching", None),
"enable_chunked_prefill": getattr(args, "enable_chunked_prefill", None),
"watermark": getattr(args, "watermark", None),
"speedup_ratio": getattr(args, "speedup_ratio", None),
"dp_size": getattr(args, "dp_size", None),
"startup_time": getattr(args, "startup_time", None),
}

# Remove None values to only include explicitly set arguments
engine_args = {k: v for k, v in engine_args.items() if v is not None}

# Create temporary file
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(engine_args, f, indent=2)
temp_path = Path(f.name)

logger.debug(f"Created temporary MockEngineArgs file at {temp_path}")
logger.debug(f"MockEngineArgs: {engine_args}")

return temp_path


@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
args = cmd_line_args()

# Create engine configuration
entrypoint_args = EntrypointArgs(
engine_type=EngineType.Mocker,
model_path=args.model_path,
model_name=args.model_name,
endpoint_id=args.endpoint,
extra_engine_args=args.extra_engine_args,
)

# Create and run the engine
# NOTE: only supports dyn endpoint for now
engine_config = await make_engine(runtime, entrypoint_args)
await run_input(runtime, args.endpoint, engine_config)
# Handle extra_engine_args: either use provided file or create from CLI args
if args.extra_engine_args:
# User provided explicit JSON file
extra_engine_args_path = args.extra_engine_args
logger.info(f"Using provided MockEngineArgs from {extra_engine_args_path}")
else:
# Create temporary JSON file from CLI arguments
extra_engine_args_path = create_temp_engine_args_file(args)
logger.info("Created MockEngineArgs from CLI arguments")

try:
# Create engine configuration
entrypoint_args = EntrypointArgs(
engine_type=EngineType.Mocker,
model_path=args.model_path,
model_name=args.model_name,
endpoint_id=args.endpoint,
extra_engine_args=extra_engine_args_path,
)

# Create and run the engine
# NOTE: only supports dyn endpoint for now
engine_config = await make_engine(runtime, entrypoint_args)
await run_input(runtime, args.endpoint, engine_config)
finally:
# Clean up temporary file if we created one
if not args.extra_engine_args and extra_engine_args_path.exists():
try:
extra_engine_args_path.unlink()
logger.debug(f"Cleaned up temporary file {extra_engine_args_path}")
except Exception as e:
logger.warning(f"Failed to clean up temporary file: {e}")


def cmd_line_args():
parser = argparse.ArgumentParser(
description="Mocker engine for testing Dynamo LLM infrastructure.",
description="Mocker engine for testing Dynamo LLM infrastructure with vLLM-style CLI.",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--version", action="version", version=f"Dynamo Mocker {__version__}"
)

# Basic configuration
parser.add_argument(
"--model-path",
type=str,
Expand All @@ -63,13 +126,95 @@ def cmd_line_args():
"--model-name",
type=str,
default=None,
help="Model name for API responses (default: mocker-engine)",
help="Model name for API responses (default: derived from model-path)",
)

# MockEngineArgs parameters (similar to vLLM style)
parser.add_argument(
"--num-gpu-blocks-override",
type=int,
dest="num_gpu_blocks", # Maps to num_gpu_blocks in MockEngineArgs
default=None,
help="Number of GPU blocks for KV cache (default: 16384)",
)
parser.add_argument(
"--block-size",
type=int,
default=None,
help="Token block size for KV cache blocks (default: 64)",
)
parser.add_argument(
"--max-num-seqs",
type=int,
default=None,
help="Maximum number of sequences per iteration (default: 256)",
)
parser.add_argument(
"--max-num-batched-tokens",
type=int,
default=None,
help="Maximum number of batched tokens per iteration (default: 8192)",
)
parser.add_argument(
"--enable-prefix-caching",
action="store_true",
dest="enable_prefix_caching",
default=None,
help="Enable automatic prefix caching (default: True)",
)
parser.add_argument(
"--no-enable-prefix-caching",
action="store_false",
dest="enable_prefix_caching",
default=None,
help="Disable automatic prefix caching",
)
parser.add_argument(
"--enable-chunked-prefill",
action="store_true",
dest="enable_chunked_prefill",
default=None,
help="Enable chunked prefill (default: True)",
)
parser.add_argument(
"--no-enable-chunked-prefill",
action="store_false",
dest="enable_chunked_prefill",
default=None,
help="Disable chunked prefill",
)
parser.add_argument(
"--watermark",
type=float,
default=None,
help="Watermark value for the mocker engine (default: 0.01)",
)
parser.add_argument(
"--speedup-ratio",
type=float,
default=None,
help="Speedup ratio for mock execution (default: 1.0)",
)
parser.add_argument(
"--data-parallel-size",
type=int,
dest="dp_size",
default=None,
help="Number of data parallel replicas (default: 1)",
)
parser.add_argument(
"--startup-time",
type=float,
default=None,
help="Simulated engine startup time in seconds (default: None)",
)

# Legacy support - allow direct JSON file specification
parser.add_argument(
"--extra-engine-args",
type=Path,
help="Path to JSON file with mocker configuration "
"(num_gpu_blocks, speedup_ratio, etc.)",
help="Path to JSON file with mocker configuration. "
"If provided, overrides individual CLI arguments.",
)

return parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion components/frontend/src/dynamo/frontend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def parse_args():
parser.add_argument(
"--http-port",
type=int,
default=int(os.environ.get("DYN_HTTP_PORT", "8080")),
default=int(os.environ.get("DYN_HTTP_PORT", "8000")),
help="HTTP port for the engine (u16). Can be set via DYN_HTTP_PORT env var.",
)
parser.add_argument(
Expand Down
19 changes: 7 additions & 12 deletions lib/llm/src/mocker/engine.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! MockSchedulerEngine - AsyncEngine wrapper around the Scheduler
//!
Expand Down Expand Up @@ -76,6 +64,13 @@ impl MockVllmEngine {
pub async fn start(&self, component: Component) -> Result<()> {
let cancel_token = component.drt().runtime().child_token();

// Simulate engine startup time if configured
if let Some(startup_time_secs) = self.engine_args.startup_time {
tracing::info!("Simulating engine startup time: {:.2}s", startup_time_secs);
tokio::time::sleep(Duration::from_secs_f64(startup_time_secs)).await;
tracing::info!("Engine startup simulation completed");
}

let (schedulers, kv_event_receiver) = self.start_schedulers(
self.engine_args.clone(),
self.active_requests.clone(),
Expand Down
12 changes: 0 additions & 12 deletions lib/llm/src/mocker/evictor.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::cmp::{Eq, Ordering};
use std::collections::{BTreeSet, HashMap};
Expand Down
12 changes: 0 additions & 12 deletions lib/llm/src/mocker/kv_manager.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! # KV Manager
//! A synchronous implementation of a block manager that handles MoveBlock signals for caching KV blocks.
Expand Down
Loading
Loading