|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +# Example cli using the Python bindings, similar to `dynamo-run`. |
| 5 | +# Usage: `python cli.py in=text out=mistralrs <your-model>`. |
| 6 | +# Must be in a virtualenv with the Dynamo bindings (or wheel) installed. |
| 7 | + |
| 8 | +import argparse |
| 9 | +import asyncio |
| 10 | +import sys |
| 11 | +from pathlib import Path |
| 12 | + |
| 13 | +import uvloop |
| 14 | + |
| 15 | +from dynamo.llm import EngineType, EntrypointArgs, make_engine, run_input |
| 16 | +from dynamo.runtime import DistributedRuntime |
| 17 | + |
| 18 | + |
| 19 | +def parse_args(): |
| 20 | + in_mode = "text" |
| 21 | + out_mode = "echo" |
| 22 | + batch_file = None # Specific to in_mode="batch" |
| 23 | + |
| 24 | + # List to hold arguments that argparse will process (flags and model path) |
| 25 | + argparse_args = [] |
| 26 | + |
| 27 | + # --- Step 1: Manual Pre-parsing for 'in=' and 'out=' --- |
| 28 | + # Iterate through sys.argv[1:] to extract in= and out= |
| 29 | + # and collect remaining arguments for argparse. |
| 30 | + for arg in sys.argv[1:]: |
| 31 | + if arg.startswith("in="): |
| 32 | + in_val = arg[len("in=") :] |
| 33 | + if in_val.startswith("batch:"): |
| 34 | + in_mode = "batch" |
| 35 | + batch_file = in_val[len("batch:") :] |
| 36 | + else: |
| 37 | + in_mode = in_val |
| 38 | + elif arg.startswith("out="): |
| 39 | + out_mode = arg[len("out=") :] |
| 40 | + else: |
| 41 | + # This argument is not 'in=' or 'out=', so it's either a flag or the model path |
| 42 | + argparse_args.append(arg) |
| 43 | + |
| 44 | + # --- Step 2: Argparse for flags and the model path --- |
| 45 | + parser = argparse.ArgumentParser( |
| 46 | + description="Dynamo CLI: Connect inputs to an engine", |
| 47 | + formatter_class=argparse.RawTextHelpFormatter, # To preserve multi-line help formatting |
| 48 | + ) |
| 49 | + |
| 50 | + # model_name: Option<String> |
| 51 | + parser.add_argument("--model-name", type=str, help="Name of the model to load.") |
| 52 | + # model_config: Option<PathBuf> |
| 53 | + parser.add_argument( |
| 54 | + "--model-config", type=Path, help="Path to the model configuration file." |
| 55 | + ) |
| 56 | + # context_length: Option<u32> |
| 57 | + parser.add_argument( |
| 58 | + "--context-length", type=int, help="Maximum context length for the model (u32)." |
| 59 | + ) |
| 60 | + # template_file: Option<PathBuf> |
| 61 | + parser.add_argument( |
| 62 | + "--template-file", |
| 63 | + type=Path, |
| 64 | + help="Path to the template file for text generation.", |
| 65 | + ) |
| 66 | + # kv_cache_block_size: Option<u32> |
| 67 | + parser.add_argument( |
| 68 | + "--kv-cache-block-size", type=int, help="KV cache block size (u32)." |
| 69 | + ) |
| 70 | + # http_port: Option<u16> |
| 71 | + parser.add_argument("--http-port", type=int, help="HTTP port for the engine (u16).") |
| 72 | + |
| 73 | + # TODO: Not yet used here |
| 74 | + parser.add_argument( |
| 75 | + "--tensor-parallel-size", |
| 76 | + type=int, |
| 77 | + help="Tensor parallel size for the model (e.g., 4).", |
| 78 | + ) |
| 79 | + |
| 80 | + # Add the positional model argument. |
| 81 | + # It's made optional (nargs='?') because its requirement depends on 'out_mode', |
| 82 | + # which is handled in post-parsing validation. |
| 83 | + parser.add_argument( |
| 84 | + "model", |
| 85 | + nargs="?", # Make it optional for argparse, we'll validate manually |
| 86 | + help="Path to the model (e.g., Qwen/Qwen3-0.6B).\n" "Required unless out=dyn.", |
| 87 | + ) |
| 88 | + |
| 89 | + # Parse the arguments that were not 'in=' or 'out=' |
| 90 | + flags = parser.parse_args(argparse_args) |
| 91 | + |
| 92 | + # --- Step 3: Post-parsing Validation and Final Assignment --- |
| 93 | + |
| 94 | + # Validate 'batch' mode requires a file path |
| 95 | + if in_mode == "batch" and not batch_file: |
| 96 | + parser.error("Batch mode requires a file path: in=batch:FILE") |
| 97 | + |
| 98 | + # Validate model path requirement based on 'out_mode' |
| 99 | + if out_mode != "dyn" and flags.model is None: |
| 100 | + parser.error("Model path is required unless out=dyn.") |
| 101 | + |
| 102 | + # Consolidate all parsed arguments into a dictionary |
| 103 | + parsed_args = { |
| 104 | + "in_mode": in_mode, |
| 105 | + "out_mode": out_mode, |
| 106 | + "batch_file": batch_file, # Will be None if in_mode is not "batch" |
| 107 | + "model_path": flags.model, |
| 108 | + "flags": flags, |
| 109 | + } |
| 110 | + |
| 111 | + return parsed_args |
| 112 | + |
| 113 | + |
| 114 | +async def run(): |
| 115 | + loop = asyncio.get_running_loop() |
| 116 | + runtime = DistributedRuntime(loop, False) |
| 117 | + |
| 118 | + args = parse_args() |
| 119 | + |
| 120 | + engine_type_map = { |
| 121 | + "echo": EngineType.Echo, |
| 122 | + "mistralrs": EngineType.MistralRs, |
| 123 | + "llamacpp": EngineType.LlamaCpp, |
| 124 | + "dyn": EngineType.Dynamic, |
| 125 | + } |
| 126 | + out_mode = args["out_mode"] |
| 127 | + engine_type = engine_type_map.get(out_mode) |
| 128 | + if engine_type is None: |
| 129 | + print(f"Unsupported output type: {out_mode}") |
| 130 | + sys.exit(1) |
| 131 | + |
| 132 | + # TODO: The "vllm", "sglang" and "trtllm" cases should call Python directly |
| 133 | + |
| 134 | + entrypoint_kwargs = {"model_path": args["model_path"]} |
| 135 | + |
| 136 | + flags = args["flags"] |
| 137 | + if flags.model_name is not None: |
| 138 | + entrypoint_kwargs["model_name"] = flags.model_name |
| 139 | + if flags.model_config is not None: |
| 140 | + entrypoint_kwargs["model_config"] = flags.model_config |
| 141 | + if flags.context_length is not None: |
| 142 | + entrypoint_kwargs["context_length"] = flags.context_length |
| 143 | + if flags.template_file is not None: |
| 144 | + entrypoint_kwargs["template_file"] = flags.template_file |
| 145 | + if flags.kv_cache_block_size is not None: |
| 146 | + entrypoint_kwargs["kv_cache_block_size"] = flags.kv_cache_block_size |
| 147 | + if flags.http_port is not None: |
| 148 | + entrypoint_kwargs["http_port"] = flags.http_port |
| 149 | + |
| 150 | + e = EntrypointArgs(engine_type, **entrypoint_kwargs) |
| 151 | + engine = await make_engine(runtime, e) |
| 152 | + await run_input(runtime, args["in_mode"], engine) |
| 153 | + |
| 154 | + |
| 155 | +if __name__ == "__main__": |
| 156 | + uvloop.run(run()) |
0 commit comments