|
33 | 33 | import warnings |
34 | 34 | import weakref |
35 | 35 | from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser, |
36 | | - ArgumentTypeError) |
| 36 | + ArgumentTypeError, _ArgumentGroup) |
37 | 37 | from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task |
38 | 38 | from collections import UserDict, defaultdict |
39 | 39 | from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, |
40 | 40 | Iterable, Iterator, KeysView, Mapping) |
41 | 41 | from concurrent.futures.process import ProcessPoolExecutor |
42 | 42 | from dataclasses import dataclass, field |
43 | 43 | from functools import cache, lru_cache, partial, wraps |
| 44 | +from gettext import gettext as _gettext |
44 | 45 | from types import MappingProxyType |
45 | 46 | from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, |
46 | 47 | Optional, Sequence, Tuple, Type, TypeVar, Union, cast, |
|
71 | 72 | from vllm.logger import enable_trace_function_call, init_logger |
72 | 73 |
|
73 | 74 | if TYPE_CHECKING: |
| 75 | + from argparse import Namespace |
| 76 | + |
74 | 77 | from vllm.config import ModelConfig, VllmConfig |
75 | 78 |
|
76 | 79 | logger = init_logger(__name__) |
@@ -1324,16 +1327,78 @@ def add_arguments(self, actions): |
1324 | 1327 | super().add_arguments(actions) |
1325 | 1328 |
|
1326 | 1329 |
|
| 1330 | +class _FlexibleArgumentGroup(_ArgumentGroup): |
| 1331 | + |
| 1332 | + def __init__(self, parser: FlexibleArgumentParser, *args, **kwargs): |
| 1333 | + self._parser = parser |
| 1334 | + super().__init__(*args, **kwargs) |
| 1335 | + |
| 1336 | + def add_argument(self, *args: Any, **kwargs: Any): |
| 1337 | + if sys.version_info < (3, 13): |
| 1338 | + deprecated = kwargs.pop('deprecated', False) |
| 1339 | + action = super().add_argument(*args, **kwargs) |
| 1340 | + object.__setattr__(action, 'deprecated', deprecated) |
| 1341 | + if deprecated and action.dest not in \ |
| 1342 | + self._parser.__class__._deprecated: |
| 1343 | + self._parser._deprecated.add(action) |
| 1344 | + return action |
| 1345 | + |
| 1346 | + # python>3.13 |
| 1347 | + return super().add_argument(*args, **kwargs) |
| 1348 | + |
| 1349 | + |
1327 | 1350 | class FlexibleArgumentParser(ArgumentParser): |
1328 | 1351 | """ArgumentParser that allows both underscore and dash in names.""" |
1329 | 1352 |
|
| 1353 | + _deprecated: set[Action] = set() |
| 1354 | + _seen: set[str] = set() |
| 1355 | + |
1330 | 1356 | def __init__(self, *args, **kwargs): |
1331 | 1357 | # Set the default 'formatter_class' to SortedHelpFormatter |
1332 | 1358 | if 'formatter_class' not in kwargs: |
1333 | 1359 | kwargs['formatter_class'] = SortedHelpFormatter |
1334 | 1360 | super().__init__(*args, **kwargs) |
1335 | 1361 |
|
1336 | | - def parse_args(self, args=None, namespace=None): |
| 1362 | + if sys.version_info < (3, 13): |
| 1363 | + |
| 1364 | + def parse_known_args( # type: ignore[override] |
| 1365 | + self, |
| 1366 | + args: Sequence[str] | None = None, |
| 1367 | + namespace: Namespace | None = None, |
| 1368 | + ) -> tuple[Namespace | None, list[str]]: |
| 1369 | + namespace, args = super().parse_known_args(args, namespace) |
| 1370 | + for action in FlexibleArgumentParser._deprecated: |
| 1371 | + if action.dest not in FlexibleArgumentParser._seen and getattr( |
| 1372 | + namespace, action.dest, |
| 1373 | + None) != action.default: # noqa: E501 |
| 1374 | + self._warning( |
| 1375 | + _gettext("argument '%(argument_name)s' is deprecated") |
| 1376 | + % {'argument_name': action.dest}) |
| 1377 | + FlexibleArgumentParser._seen.add(action.dest) |
| 1378 | + return namespace, args |
| 1379 | + |
| 1380 | + def add_argument(self, *args: Any, **kwargs: Any): |
| 1381 | + # add a deprecated=True compatibility |
| 1382 | + # for python < 3.13 |
| 1383 | + deprecated = kwargs.pop('deprecated', False) |
| 1384 | + action = super().add_argument(*args, **kwargs) |
| 1385 | + object.__setattr__(action, 'deprecated', deprecated) |
| 1386 | + if deprecated and \ |
| 1387 | + action not in FlexibleArgumentParser._deprecated: |
| 1388 | + self._deprecated.add(action) |
| 1389 | + |
| 1390 | + return action |
| 1391 | + |
| 1392 | + def _warning(self, message: str): |
| 1393 | + self._print_message( |
| 1394 | + _gettext('warning: %(message)s\n') % {'message': message}, |
| 1395 | + sys.stderr) |
| 1396 | + |
| 1397 | + def parse_args( # type: ignore[override] |
| 1398 | + self, |
| 1399 | + args: list[str] | None = None, |
| 1400 | + namespace: Namespace | None = None, |
| 1401 | + ): |
1337 | 1402 | if args is None: |
1338 | 1403 | args = sys.argv[1:] |
1339 | 1404 |
|
@@ -1504,6 +1569,15 @@ def _load_config_file(self, file_path: str) -> list[str]: |
1504 | 1569 |
|
1505 | 1570 | return processed_args |
1506 | 1571 |
|
| 1572 | + def add_argument_group( |
| 1573 | + self, |
| 1574 | + *args: Any, |
| 1575 | + **kwargs: Any, |
| 1576 | + ) -> _FlexibleArgumentGroup: |
| 1577 | + group = _FlexibleArgumentGroup(self, self, *args, **kwargs) |
| 1578 | + self._action_groups.append(group) |
| 1579 | + return group |
| 1580 | + |
1507 | 1581 |
|
1508 | 1582 | async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, |
1509 | 1583 | **kwargs): |
|
0 commit comments