Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add support for typing.TypeAliasType as valid parameter type. #970

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ classifiers = [
]
dependencies = [
"click >= 8.0.0",
"typing-extensions >= 3.7.4.3",
"typing-extensions >= 4.6.0",
]
readme = "README.md"
[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion requirements-tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pytest-cov >=2.10.0,<6.0.0
coverage[toml] >=6.2,<8.0
pytest-xdist >=1.32.0,<4.0.0
pytest-sugar >=0.9.4,<1.1.0
mypy ==1.4.1
mypy >=1.10.1
ruff ==0.6.3
# Needed explicitly by typer-slim
rich >=10.11.0
Expand Down
139 changes: 139 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from datetime import datetime
from enum import Enum
from pathlib import Path
from uuid import UUID

import click
import pytest
from typer.main import get_click_type
from typer.models import FileBinaryRead, FileTextWrite, ParameterInfo
from typing_extensions import TypeAliasType


def test_get_click_type_with_custom_click_type():
custom_click_type = click.INT
param_info = ParameterInfo(click_type=custom_click_type)
result = get_click_type(annotation=int, parameter_info=param_info)
assert result is custom_click_type


def test_get_click_type_with_custom_parser():
def mock_parser(x):
return 42

param_info = ParameterInfo(parser=mock_parser)
result = get_click_type(annotation=int, parameter_info=param_info)
assert isinstance(result, click.types.FuncParamType)
assert result.convert("42", None, None) == 42


def test_get_click_type_with_str_annotation():
param_info = ParameterInfo()
result = get_click_type(annotation=str, parameter_info=param_info)
assert result is click.STRING


def test_get_click_type_with_int_annotation_no_min_max():
param_info = ParameterInfo()
result = get_click_type(annotation=int, parameter_info=param_info)
assert result is click.INT


def test_get_click_type_with_int_annotation_with_min_max():
param_info = ParameterInfo(min=10, max=100)
result = get_click_type(annotation=int, parameter_info=param_info)
assert isinstance(result, click.IntRange)
assert result.min == 10
assert result.max == 100


def test_get_click_type_with_float_annotation_no_min_max():
param_info = ParameterInfo()
result = get_click_type(annotation=float, parameter_info=param_info)
assert result is click.FLOAT


def test_get_click_type_with_float_annotation_with_min_max():
param_info = ParameterInfo(min=0.1, max=10.5)
result = get_click_type(annotation=float, parameter_info=param_info)
assert isinstance(result, click.FloatRange)
assert result.min == 0.1
assert result.max == 10.5


def test_get_click_type_with_bool_annotation():
param_info = ParameterInfo()
result = get_click_type(annotation=bool, parameter_info=param_info)
assert result is click.BOOL


def test_get_click_type_with_uuid_annotation():
param_info = ParameterInfo()
result = get_click_type(annotation=UUID, parameter_info=param_info)
assert result is click.UUID


def test_get_click_type_with_datetime_annotation():
param_info = ParameterInfo(formats=["%Y-%m-%d"])
result = get_click_type(annotation=datetime, parameter_info=param_info)
assert isinstance(result, click.DateTime)
assert result.formats == ["%Y-%m-%d"]


def test_get_click_type_with_path_annotation():
param_info = ParameterInfo(resolve_path=True)
result = get_click_type(annotation=Path, parameter_info=param_info)
assert isinstance(result, click.Path)
assert result.resolve_path is True


def test_get_click_type_with_enum_annotation():
class Color(Enum):
RED = "red"
BLUE = "blue"

param_info = ParameterInfo()
result = get_click_type(annotation=Color, parameter_info=param_info)
assert isinstance(result, click.Choice)
assert result.choices == ["red", "blue"]


def test_get_click_type_with_file_text_write_annotation():
param_info = ParameterInfo(mode="w", encoding="utf-8")
result = get_click_type(annotation=FileTextWrite, parameter_info=param_info)
assert isinstance(result, click.File)
assert result.mode == "w"
assert result.encoding == "utf-8"


def test_get_click_type_with_file_binary_read_annotation():
param_info = ParameterInfo(mode="rb")
result = get_click_type(annotation=FileBinaryRead, parameter_info=param_info)
assert isinstance(result, click.File)
assert result.mode == "rb"


def test_get_click_type_with_type_alias_type():
# define TypeAliasType
Name = TypeAliasType(name="Name", value=str)
Surname = TypeAliasType(name="Surname", value=Name)

param_info = ParameterInfo()
result = get_click_type(annotation=Name, parameter_info=param_info)
assert result is click.STRING

# recursive types
param_info = ParameterInfo()
result = get_click_type(annotation=Surname, parameter_info=param_info)
assert result is click.STRING


def test_get_click_type_with_unsupported_type():
class UnsupportedType:
pass

param_info = ParameterInfo()
with pytest.raises(
RuntimeError, match="Type not yet supported: <class '.*UnsupportedType.*'>"
):
get_click_type(annotation=UnsupportedType, parameter_info=param_info)
19 changes: 16 additions & 3 deletions typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,21 @@
from pathlib import Path
from traceback import FrameSummary, StackSummary
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from uuid import UUID

import click
from typing_extensions import get_args, get_origin
from typing_extensions import TypeAliasType, get_args, get_origin

from ._typing import is_union
from .completion import get_completion_inspect_parameters
Expand Down Expand Up @@ -43,7 +53,7 @@
Required,
TyperInfo,
)
from .utils import get_params_from_function
from .utils import get_original_type, get_params_from_function

try:
import rich
Expand Down Expand Up @@ -710,6 +720,9 @@ def wrapper(**kwargs: Any) -> Any:
def get_click_type(
*, annotation: Any, parameter_info: ParameterInfo
) -> click.ParamType:
if isinstance(annotation, TypeAliasType):
annotation = get_original_type(annotation)

if parameter_info.click_type is not None:
return parameter_info.click_type

Expand Down
33 changes: 32 additions & 1 deletion typer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@
from copy import copy
from typing import Any, Callable, Dict, List, Tuple, Type, cast

from typing_extensions import Annotated, get_args, get_origin, get_type_hints
from typing_extensions import (
Annotated,
TypeAliasType,
TypeVar,
get_args,
get_origin,
get_type_hints,
)

from .models import ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta

T = TypeVar("T")
TypeAliasTypeVar = TypeAliasType("TypeAliasTypeVar", value=T, type_params=(T,))


def _param_type_to_user_string(param_type: Type[ParameterInfo]) -> str:
# Render a `ParameterInfo` subclass for use in error messages.
Expand Down Expand Up @@ -189,3 +199,24 @@ def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]:
name=param.name, default=default, annotation=annotation
)
return params


def get_original_type(alias: TypeAliasTypeVar[T]) -> T:
"""Return the original type of an alias.

Examples
--------
>>> Name = TypeAliasType(name="Name", value=str)
>>> Surname = TypeAliasType(name="Surname", value=Name)
>>> get_original_type(Name)
str
>>> get_original_type(Surname)
str
>>> get_original_type(int)
int
"""
otype = alias
while isinstance(otype, TypeAliasType):
otype = otype.__value__

return otype
Loading