Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement aliases #83

Merged
merged 5 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions docs/source/examples/04_additional/11_aliases.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
.. Comment: this file is automatically generated by `update_example_docs.py`.
It should not be modified manually.

Argument aliases
==========================================


:func:`tyro.conf.arg()` can be used to attach aliases to arguments.



.. code-block:: python
:linenos:


from typing_extensions import Annotated

import tyro


def checkout(
branch: Annotated[str, tyro.conf.arg(aliases=["-b"])],
) -> None:
"""Check out a branch."""
print(f"{branch=}")


def commit(
message: Annotated[str, tyro.conf.arg(aliases=["-m"])],
all: Annotated[str, tyro.conf.arg(aliases=["-a"])],
) -> None:
"""Make a commit."""
print(f"{message=} {all=}")


if __name__ == "__main__":
tyro.extras.subcommand_cli_from_dict(
{
"checkout": checkout,
"commit": commit,
}
)

------------

.. raw:: html

<kbd>python 04_additional/11_aliases.py --help</kbd>

.. program-output:: python ../../examples/04_additional/11_aliases.py --help

------------

.. raw:: html

<kbd>python 04_additional/11_aliases.py commit --help</kbd>

.. program-output:: python ../../examples/04_additional/11_aliases.py commit --help

------------

.. raw:: html

<kbd>python 04_additional/11_aliases.py commit --message hello --all</kbd>

.. program-output:: python ../../examples/04_additional/11_aliases.py commit --message hello --all

------------

.. raw:: html

<kbd>python 04_additional/11_aliases.py commit -m hello -a</kbd>

.. program-output:: python ../../examples/04_additional/11_aliases.py commit -m hello -a

------------

.. raw:: html

<kbd>python 04_additional/11_aliases.py checkout --help</kbd>

.. program-output:: python ../../examples/04_additional/11_aliases.py checkout --help

------------

.. raw:: html

<kbd>python 04_additional/11_aliases.py checkout --branch main</kbd>

.. program-output:: python ../../examples/04_additional/11_aliases.py checkout --branch main

------------

.. raw:: html

<kbd>python 04_additional/11_aliases.py checkout -b main</kbd>

.. program-output:: python ../../examples/04_additional/11_aliases.py checkout -b main
41 changes: 41 additions & 0 deletions examples/04_additional/11_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Argument aliases

:func:`tyro.conf.arg()` can be used to attach aliases to arguments.

