-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
More complete custom constructor docs, refinements (#201)
* Constructors + examples refinements * formatting/regenerate docs * Move test * ruff * Python 3.7 * ruff
- Loading branch information
Showing
23 changed files
with
619 additions
and
158 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
"""Simple Constructors | ||
For simple custom constructors, we can pass a constructor function into | ||
:func:`tyro.conf.arg` or :func:`tyro.conf.subcommand`. Arguments for will be | ||
generated by parsing the signature of the constructor function. | ||
In this example, we use this pattern to define custom behavior for | ||
instantiating a NumPy array. | ||
Usage: | ||
python ./01_simple_constructors.py --help | ||
python ./01_simple_constructors.py --array.values 1 2 3 | ||
python ./01_simple_constructors.py --array.values 1 2 3 4 5 --array.dtype float32 | ||
""" | ||
|
||
from typing import Literal | ||
|
||
import numpy as np | ||
from typing_extensions import Annotated | ||
|
||
import tyro | ||
|
||
|
||
def construct_array( | ||
values: tuple[float, ...], dtype: Literal["float32", "float64"] = "float64" | ||
) -> np.ndarray: | ||
"""A custom constructor for 1D NumPy arrays.""" | ||
return np.array( | ||
values, | ||
dtype={"float32": np.float32, "float64": np.float64}[dtype], | ||
) | ||
|
||
|
||
def main( | ||
# We can specify a custom constructor for an argument in `tyro.conf.arg()`. | ||
array: Annotated[np.ndarray, tyro.conf.arg(constructor=construct_array)], | ||
) -> None: | ||
print(f"{array=}") | ||
|
||
|
||
if __name__ == "__main__": | ||
tyro.cli(main) |
20 changes: 5 additions & 15 deletions
20
...m_constructors/01_primitive_annotation.py → ...m_constructors/02_primitive_annotation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""Custom Structs (Registry) | ||
In this example, we use a :class:`tyro.constructors.ConstructorRegistry` to | ||
add support for a custom type. | ||
.. warning:: | ||
This will be complicated! | ||
Usage: | ||
python ./03_primitive_registry.py --help | ||
python ./03_primitive_registry.py --dict1 '{"hello": "world"}' | ||
python ./03_primitive_registry.py --dict1 '{"hello": "world"}' --dict2 '{"hello": "world"}' | ||
""" | ||
|
||
import tyro | ||
|
||
|
||
# A custom type that we'll add support for to tyro. | ||
class Bounds: | ||
def __init__(self, lower: int, upper: int): | ||
self.bounds = (lower, upper) | ||
|
||
|
||
# Create a custom registry, which stores constructor rules. | ||
custom_registry = tyro.constructors.ConstructorRegistry() | ||
|
||
|
||
# Define a rule that applies to all types that match `Bounds`. | ||
@custom_registry.struct_rule | ||
def _( | ||
type_info: tyro.constructors.StructTypeInfo, | ||
) -> tyro.constructors.StructConstructorSpec | None: | ||
# We return `None` if the rule does not apply. | ||
if type_info.type != Bounds: | ||
return None | ||
|
||
# We can extract the default value of the field from `type_info`. | ||
if isinstance(type_info.default, Bounds): | ||
# If the default value is a `Bounds` instance, we don't need to generate a constructor. | ||
default = (type_info.default.bounds[0], type_info.default.bounds[1]) | ||
is_default_overridden = True | ||
else: | ||
# Otherwise, the default value is missing. We'll mark the child defaults as missing as well. | ||
assert type_info.default in ( | ||
tyro.constructors.MISSING, | ||
tyro.constructors.MISSING_NONPROP, | ||
) | ||
default = (tyro.MISSING, tyro.MISSING) | ||
is_default_overridden = False | ||
|
||
# If the rule applies, we return the constructor spec. | ||
return tyro.constructors.StructConstructorSpec( | ||
# The instantiate function will be called with the fields as keyword arguments. | ||
instantiate=Bounds, | ||
fields=( | ||
tyro.constructors.StructFieldSpec( | ||
name="lower", | ||
type=int, | ||
default=default[0], | ||
is_default_overridden=is_default_overridden, | ||
helptext="Lower bound." "", | ||
), | ||
tyro.constructors.StructFieldSpec( | ||
name="upper", | ||
type=int, | ||
default=default[1], | ||
is_default_overridden=is_default_overridden, | ||
helptext="Upper bound." "", | ||
), | ||
), | ||
) | ||
|
||
|
||
def main( | ||
bounds: Bounds, | ||
bounds_with_default: Bounds = Bounds(0, 100), | ||
) -> None: | ||
"""A function with two `Bounds` instances as input.""" | ||
print(f"{bounds=}") | ||
print(f"{bounds_with_default=}") | ||
|
||
|
||
if __name__ == "__main__": | ||
# To activate a custom registry, we should use it as a context manager. | ||
with custom_registry: | ||
tyro.cli(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,15 @@ | ||
Custom constructors | ||
=================== | ||
|
||
In these examples, we show how custom types can be parsed by and registered with :func:`tyro.cli`. | ||
:func:`tyro.cli` is designed for comprehensive support of standard Python type | ||
constructs. In some cases, however, it can be useful to extend the set of types | ||
supported by :mod:`tyro`. | ||
|
||
We provide two complementary approaches for doing so: | ||
|
||
- :mod:`tyro.conf` provides a simple API for specifying custom constructor | ||
functions. | ||
- :mod:`tyro.constructors` provides a more flexible API for defining behavior | ||
for different types. There are two categories of types: *primitive* types are | ||
instantiated from a single commandline argument, while *struct* types are | ||
broken down into multiple arguments. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.