Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix subcommand edge cases associated with generics, tyro.conf.Suppress, tyro.conf.AvoidSubcomands #220

Merged
merged 6 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
<br />

<strong><code>tyro.cli()</code></strong> is a tool for generating CLI
interfaces in Python.
interfaces from type-annotated Python.

We can define configurable scripts using functions:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

|ruff| |nbsp| |mypy| |nbsp| |pyright| |nbsp| |coverage| |nbsp| |versions|

:func:`tyro.cli()` is a tool for generating CLI interfaces in Python.
:func:`tyro.cli()` is a tool for generating CLI interfaces from type-annotated Python.

We can define configurable scripts using functions:

Expand Down
7 changes: 5 additions & 2 deletions src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._typing import TypeForm
from .conf import _confstruct, _markers
from .constructors._primitive_spec import UnsupportedTypeAnnotationError
from .constructors._registry import ConstructorRegistry
from .constructors._registry import ConstructorRegistry, check_default_instances
from .constructors._struct_spec import (
StructFieldSpec,
StructTypeInfo,
Expand Down Expand Up @@ -92,7 +92,10 @@ def make(
# called for functions.
typ = _resolver.type_from_typevar_constraints(typ)
typ = _resolver.narrow_collection_types(typ, default)
typ = _resolver.narrow_union_type(typ, default)

# Be forgiving about default instances.
if not check_default_instances():
typ = _resolver.expand_union_types(typ, default)

# Try to extract argconf overrides from type.
_, argconfs = _resolver.unwrap_annotated(typ, _confstruct._ArgConfig)
Expand Down
67 changes: 31 additions & 36 deletions src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,7 @@
import dataclasses
import numbers
import warnings
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from typing import Any, Callable, Dict, List, Set, Tuple, Type, TypeVar, Union, cast

from typing_extensions import Annotated, get_args, get_origin

Expand Down Expand Up @@ -54,12 +42,12 @@ class ParserSpecification:
args: List[_arguments.ArgumentDefinition]
field_list: List[_fields.FieldDefinition]
child_from_prefix: Dict[str, ParserSpecification]
helptext_from_intern_prefixed_field_name: Dict[str, Optional[str]]
helptext_from_intern_prefixed_field_name: Dict[str, str | None]

# We have two mechanics for tracking subparser groups:
# - A single subparser group, which is what gets added in the tree structure built
# by the argparse parser.
subparsers: Optional[SubparsersSpecification]
subparsers: SubparsersSpecification | None
# - A set of subparser groups, which reflect the tree structure built by the
# hierarchy of a nested config structure.
subparsers_from_intern_prefix: Dict[str, SubparsersSpecification]
Expand All @@ -72,7 +60,7 @@ class ParserSpecification:
def from_callable_or_type(
f: Callable[..., T],
markers: Set[_markers._Marker],
description: Optional[str],
description: str | None,
parent_classes: Set[Type[Any]],
default_instance: Union[
T, _singleton.PropagatingMissingType, _singleton.NonpropagatingMissingType
Expand Down Expand Up @@ -113,7 +101,7 @@ def from_callable_or_type(

has_required_args = False
args = []
helptext_from_intern_prefixed_field_name: Dict[str, Optional[str]] = {}
helptext_from_intern_prefixed_field_name: Dict[str, str | None] = {}

child_from_prefix: Dict[str, ParserSpecification] = {}

Expand Down Expand Up @@ -242,7 +230,7 @@ def apply(
def apply_args(
self,
parser: argparse.ArgumentParser,
parent: Optional[ParserSpecification] = None,
parent: ParserSpecification | None = None,
) -> None:
"""Create defined arguments and subparsers."""

Expand Down Expand Up @@ -356,21 +344,25 @@ def handle_field(

if not force_primitive:
# (1) Handle Unions over callables; these result in subparsers.
subparsers_attempt = SubparsersSpecification.from_field(
field,
parent_classes=parent_classes,
intern_prefix=_strings.make_field_name([intern_prefix, field.intern_name]),
extern_prefix=_strings.make_field_name([extern_prefix, field.extern_name]),
)
if subparsers_attempt is not None:
if not subparsers_attempt.required and (
_markers.AvoidSubcommands in field.markers
or _markers.Suppress in field.markers
):
# Don't make a subparser.
field = field.with_new_type_stripped(type(field.default))
else:
return subparsers_attempt
if _markers.Suppress not in field.markers:
subparsers_attempt = SubparsersSpecification.from_field(
field,
parent_classes=parent_classes,
intern_prefix=_strings.make_field_name(
[intern_prefix, field.intern_name]
),
extern_prefix=_strings.make_field_name(
[extern_prefix, field.extern_name]
),
)
if subparsers_attempt is not None:
if subparsers_attempt.default_parser is not None and (
_markers.AvoidSubcommands in field.markers
):
# Don't make a subparser, just use the default subcommand.
return subparsers_attempt.default_parser
else:
return subparsers_attempt

# (2) Handle nested callables.
if force_primitive == "struct" or _fields.is_struct_type(
Expand Down Expand Up @@ -414,8 +406,9 @@ class SubparsersSpecification:
"""Structure for defining subparsers. Each subparser is a parser with a name."""

name: str
description: Optional[str]
description: str | None
parser_from_name: Dict[str, ParserSpecification]
default_parser: ParserSpecification | None
intern_prefix: str
required: bool
default_instance: Any
Expand All @@ -427,7 +420,7 @@ def from_field(
parent_classes: Set[Type[Any]],
intern_prefix: str,
extern_prefix: str,
) -> Optional[SubparsersSpecification]:
) -> SubparsersSpecification | None:
# Union of classes should create subparsers.
typ = _resolver.unwrap_annotated(field.type_stripped)
if get_origin(typ) not in (Union, _resolver.UnionType):
Expand Down Expand Up @@ -593,6 +586,7 @@ def from_field(

# Required if a default is passed in, but the default value has missing
# parameters.
default_parser = None
if default_name is not None:
default_parser = parser_from_name[default_name]
if any(map(lambda arg: arg.lowered.required, default_parser.args)):
Expand Down Expand Up @@ -626,6 +620,7 @@ def from_field(
# the user to include it in the docstring.
description=description,
parser_from_name=parser_from_name,
default_parser=default_parser,
intern_prefix=intern_prefix,
required=required,
default_instance=field.default,
Expand Down Expand Up @@ -678,7 +673,7 @@ def apply(


def add_subparsers_to_leaves(
root: Optional[SubparsersSpecification], leaf: SubparsersSpecification
root: SubparsersSpecification | None, leaf: SubparsersSpecification
) -> SubparsersSpecification:
if root is None:
return leaf
Expand Down
4 changes: 2 additions & 2 deletions src/tyro/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,8 @@ def __exit__(self, exc_type, exc_value, traceback):


@_unsafe_cache.unsafe_cache(maxsize=1024)
def narrow_union_type(typ: TypeOrCallable, default_instance: Any) -> TypeOrCallable:
"""Narrow union types.
def expand_union_types(typ: TypeOrCallable, default_instance: Any) -> TypeOrCallable:
"""Expand union types if necessary.

This is a shim for failing more gracefully when we we're given a Union type that
doesn't match the default value.
Expand Down
44 changes: 43 additions & 1 deletion tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest
from helptext_utils import get_helptext_with_checks
from typing_extensions import Annotated
from typing_extensions import Annotated, TypedDict

import tyro

Expand Down Expand Up @@ -1558,3 +1558,45 @@ class Aconfig:
b_conf: Bconfig = dataclasses.field(default_factory=Bconfig)

assert tyro.cli(Aconfig, config=(tyro.conf.Suppress,), args=[]) == Aconfig()


def test_suppressed_subcommand() -> None:
class Person(TypedDict):
name: str
age: int

@dataclasses.dataclass
class Train:
person: tyro.conf.Suppress[Union[Person, None]] = None

assert tyro.cli(Train, args=[]) == Train(None)


def test_avoid_subcommands_with_generics() -> None:
T = TypeVar("T")

@dataclasses.dataclass(frozen=True)
class Person(Generic[T]):
field: Union[T, bool]

@dataclasses.dataclass
class Train:
person: Union[Person[int], Person[bool], Person[str], Person[float]] = Person(
"hello"
)

assert tyro.cli(Train, config=(tyro.conf.AvoidSubcommands,), args=[]) == Train(
person=Person("hello")
)

# No subcommand should be created.
assert "STR|{True,False}" in get_helptext_with_checks(
tyro.conf.AvoidSubcommands[Train]
)
assert "person:person-str" not in get_helptext_with_checks(
tyro.conf.AvoidSubcommands[Train]
)

# Subcommand should be created.
assert "STR|{True,False}" not in get_helptext_with_checks(Train)
assert "person:person-str" in get_helptext_with_checks(Train)
54 changes: 53 additions & 1 deletion tests/test_py311_generated/test_conf_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import json as json_
import shlex
import sys
from typing import Annotated, Any, Dict, Generic, List, Tuple, Type, TypeVar
from typing import (
Annotated,
Any,
Dict,
Generic,
List,
Tuple,
Type,
TypedDict,
TypeVar,
)

import pytest
from helptext_utils import get_helptext_with_checks
Expand Down Expand Up @@ -1553,3 +1563,45 @@ class Aconfig:
b_conf: Bconfig = dataclasses.field(default_factory=Bconfig)

assert tyro.cli(Aconfig, config=(tyro.conf.Suppress,), args=[]) == Aconfig()


def test_suppressed_subcommand() -> None:
class Person(TypedDict):
name: str
age: int

@dataclasses.dataclass
class Train:
person: tyro.conf.Suppress[Person | None] = None

assert tyro.cli(Train, args=[]) == Train(None)


def test_avoid_subcommands_with_generics() -> None:
T = TypeVar("T")

@dataclasses.dataclass(frozen=True)
class Person(Generic[T]):
field: T | bool

@dataclasses.dataclass
class Train:
person: Person[int] | Person[bool] | Person[str] | Person[float] = Person(
"hello"
)

assert tyro.cli(Train, config=(tyro.conf.AvoidSubcommands,), args=[]) == Train(
person=Person("hello")
)

# No subcommand should be created.
assert "STR|{True,False}" in get_helptext_with_checks(
tyro.conf.AvoidSubcommands[Train]
)
assert "person:person-str" not in get_helptext_with_checks(
tyro.conf.AvoidSubcommands[Train]
)

# Subcommand should be created.
assert "STR|{True,False}" not in get_helptext_with_checks(Train)
assert "person:person-str" in get_helptext_with_checks(Train)
Loading