From f212ad8af4788c678cd9c9e56fb995d0a7e663d9 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 31 Oct 2023 12:24:47 -0700 Subject: [PATCH 1/5] Implement aliases --- .../examples/04_additional/11_aliases.rst | 98 +++++++++++++++++++ examples/04_additional/11_aliases.py | 41 ++++++++ tests/test_conf.py | 75 ++++++++++++++ tyro/_arguments.py | 15 ++- tyro/_fields.py | 28 ++++-- tyro/conf/_confstruct.py | 27 +++-- 6 files changed, 269 insertions(+), 15 deletions(-) create mode 100644 docs/source/examples/04_additional/11_aliases.rst create mode 100644 examples/04_additional/11_aliases.py diff --git a/docs/source/examples/04_additional/11_aliases.rst b/docs/source/examples/04_additional/11_aliases.rst new file mode 100644 index 00000000..e364f40c --- /dev/null +++ b/docs/source/examples/04_additional/11_aliases.rst @@ -0,0 +1,98 @@ +.. Comment: this file is automatically generated by `update_example_docs.py`. + It should not be modified manually. + +Argument aliases +========================================== + + +:func:`tyro.conf.arg()` can be used to attach aliases to arguments. + + + +.. code-block:: python + :linenos: + + + from typing_extensions import Annotated + + import tyro + + + def checkout( + branch: Annotated[str, tyro.conf.arg(aliases=["-b"])], + ) -> None: + """Check out a branch.""" + print(f"{branch=}") + + + def commit( + message: Annotated[str, tyro.conf.arg(aliases=["-m"])], + all: Annotated[str, tyro.conf.arg(aliases=["-a"])], + ) -> None: + """Make a commit.""" + print(f"{message=} {all=}") + + + if __name__ == "__main__": + tyro.extras.subcommand_cli_from_dict( + { + "checkout": checkout, + "commit": commit, + } + ) + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py --help + +.. program-output:: python ../../examples/04_additional/11_aliases.py --help + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py commit --help + +.. program-output:: python ../../examples/04_additional/11_aliases.py commit --help + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py commit --message hello --all + +.. program-output:: python ../../examples/04_additional/11_aliases.py commit --message hello --all + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py commit -m hello -a + +.. program-output:: python ../../examples/04_additional/11_aliases.py commit -m hello -a + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py checkout --help + +.. program-output:: python ../../examples/04_additional/11_aliases.py checkout --help + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py checkout --branch main + +.. program-output:: python ../../examples/04_additional/11_aliases.py checkout --branch main + +------------ + +.. raw:: html + + python 04_additional/11_aliases.py checkout -b main + +.. program-output:: python ../../examples/04_additional/11_aliases.py checkout -b main diff --git a/examples/04_additional/11_aliases.py b/examples/04_additional/11_aliases.py new file mode 100644 index 00000000..2e55c529 --- /dev/null +++ b/examples/04_additional/11_aliases.py @@ -0,0 +1,41 @@ +"""Argument aliases + +:func:`tyro.conf.arg()` can be used to attach aliases to arguments. + +Usage: +`python ./11_aliases.py --help` +`python ./11_aliases.py commit --help` +`python ./11_aliases.py commit --message hello --all` +`python ./11_aliases.py commit -m hello -a` +`python ./11_aliases.py checkout --help` +`python ./11_aliases.py checkout --branch main` +`python ./11_aliases.py checkout -b main` +""" + +from typing_extensions import Annotated + +import tyro + + +def checkout( + branch: Annotated[str, tyro.conf.arg(aliases=["-b"])], +) -> None: + """Check out a branch.""" + print(f"{branch=}") + + +def commit( + message: Annotated[str, tyro.conf.arg(aliases=["-m"])], + all: Annotated[str, tyro.conf.arg(aliases=["-a"])], +) -> None: + """Make a commit.""" + print(f"{message=} {all=}") + + +if __name__ == "__main__": + tyro.extras.subcommand_cli_from_dict( + { + "checkout": checkout, + "commit": commit, + } + ) diff --git a/tests/test_conf.py b/tests/test_conf.py index 96a6f5f8..57338d11 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -1090,3 +1090,78 @@ class Config: assert "We're missing arguments" in error assert "'a'" in error assert "'b'" not in error + + +def test_alias() -> None: + """Arguments with aliases.""" + + @dataclasses.dataclass + class Struct: + a: Annotated[int, tyro.conf.arg(aliases=["--all", "-d"])] + b: int + c: int = 3 + + def make_float(struct: Struct) -> float: + return struct.a * struct.b * struct.c + + @dataclasses.dataclass + class Config: + x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23 + + assert tyro.cli(Config, args=[]) == Config(x=3.23) + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 --x.struct.a 5".split(" ") + ) == Config(x=30) + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 -d 5".split(" ") + ) == Config(x=30) + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 --all 5".split(" ") + ) == Config(x=30) + assert tyro.cli(Config, args="--x.struct.b 2 --x.struct.a 5".split(" ")) == Config(x=30) + + # --x.struct.b is required! + with pytest.raises(SystemExit): + tyro.cli(Config, args="--x.struct.a 5".split(" ")) + + # --x.struct.a and --x.struct.b are required! + target = io.StringIO() + with pytest.raises(SystemExit), contextlib.redirect_stdout(target): + tyro.cli(Config, args="--x.struct.b 5".split(" ")) + error = target.getvalue() + assert "We're missing arguments" in error + assert "'a'" in error + assert "'b'" not in error + + +def test_positional_alias() -> None: + """Positional arguments with aliases (which will be ignored).""" + + @dataclasses.dataclass + class Struct: + a: Annotated[tyro.conf.Positional[int], tyro.conf.arg(aliases=["--all", "-d"])] + b: int + c: int = 3 + + def make_float(struct: Struct) -> float: + return struct.a * struct.b * struct.c + + @dataclasses.dataclass + class Config: + x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23 + + with pytest.warns(UserWarning): + assert tyro.cli(Config, args=[]) == Config(x=3.23) + with pytest.warns(UserWarning): + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 5".split(" ") + ) == Config(x=30) + + with pytest.raises(SystemExit), pytest.warns(UserWarning): + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 -d 5".split(" ") + ) == Config(x=30) + with pytest.raises(SystemExit), pytest.warns(UserWarning): + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 --all 5".split(" ") + ) == Config(x=30) diff --git a/tyro/_arguments.py b/tyro/_arguments.py index 79924af0..4a6cd8d7 100644 --- a/tyro/_arguments.py +++ b/tyro/_arguments.py @@ -141,8 +141,19 @@ def add_argument( if self.field.argconf.metavar is not None: kwargs["metavar"] = self.field.argconf.metavar - # Add argument! - arg = parser.add_argument(name_or_flag, **kwargs) + # Add argument, with aliases if available. + if self.field.argconf.aliases is not None and not self.field.is_positional(): + arg = parser.add_argument( + name_or_flag, *self.field.argconf.aliases, **kwargs + ) + else: + if self.field.argconf.aliases is not None: + import warnings + + warnings.warn( + f"Aliases were specified, but {name_or_flag} is positional. Aliases will be ignored." + ) + arg = parser.add_argument(name_or_flag, **kwargs) # Do our best to tab complete paths. # There will be false positives here, but if choices is unset they should be diff --git a/tyro/_fields.py b/tyro/_fields.py index 295bcb95..95e9ecdb 100644 --- a/tyro/_fields.py +++ b/tyro/_fields.py @@ -77,7 +77,7 @@ def __post_init__(self): @staticmethod def make( name: str, - type_or_callable: TypeForm[Any], + type_or_callable: Union[TypeForm[Any], Callable], default: Any, helptext: Optional[str], call_argname_override: Optional[Any] = None, @@ -88,12 +88,26 @@ def make( _, argconfs = _resolver.unwrap_annotated( type_or_callable, _confstruct._ArgConfiguration ) - if len(argconfs) == 0: - argconf = _confstruct._ArgConfiguration(None, None, None, True, None) - else: - assert len(argconfs) == 1 - (argconf,) = argconfs - helptext = argconf.help + argconf = _confstruct._ArgConfiguration( + None, + None, + help=None, + aliases=None, + prefix_name=True, + constructor_factory=None, + ) + for overwrite_argconf in argconfs: + # Apply any annotated argument configuration values. + argconf = dataclasses.replace( + argconf, + **{ + field.name: getattr(overwrite_argconf, field.name) + for field in dataclasses.fields(overwrite_argconf) + if getattr(overwrite_argconf, field.name) is not None + }, + ) + if argconf.help is not None: + helptext = argconf.help type_or_callable, inferred_markers = _resolver.unwrap_annotated( type_or_callable, _markers._Marker diff --git a/tyro/conf/_confstruct.py b/tyro/conf/_confstruct.py index 1b5c39fd..f225c6b0 100644 --- a/tyro/conf/_confstruct.py +++ b/tyro/conf/_confstruct.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Callable, Optional, Type, Union, overload +from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload from .._fields import MISSING_NONPROP @@ -112,10 +112,13 @@ def subcommand( @dataclasses.dataclass(frozen=True) class _ArgConfiguration: + # These are all optional by default in order to support multiple tyro.conf.arg() + # annotations. A None value means "don't overwrite the current value". name: Optional[str] metavar: Optional[str] help: Optional[str] - prefix_name: bool + aliases: Optional[Tuple[str, ...]] + prefix_name: Optional[bool] constructor_factory: Optional[Callable[[], Union[Type, Callable]]] @@ -125,7 +128,8 @@ def arg( name: Optional[str] = None, metavar: Optional[str] = None, help: Optional[str] = None, - prefix_name: bool = True, + aliases: Optional[Sequence[str]] = None, + prefix_name: Optional[bool] = None, constructor: None = None, constructor_factory: Optional[Callable[[], Union[Type, Callable]]] = None, ) -> Any: @@ -138,7 +142,8 @@ def arg( name: Optional[str] = None, metavar: Optional[str] = None, help: Optional[str] = None, - prefix_name: bool = True, + aliases: Optional[Sequence[str]] = None, + prefix_name: Optional[bool] = None, constructor: Optional[Union[Type, Callable]] = None, constructor_factory: None = None, ) -> Any: @@ -150,7 +155,8 @@ def arg( name: Optional[str] = None, metavar: Optional[str] = None, help: Optional[str] = None, - prefix_name: bool = True, + aliases: Optional[Sequence[str]] = None, + prefix_name: Optional[bool] = None, constructor: Optional[Union[Type, Callable]] = None, constructor_factory: Optional[Callable[[], Union[Type, Callable]]] = None, ) -> Any: @@ -164,8 +170,11 @@ def arg( name: A new name for the argument. metavar: Argument name in usage messages. The type is used by default. help: Helptext for this argument. The docstring is used by default. + aliases: Aliases for this argument. All strings in the sequence should start + with a hyphen (-). Aliases will _not_ currently be prefixed in a nested + structure, and are not supported for positional arguments. prefix_name: Whether or not to prefix the name of the argument based on where - it is in a nested structure. + it is in a nested structure. Arguments are prefixed by default. constructor: A constructor type or function. This should either be (a) a subtype of an argument's annotated type, or (b) a function with type-annotated inputs that returns an instance of the annotated type. This will be used in @@ -179,10 +188,16 @@ def arg( assert not ( constructor is not None and constructor_factory is not None ), "`constructor` and `constructor_factory` cannot both be set." + + if aliases is not None: + for alias in aliases: + assert alias.startswith("-"), "Argument alias needs to start with a hyphen!" + return _ArgConfiguration( name=name, metavar=metavar, help=help, + aliases=tuple(aliases) if aliases is not None else None, prefix_name=prefix_name, constructor_factory=constructor_factory if constructor is None From 442f5fe70f8f2f8b466272ba04c4c336ae0fb841 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 31 Oct 2023 12:27:09 -0700 Subject: [PATCH 2/5] Format --- tests/test_conf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_conf.py b/tests/test_conf.py index 57338d11..a74cb94e 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -1118,7 +1118,9 @@ class Config: assert tyro.cli( Config, args="--x.struct.b 2 --x.struct.c 3 --all 5".split(" ") ) == Config(x=30) - assert tyro.cli(Config, args="--x.struct.b 2 --x.struct.a 5".split(" ")) == Config(x=30) + assert tyro.cli(Config, args="--x.struct.b 2 --x.struct.a 5".split(" ")) == Config( + x=30 + ) # --x.struct.b is required! with pytest.raises(SystemExit): From a8f72bc00b652aa13c843bd68b5a0e5e671d4fac Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 31 Oct 2023 12:29:34 -0700 Subject: [PATCH 3/5] Bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d08495e7..7bb0f2d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "tyro" authors = [ {name = "brentyi", email = "brentyi@berkeley.edu"}, ] -version = "0.5.11" +version = "0.5.12" description = "Strongly typed, zero-effort CLI interfaces" readme = "README.md" license = { text="MIT" } From 983910993f0d2c761b69edf8cdb34192f9a51521 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 31 Oct 2023 12:31:55 -0700 Subject: [PATCH 4/5] Add test for alias helptext --- tests/test_conf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_conf.py b/tests/test_conf.py index a74cb94e..5e2a574b 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -1135,6 +1135,8 @@ class Config: assert "'a'" in error assert "'b'" not in error + assert "--x.struct.a INT, --all INT, -d INT" in get_helptext(Config) + def test_positional_alias() -> None: """Positional arguments with aliases (which will be ignored).""" From 588423ef3c43989d44d83de89f8f3f3b198220d3 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 31 Oct 2023 12:35:38 -0700 Subject: [PATCH 5/5] Re-generate Python 3.11 tests --- .../test_conf_generated.py | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tests/test_py311_generated/test_conf_generated.py b/tests/test_py311_generated/test_conf_generated.py index 15dcdac1..1e7bc887 100644 --- a/tests/test_py311_generated/test_conf_generated.py +++ b/tests/test_py311_generated/test_conf_generated.py @@ -1089,3 +1089,82 @@ class Config: assert "We're missing arguments" in error assert "'a'" in error assert "'b'" not in error + + +def test_alias() -> None: + """Arguments with aliases.""" + + @dataclasses.dataclass + class Struct: + a: Annotated[int, tyro.conf.arg(aliases=["--all", "-d"])] + b: int + c: int = 3 + + def make_float(struct: Struct) -> float: + return struct.a * struct.b * struct.c + + @dataclasses.dataclass + class Config: + x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23 + + assert tyro.cli(Config, args=[]) == Config(x=3.23) + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 --x.struct.a 5".split(" ") + ) == Config(x=30) + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 -d 5".split(" ") + ) == Config(x=30) + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 --all 5".split(" ") + ) == Config(x=30) + assert tyro.cli(Config, args="--x.struct.b 2 --x.struct.a 5".split(" ")) == Config( + x=30 + ) + + # --x.struct.b is required! + with pytest.raises(SystemExit): + tyro.cli(Config, args="--x.struct.a 5".split(" ")) + + # --x.struct.a and --x.struct.b are required! + target = io.StringIO() + with pytest.raises(SystemExit), contextlib.redirect_stdout(target): + tyro.cli(Config, args="--x.struct.b 5".split(" ")) + error = target.getvalue() + assert "We're missing arguments" in error + assert "'a'" in error + assert "'b'" not in error + + assert "--x.struct.a INT, --all INT, -d INT" in get_helptext(Config) + + +def test_positional_alias() -> None: + """Positional arguments with aliases (which will be ignored).""" + + @dataclasses.dataclass + class Struct: + a: Annotated[tyro.conf.Positional[int], tyro.conf.arg(aliases=["--all", "-d"])] + b: int + c: int = 3 + + def make_float(struct: Struct) -> float: + return struct.a * struct.b * struct.c + + @dataclasses.dataclass + class Config: + x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23 + + with pytest.warns(UserWarning): + assert tyro.cli(Config, args=[]) == Config(x=3.23) + with pytest.warns(UserWarning): + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 5".split(" ") + ) == Config(x=30) + + with pytest.raises(SystemExit), pytest.warns(UserWarning): + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 -d 5".split(" ") + ) == Config(x=30) + with pytest.raises(SystemExit), pytest.warns(UserWarning): + assert tyro.cli( + Config, args="--x.struct.b 2 --x.struct.c 3 --all 5".split(" ") + ) == Config(x=30)