Skip to content

Commit 6d42ce8

Browse files
[CLI] Improve CLI arg parsing for -O/--compilation-config (#20156)
Signed-off-by: luka <luka@neuralmagic.com>
1 parent ded1fb6 commit 6d42ce8

File tree

5 files changed

+124
-40
lines changed

5 files changed

+124
-40
lines changed

tests/engine/test_arg_utils.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,32 +239,40 @@ def test_compilation_config():
239239
assert args.compilation_config == CompilationConfig()
240240

241241
# set to O3
242-
args = parser.parse_args(["-O3"])
243-
assert args.compilation_config.level == 3
242+
args = parser.parse_args(["-O0"])
243+
assert args.compilation_config.level == 0
244244

245245
# set to O 3 (space)
246-
args = parser.parse_args(["-O", "3"])
247-
assert args.compilation_config.level == 3
246+
args = parser.parse_args(["-O", "1"])
247+
assert args.compilation_config.level == 1
248248

249249
# set to O 3 (equals)
250-
args = parser.parse_args(["-O=3"])
250+
args = parser.parse_args(["-O=2"])
251+
assert args.compilation_config.level == 2
252+
253+
# set to O.level 3
254+
args = parser.parse_args(["-O.level", "3"])
251255
assert args.compilation_config.level == 3
252256

253257
# set to string form of a dict
254258
args = parser.parse_args([
255-
"--compilation-config",
256-
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
259+
"-O",
260+
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
261+
'"use_inductor": false}',
257262
])
258263
assert (args.compilation_config.level == 3 and
259-
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
264+
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
265+
and not args.compilation_config.use_inductor)
260266

261267
# set to string form of a dict
262268
args = parser.parse_args([
263269
"--compilation-config="
264-
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
270+
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
271+
'"use_inductor": true}',
265272
])
266273
assert (args.compilation_config.level == 3 and
267-
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
274+
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
275+
and args.compilation_config.use_inductor)
268276

269277

270278
def test_prefix_cache_default():

tests/test_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
import hashlib
77
import json
8+
import logging
89
import pickle
910
import socket
1011
from collections.abc import AsyncIterator
@@ -142,6 +143,7 @@ def parser():
142143
parser.add_argument('--batch-size', type=int)
143144
parser.add_argument('--enable-feature', action='store_true')
144145
parser.add_argument('--hf-overrides', type=json.loads)
146+
parser.add_argument('-O', '--compilation-config', type=json.loads)
145147
return parser
146148

147149

@@ -265,6 +267,11 @@ def test_dict_args(parser):
265267
"val2",
266268
"--hf-overrides.key2.key4",
267269
"val3",
270+
# Test compile config and compilation level
271+
"-O.use_inductor=true",
272+
"-O.backend",
273+
"custom",
274+
"-O1",
268275
# Test = sign
269276
"--hf-overrides.key5=val4",
270277
# Test underscore to dash conversion
@@ -281,6 +288,13 @@ def test_dict_args(parser):
281288
"true",
282289
"--hf_overrides.key12.key13",
283290
"null",
291+
# Test '-' and '.' in value
292+
"--hf_overrides.key14.key15",
293+
"-minus.and.dot",
294+
# Test array values
295+
"-O.custom_ops+",
296+
"-quant_fp8",
297+
"-O.custom_ops+=+silu_mul,-rms_norm",
284298
]
285299
parsed_args = parser.parse_args(args)
286300
assert parsed_args.model_name == "something.something"
@@ -301,9 +315,42 @@ def test_dict_args(parser):
301315
"key12": {
302316
"key13": None,
303317
},
318+
"key14": {
319+
"key15": "-minus.and.dot",
320+
}
321+
}
322+
assert parsed_args.compilation_config == {
323+
"level": 1,
324+
"use_inductor": True,
325+
"backend": "custom",
326+
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
304327
}
305328

306329