Usage:
`python ./11_aliases.py --help`
`python ./11_aliases.py commit --help`
`python ./11_aliases.py commit --message hello --all`
`python ./11_aliases.py commit -m hello -a`
`python ./11_aliases.py checkout --help`
`python ./11_aliases.py checkout --branch main`
`python ./11_aliases.py checkout -b main`
"""

from typing_extensions import Annotated

import tyro


def checkout(
branch: Annotated[str, tyro.conf.arg(aliases=["-b"])],
) -> None:
"""Check out a branch."""
print(f"{branch=}")


def commit(
message: Annotated[str, tyro.conf.arg(aliases=["-m"])],
all: Annotated[str, tyro.conf.arg(aliases=["-a"])],
) -> None:
"""Make a commit."""
print(f"{message=} {all=}")


if __name__ == "__main__":
tyro.extras.subcommand_cli_from_dict(
{
"checkout": checkout,
"commit": commit,
}
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "tyro"
authors = [
{name = "brentyi", email = "brentyi@berkeley.edu"},
]
version = "0.5.11"
version = "0.5.12"
description = "Strongly typed, zero-effort CLI interfaces"
readme = "README.md"
license = { text="MIT" }
Expand Down
79 changes: 79 additions & 0 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,3 +1090,82 @@ class Config:
assert "We're missing arguments" in error
assert "'a'" in error
assert "'b'" not in error


def test_alias() -> None:
"""Arguments with aliases."""

@dataclasses.dataclass
class Struct:
a: Annotated[int, tyro.conf.arg(aliases=["--all", "-d"])]
b: int
c: int = 3

def make_float(struct: Struct) -> float:
return struct.a * struct.b * struct.c

@dataclasses.dataclass
class Config:
x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23

assert tyro.cli(Config, args=[]) == Config(x=3.23)
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 --x.struct.a 5".split(" ")
) == Config(x=30)
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 -d 5".split(" ")
) == Config(x=30)
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 --all 5".split(" ")
) == Config(x=30)
assert tyro.cli(Config, args="--x.struct.b 2 --x.struct.a 5".split(" ")) == Config(
x=30
)

# --x.struct.b is required!
with pytest.raises(SystemExit):
tyro.cli(Config, args="--x.struct.a 5".split(" "))

# --x.struct.a and --x.struct.b are required!
target = io.StringIO()
with pytest.raises(SystemExit), contextlib.redirect_stdout(target):
tyro.cli(Config, args="--x.struct.b 5".split(" "))
error = target.getvalue()
assert "We're missing arguments" in error
assert "'a'" in error
assert "'b'" not in error

assert "--x.struct.a INT, --all INT, -d INT" in get_helptext(Config)


def test_positional_alias() -> None:
"""Positional arguments with aliases (which will be ignored)."""

@dataclasses.dataclass
class Struct:
a: Annotated[tyro.conf.Positional[int], tyro.conf.arg(aliases=["--all", "-d"])]
b: int
c: int = 3

def make_float(struct: Struct) -> float:
return struct.a * struct.b * struct.c

@dataclasses.dataclass
class Config:
x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23

with pytest.warns(UserWarning):
assert tyro.cli(Config, args=[]) == Config(x=3.23)
with pytest.warns(UserWarning):
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 5".split(" ")
) == Config(x=30)

with pytest.raises(SystemExit), pytest.warns(UserWarning):
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 -d 5".split(" ")
) == Config(x=30)
with pytest.raises(SystemExit), pytest.warns(UserWarning):
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 --all 5".split(" ")
) == Config(x=30)
79 changes: 79 additions & 0 deletions tests/test_py311_generated/test_conf_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,3 +1089,82 @@ class Config:
assert "We're missing arguments" in error
assert "'a'" in error
assert "'b'" not in error


def test_alias() -> None:
"""Arguments with aliases."""

@dataclasses.dataclass
class Struct:
a: Annotated[int, tyro.conf.arg(aliases=["--all", "-d"])]
b: int
c: int = 3

def make_float(struct: Struct) -> float:
return struct.a * struct.b * struct.c

@dataclasses.dataclass
class Config:
x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23

assert tyro.cli(Config, args=[]) == Config(x=3.23)
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 --x.struct.a 5".split(" ")
) == Config(x=30)
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 -d 5".split(" ")
) == Config(x=30)
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 --all 5".split(" ")
) == Config(x=30)
assert tyro.cli(Config, args="--x.struct.b 2 --x.struct.a 5".split(" ")) == Config(
x=30
)

# --x.struct.b is required!
with pytest.raises(SystemExit):
tyro.cli(Config, args="--x.struct.a 5".split(" "))

# --x.struct.a and --x.struct.b are required!
target = io.StringIO()
with pytest.raises(SystemExit), contextlib.redirect_stdout(target):
tyro.cli(Config, args="--x.struct.b 5".split(" "))
error = target.getvalue()
assert "We're missing arguments" in error
assert "'a'" in error
assert "'b'" not in error

assert "--x.struct.a INT, --all INT, -d INT" in get_helptext(Config)


def test_positional_alias() -> None:
"""Positional arguments with aliases (which will be ignored)."""

@dataclasses.dataclass
class Struct:
a: Annotated[tyro.conf.Positional[int], tyro.conf.arg(aliases=["--all", "-d"])]
b: int
c: int = 3

def make_float(struct: Struct) -> float:
return struct.a * struct.b * struct.c

@dataclasses.dataclass
class Config:
x: Annotated[float, tyro.conf.arg(constructor=make_float)] = 3.23

with pytest.warns(UserWarning):
assert tyro.cli(Config, args=[]) == Config(x=3.23)
with pytest.warns(UserWarning):
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 5".split(" ")
) == Config(x=30)

with pytest.raises(SystemExit), pytest.warns(UserWarning):
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 -d 5".split(" ")
) == Config(x=30)
with pytest.raises(SystemExit), pytest.warns(UserWarning):
assert tyro.cli(
Config, args="--x.struct.b 2 --x.struct.c 3 --all 5".split(" ")
) == Config(x=30)
15 changes: 13 additions & 2 deletions tyro/_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,19 @@ def add_argument(
if self.field.argconf.metavar is not None:
kwargs["metavar"] = self.field.argconf.metavar

# Add argument!
arg = parser.add_argument(name_or_flag, **kwargs)
# Add argument, with aliases if available.
if self.field.argconf.aliases is not None and not self.field.is_positional():
arg = parser.add_argument(
name_or_flag, *self.field.argconf.aliases, **kwargs
)
else:
if self.field.argconf.aliases is not None:
import warnings

warnings.warn(
f"Aliases were specified, but {name_or_flag} is positional. Aliases will be ignored."
)
arg = parser.add_argument(name_or_flag, **kwargs)

# Do our best to tab complete paths.
# There will be false positives here, but if choices is unset they should be
Expand Down
28 changes: 21 additions & 7 deletions tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __post_init__(self):
@staticmethod
def make(
name: str,
type_or_callable: TypeForm[Any],
type_or_callable: Union[TypeForm[Any], Callable],
default: Any,
helptext: Optional[str],
call_argname_override: Optional[Any] = None,
Expand All @@ -88,12 +88,26 @@ def make(
_, argconfs = _resolver.unwrap_annotated(
type_or_callable, _confstruct._ArgConfiguration
)
if len(argconfs) == 0:
argconf = _confstruct._ArgConfiguration(None, None, None, True, None)
else:
assert len(argconfs) == 1
(argconf,) = argconfs
helptext = argconf.help
argconf = _confstruct._ArgConfiguration(
None,
None,
help=None,
aliases=None,
prefix_name=True,
constructor_factory=None,
)
for overwrite_argconf in argconfs:
# Apply any annotated argument configuration values.
argconf = dataclasses.replace(
argconf,
**{
field.name: getattr(overwrite_argconf, field.name)
for field in dataclasses.fields(overwrite_argconf)
if getattr(overwrite_argconf, field.name) is not None
},
)
if argconf.help is not None:
helptext = argconf.help

type_or_callable, inferred_markers = _resolver.unwrap_annotated(
type_or_callable, _markers._Marker
Expand Down
Loading