330+
def test_duplicate_dict_args(caplog_vllm, parser):
331+
args = [
332+
"--model-name=something.something",
333+
"--hf-overrides.key1",
334+
"val1",
335+
"--hf-overrides.key1",
336+
"val2",
337+
"-O1",
338+
"-O.level",
339+
"2",
340+
"-O3",
341+
]
342+
343+
parsed_args = parser.parse_args(args)
344+
# Should be the last value
345+
assert parsed_args.hf_overrides == {"key1": "val2"}
346+
assert parsed_args.compilation_config == {"level": 3}
347+
348+
assert len(caplog_vllm.records) == 1
349+
assert "duplicate" in caplog_vllm.text
350+
assert "--hf-overrides.key1" in caplog_vllm.text
351+
assert "-O.level" in caplog_vllm.text
352+
353+
307354
# yapf: enable
308355
@pytest.mark.parametrize(
309356
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",

vllm/config.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4140,9 +4140,9 @@ def __repr__(self) -> str:
41404140

41414141
@classmethod
41424142
def from_cli(cls, cli_value: str) -> "CompilationConfig":
4143-
"""Parse the CLI value for the compilation config."""
4144-
if cli_value in ["0", "1", "2", "3"]:
4145-
return cls(level=int(cli_value))
4143+
"""Parse the CLI value for the compilation config.
4144+
-O1, -O2, -O3, etc. is handled in FlexibleArgumentParser.
4145+
"""
41464146
return TypeAdapter(CompilationConfig).validate_json(cli_value)
41474147

41484148
def __post_init__(self) -> None:
@@ -4303,17 +4303,16 @@ class VllmConfig:
43034303
"""Quantization configuration."""
43044304
compilation_config: CompilationConfig = field(
43054305
default_factory=CompilationConfig)
4306-
"""`torch.compile` configuration for the model.
4306+
"""`torch.compile` and cudagraph capture configuration for the model.
43074307
4308-
When it is a number (0, 1, 2, 3), it will be interpreted as the
4309-
optimization level.
4308+
As a shorthand, `-O<n>` can be used to directly specify the compilation
4309+
level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
4310+
Currently, -O <n> and -O=<n> are supported as well but this will likely be
4311+
removed in favor of clearer -O<n> syntax in the future.
43104312
43114313
NOTE: level 0 is the default level without any optimization. level 1 and 2
43124314
are for internal testing only. level 3 is the recommended level for
4313-
production.
4314-
4315-
Following the convention of traditional compilers, using `-O` without space
4316-
is also supported. `-O3` is equivalent to `-O 3`.
4315+
production, also default in V1.
43174316
43184317
You can specify the full compilation config like so:
43194318
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`

vllm/engine/arg_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,10 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
202202
passed individually. For example, the following sets of arguments are
203203
equivalent:\n\n
204204
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
205-
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n"""
205+
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n
206+
Additionally, list elements can be passed individually using '+':
207+
- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
208+
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n"""
206209
if dataclass_cls is not None:
207210

208211
def parse_dataclass(val: str, cls=dataclass_cls) -> Any:

vllm/utils.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,15 @@
8989

9090
STR_NOT_IMPL_ENC_DEC_SWA = \
9191
"Sliding window attention for encoder/decoder models " + \
92-
"is not currently supported."
92+
"is not currently supported."
9393

9494
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
9595
"Prefix caching for encoder/decoder models " + \
96-
"is not currently supported."
96+
"is not currently supported."
9797

9898
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
9999
"Chunked prefill for encoder/decoder models " + \
100-
"is not currently supported."
100+
"is not currently supported."
101101

102102
STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
103103
"Models with logits_soft_cap "
@@ -752,7 +752,7 @@ def _generate_random_fp8(
752752
# to generate random data for fp8 data.
753753
# For example, s.11111.00 in fp8e5m2 format represents Inf.
754754
# | E4M3 | E5M2
755-
#-----|-------------|-------------------
755+
# -----|-------------|-------------------
756756
# Inf | N/A | s.11111.00
757757
# NaN | s.1111.111 | s.11111.{01,10,11}
758758
from vllm import _custom_ops as ops
@@ -840,7 +840,6 @@ def create_kv_caches_with_random(
840840
seed: Optional[int] = None,
841841
device: Optional[str] = "cuda",
842842
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
843-
844843
if cache_dtype == "fp8" and head_size % 16:
845844
raise ValueError(
846845
f"Does not support key cache of type fp8 with head_size {head_size}"
@@ -1205,7 +1204,6 @@ def deprecate_args(
12051204
is_deprecated: Union[bool, Callable[[], bool]] = True,
12061205
additional_message: Optional[str] = None,
12071206
) -> Callable[[F], F]:
1208-
12091207
if not callable(is_deprecated):
12101208
is_deprecated = partial(identity, is_deprecated)
12111209

@@ -1355,7 +1353,7 @@ def weak_bound(*args, **kwargs) -> None:
13551353
return weak_bound
13561354

13571355

1358-
#From: https://stackoverflow.com/a/4104188/2749989
1356+
# From: https://stackoverflow.com/a/4104188/2749989
13591357
def run_once(f: Callable[P, None]) -> Callable[P, None]:
13601358

13611359
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
@@ -1474,7 +1472,7 @@ def repl(match: re.Match) -> str:
14741472

14751473
# Convert underscores to dashes and vice versa in argument names
14761474
processed_args = list[str]()
1477-
for arg in args:
1475+
for i, arg in enumerate(args):
14781476
if arg.startswith('--'):
14791477
if '=' in arg:
14801478
key, value = arg.split('=', 1)
@@ -1483,10 +1481,17 @@ def repl(match: re.Match) -> str:
14831481
else:
14841482
key = pattern.sub(repl, arg, count=1)
14851483
processed_args.append(key)
1486-
elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:
1487-
# allow -O flag to be used without space, e.g. -O3
1488-
processed_args.append('-O')
1489-
processed_args.append(arg[2:])
1484+
elif arg.startswith('-O') and arg != '-O' and arg[2] != '.':
1485+
# allow -O flag to be used without space, e.g. -O3 or -Odecode
1486+
# -O.<...> handled later
1487+
# also handle -O=<level> here
1488+
level = arg[3:] if arg[2] == '=' else arg[2:]
1489+
processed_args.append(f'-O.level={level}')
1490+
elif arg == '-O' and i + 1 < len(args) and args[i + 1] in {
1491+
"0", "1", "2", "3"
1492+
}:
1493+
# Convert -O <n> to -O.level <n>
1494+
processed_args.append('-O.level')
14901495
else:
14911496
processed_args.append(arg)
14921497

@@ -1504,27 +1509,44 @@ def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
15041509
def recursive_dict_update(
15051510
original: dict[str, Any],
15061511
update: dict[str, Any],
1507-
):
1508-
"""Recursively updates a dictionary with another dictionary."""
1512+
) -> set[str]:
1513+
"""Recursively updates a dictionary with another dictionary.
1514+
Returns a set of duplicate keys that were overwritten.
1515+
"""
1516+
duplicates = set[str]()
15091517
for k, v in update.items():
15101518
if isinstance(v, dict) and isinstance(original.get(k), dict):
1511-
recursive_dict_update(original[k], v)
1519+
nested_duplicates = recursive_dict_update(original[k], v)
1520+
duplicates |= {f"{k}.{d}" for d in nested_duplicates}
1521+
elif isinstance(v, list) and isinstance(original.get(k), list):
1522+
original[k] += v
15121523
else:
1524+
if k in original:
1525+
duplicates.add(k)
15131526
original[k] = v
1527+
return duplicates
15141528

15151529
delete = set[int]()
15161530
dict_args = defaultdict[str, dict[str, Any]](dict)
1531+
duplicates = set[str]()
15171532
for i, processed_arg in enumerate(processed_args):
1518-
if processed_arg.startswith("--") and "." in processed_arg:
1533+
if i in delete: # skip if value from previous arg
1534+
continue
1535+
1536+
if processed_arg.startswith("-") and "." in processed_arg:
15191537
if "=" in processed_arg:
15201538
processed_arg, value_str = processed_arg.split("=", 1)
15211539
if "." not in processed_arg:
1522-
# False positive, . was only in the value
1540+
# False positive, '.' was only in the value
15231541
continue
15241542
else:
15251543
value_str = processed_args[i + 1]
15261544
delete.add(i + 1)
15271545

1546+
if processed_arg.endswith("+"):
1547+
processed_arg = processed_arg[:-1]
1548+
value_str = json.dumps(list(value_str.split(",")))
1549+
15281550
key, *keys = processed_arg.split(".")
15291551
try:
15301552
value = json.loads(value_str)
@@ -1533,12 +1555,17 @@ def recursive_dict_update(
15331555

15341556
# Merge all values with the same key into a single dict
15351557
arg_dict = create_nested_dict(keys, value)
1536-
recursive_dict_update(dict_args[key], arg_dict)
1558+
arg_duplicates = recursive_dict_update(dict_args[key],
1559+
arg_dict)
1560+
duplicates |= {f'{key}.{d}' for d in arg_duplicates}
15371561
delete.add(i)
15381562
# Filter out the dict args we set to None
15391563
processed_args = [
15401564
a for i, a in enumerate(processed_args) if i not in delete
15411565
]
1566+
if duplicates:
1567+
logger.warning("Found duplicate keys %s", ", ".join(duplicates))
1568+
15421569
# Add the dict args back as if they were originally passed as JSON
15431570
for dict_arg, dict_value in dict_args.items():
15441571
processed_args.append(dict_arg)
@@ -2405,7 +2432,7 @@ def memory_profiling(
24052432
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
24062433
24072434
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
2408-
""" # noqa
2435+
""" # noqa
24092436
gc.collect()
24102437
torch.cuda.empty_cache()
24112438
torch.cuda.reset_peak_memory_stats()

0 commit comments

Comments
 (0)