Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support union type syntax in runtime contexts #10770

Merged
merged 4 commits into from
Jul 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions mypy/exprtotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

from mypy.nodes import (
Expression, NameExpr, MemberExpr, IndexExpr, RefExpr, TupleExpr, IntExpr, FloatExpr, UnaryExpr,
ComplexExpr, ListExpr, StrExpr, BytesExpr, UnicodeExpr, EllipsisExpr, CallExpr,
ComplexExpr, ListExpr, StrExpr, BytesExpr, UnicodeExpr, EllipsisExpr, CallExpr, OpExpr,
get_member_expr_fullname
)
from mypy.fastparse import parse_type_string
from mypy.types import (
Type, UnboundType, TypeList, EllipsisType, AnyType, CallableArgument, TypeOfAny,
RawExpressionType, ProperType
RawExpressionType, ProperType, UnionType
)
from mypy.options import Options


class TypeTranslationError(Exception):
Expand All @@ -29,7 +30,9 @@ def _extract_argument_name(expr: Expression) -> Optional[str]:
raise TypeTranslationError()


def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = None) -> ProperType:
def expr_to_unanalyzed_type(expr: Expression,
options: Optional[Options] = None,
_parent: Optional[Expression] = None) -> ProperType:
"""Translate an expression to the corresponding type.

The result is not semantically analyzed. It can be UnboundType or TypeList.
Expand All @@ -53,7 +56,7 @@ def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = No
else:
raise TypeTranslationError()
elif isinstance(expr, IndexExpr):
base = expr_to_unanalyzed_type(expr.base, expr)
base = expr_to_unanalyzed_type(expr.base, options, expr)
if isinstance(base, UnboundType):
if base.args:
raise TypeTranslationError()
Expand All @@ -69,14 +72,20 @@ def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = No
# of the Annotation definition and only returning the type information,
# losing all the annotations.

return expr_to_unanalyzed_type(args[0], expr)
return expr_to_unanalyzed_type(args[0], options, expr)
else:
base.args = tuple(expr_to_unanalyzed_type(arg, expr) for arg in args)
base.args = tuple(expr_to_unanalyzed_type(arg, options, expr) for arg in args)
if not base.args:
base.empty_tuple_index = True
return base
else:
raise TypeTranslationError()
elif (isinstance(expr, OpExpr)
and expr.op == '|'
and options
and options.python_version >= (3, 10)):
return UnionType([expr_to_unanalyzed_type(expr.left, options),
expr_to_unanalyzed_type(expr.right, options)])
elif isinstance(expr, CallExpr) and isinstance(_parent, ListExpr):
c = expr.callee
names = []
Expand Down Expand Up @@ -109,19 +118,19 @@ def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = No
if typ is not default_type:
# Two types
raise TypeTranslationError()
typ = expr_to_unanalyzed_type(arg, expr)
typ = expr_to_unanalyzed_type(arg, options, expr)
continue
else:
raise TypeTranslationError()
elif i == 0:
typ = expr_to_unanalyzed_type(arg, expr)
typ = expr_to_unanalyzed_type(arg, options, expr)
elif i == 1:
name = _extract_argument_name(arg)
else:
raise TypeTranslationError()
return CallableArgument(typ, name, arg_const, expr.line, expr.column)
elif isinstance(expr, ListExpr):
return TypeList([expr_to_unanalyzed_type(t, expr) for t in expr.items],
return TypeList([expr_to_unanalyzed_type(t, options, expr) for t in expr.items],
line=expr.line, column=expr.column)
elif isinstance(expr, StrExpr):
return parse_type_string(expr.value, 'builtins.str', expr.line, expr.column,
Expand All @@ -133,7 +142,7 @@ def expr_to_unanalyzed_type(expr: Expression, _parent: Optional[Expression] = No
return parse_type_string(expr.value, 'builtins.unicode', expr.line, expr.column,
assume_str_is_unicode=True)
elif isinstance(expr, UnaryExpr):
typ = expr_to_unanalyzed_type(expr.expr)
typ = expr_to_unanalyzed_type(expr.expr, options)
if isinstance(typ, RawExpressionType):
if isinstance(typ.literal_value, int) and expr.op == '-':
typ.literal_value *= -1
Expand Down
2 changes: 1 addition & 1 deletion mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext',
type_arg = _get_argument(rvalue, 'type')
if type_arg and not init_type:
try:
un_type = expr_to_unanalyzed_type(type_arg)
un_type = expr_to_unanalyzed_type(type_arg, ctx.api.options)
except TypeTranslationError:
ctx.api.fail('Invalid argument to type', type_arg)
else:
Expand Down
25 changes: 16 additions & 9 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,7 +1267,7 @@ class Foo(Bar, Generic[T]): ...
self.analyze_type_expr(base_expr)

try:
base = expr_to_unanalyzed_type(base_expr)
base = expr_to_unanalyzed_type(base_expr, self.options)
except TypeTranslationError:
# This error will be caught later.
continue
Expand Down Expand Up @@ -1373,7 +1373,7 @@ def get_all_bases_tvars(self,
for i, base_expr in enumerate(base_type_exprs):
if i not in removed:
try:
base = expr_to_unanalyzed_type(base_expr)
base = expr_to_unanalyzed_type(base_expr, self.options)
except TypeTranslationError:
# This error will be caught later.
continue
Expand Down Expand Up @@ -2101,7 +2101,7 @@ def should_wait_rhs(self, rv: Expression) -> bool:
return self.should_wait_rhs(rv.callee)
return False

def can_be_type_alias(self, rv: Expression) -> bool:
def can_be_type_alias(self, rv: Expression, allow_none: bool = False) -> bool:
"""Is this a valid r.h.s. for an alias definition?

Note: this function should be only called for expressions where self.should_wait_rhs()
Expand All @@ -2113,6 +2113,13 @@ def can_be_type_alias(self, rv: Expression) -> bool:
return True
if self.is_none_alias(rv):
return True
if allow_none and isinstance(rv, NameExpr) and rv.fullname == 'builtins.None':
return True
if (isinstance(rv, OpExpr)
and rv.op == '|'
and self.can_be_type_alias(rv.left, allow_none=True)
and self.can_be_type_alias(rv.right, allow_none=True)):
return True
return False

def is_type_ref(self, rv: Expression, bare: bool = False) -> bool:
Expand Down Expand Up @@ -3195,7 +3202,7 @@ def analyze_value_types(self, items: List[Expression]) -> List[Type]:
result: List[Type] = []
for node in items:
try:
analyzed = self.anal_type(expr_to_unanalyzed_type(node),
analyzed = self.anal_type(expr_to_unanalyzed_type(node, self.options),
allow_placeholder=True)
if analyzed is None:
# Type variables are special: we need to place them in the symbol table
Expand Down Expand Up @@ -3638,7 +3645,7 @@ def visit_call_expr(self, expr: CallExpr) -> None:
return
# Translate first argument to an unanalyzed type.
try:
target = expr_to_unanalyzed_type(expr.args[0])
target = expr_to_unanalyzed_type(expr.args[0], self.options)
except TypeTranslationError:
self.fail('Cast target is not a type', expr)
return
Expand Down Expand Up @@ -3696,7 +3703,7 @@ def visit_call_expr(self, expr: CallExpr) -> None:
return
# Translate first argument to an unanalyzed type.
try:
target = expr_to_unanalyzed_type(expr.args[0])
target = expr_to_unanalyzed_type(expr.args[0], self.options)
except TypeTranslationError:
self.fail('Argument 1 to _promote is not a type', expr)
return
Expand Down Expand Up @@ -3892,7 +3899,7 @@ def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]]
items = [index]
for item in items:
try:
typearg = expr_to_unanalyzed_type(item)
typearg = expr_to_unanalyzed_type(item, self.options)
except TypeTranslationError:
self.fail('Type expected within [...]', expr)
return None
Expand Down Expand Up @@ -4199,7 +4206,7 @@ def lookup_qualified(self, name: str, ctx: Context,

def lookup_type_node(self, expr: Expression) -> Optional[SymbolTableNode]:
try:
t = expr_to_unanalyzed_type(expr)
t = expr_to_unanalyzed_type(expr, self.options)
except TypeTranslationError:
return None
if isinstance(t, UnboundType):
Expand Down Expand Up @@ -4919,7 +4926,7 @@ def expr_to_analyzed_type(self,
assert info.tuple_type, "NamedTuple without tuple type"
fallback = Instance(info, [])
return TupleType(info.tuple_type.items, fallback=fallback)
typ = expr_to_unanalyzed_type(expr)
typ = expr_to_unanalyzed_type(expr, self.options)
return self.anal_type(typ, report_invalid_types=report_invalid_types,
allow_placeholder=allow_placeholder)

Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal_namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def parse_namedtuple_fields_with_types(self, nodes: List[Expression], context: C
self.fail("Invalid NamedTuple() field name", item)
return None
try:
type = expr_to_unanalyzed_type(type_node)
type = expr_to_unanalyzed_type(type_node, self.options)
except TypeTranslationError:
self.fail('Invalid field type', type_node)
return None
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal_newtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def check_newtype_args(self, name: str, call: CallExpr,
# Check second argument
msg = "Argument 2 to NewType(...) must be a valid type"
try:
unanalyzed_type = expr_to_unanalyzed_type(args[1])
unanalyzed_type = expr_to_unanalyzed_type(args[1], self.options)
except TypeTranslationError:
self.fail(msg, context)
return None, False
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal_typeddict.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def parse_typeddict_fields_with_types(
self.fail_typeddict_arg("Invalid TypedDict() field name", name_context)
return [], [], False
try:
type = expr_to_unanalyzed_type(field_type_expr)
type = expr_to_unanalyzed_type(field_type_expr, self.options)
except TypeTranslationError:
self.fail_typeddict_arg('Invalid field type', field_type_expr)
return [], [], False
Expand Down
2 changes: 1 addition & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def analyze_type_alias(node: Expression,
Return None otherwise. 'node' must have been semantically analyzed.
"""
try:
type = expr_to_unanalyzed_type(node)
type = expr_to_unanalyzed_type(node, options)
except TypeTranslationError:
api.fail('Invalid type alias: expression is not a valid type', node)
return None
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -3154,7 +3154,7 @@ def foo(arg: Type[Any]):
from typing import Type, Any
def foo(arg: Type[Any]):
reveal_type(arg.__str__) # N: Revealed type is "def () -> builtins.str"
reveal_type(arg.mro()) # N: Revealed type is "builtins.list[builtins.type]"
reveal_type(arg.mro()) # N: Revealed type is "builtins.list[builtins.type[Any]]"
[builtins fixtures/type.pyi]
[out]

Expand Down
58 changes: 41 additions & 17 deletions test-data/unit/check-union-or-syntax.test
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def f(x: int | str) -> int | str:
reveal_type(f) # N: Revealed type is "def (x: Union[builtins.int, builtins.str]) -> Union[builtins.int, builtins.str]"
[builtins fixtures/tuple.pyi]


[case testUnionOrSyntaxWithThreeBuiltinsTypes]
# flags: --python-version 3.10
def f(x: int | str | float) -> int | str | float:
Expand All @@ -21,7 +20,6 @@ def f(x: int | str | float) -> int | str | float:
return x
reveal_type(f) # N: Revealed type is "def (x: Union[builtins.int, builtins.str, builtins.float]) -> Union[builtins.int, builtins.str, builtins.float]"


[case testUnionOrSyntaxWithTwoTypes]
# flags: --python-version 3.10
class A: pass
Expand All @@ -33,7 +31,6 @@ def f(x: A | B) -> A | B:
return x
reveal_type(f) # N: Revealed type is "def (x: Union[__main__.A, __main__.B]) -> Union[__main__.A, __main__.B]"


[case testUnionOrSyntaxWithThreeTypes]
# flags: --python-version 3.10
class A: pass
Expand All @@ -46,34 +43,29 @@ def f(x: A | B | C) -> A | B | C:
return x
reveal_type(f) # N: Revealed type is "def (x: Union[__main__.A, __main__.B, __main__.C]) -> Union[__main__.A, __main__.B, __main__.C]"


[case testUnionOrSyntaxWithLiteral]
# flags: --python-version 3.10
from typing_extensions import Literal
reveal_type(Literal[4] | str) # N: Revealed type is "Any"
[builtins fixtures/tuple.pyi]


[case testUnionOrSyntaxWithBadOperator]
# flags: --python-version 3.10
x: 1 + 2 # E: Invalid type comment or annotation


[case testUnionOrSyntaxWithBadOperands]
# flags: --python-version 3.10
x: int | 42 # E: Invalid type: try using Literal[42] instead?
y: 42 | int # E: Invalid type: try using Literal[42] instead?
z: str | 42 | int # E: Invalid type: try using Literal[42] instead?


[case testUnionOrSyntaxWithGenerics]
# flags: --python-version 3.10
from typing import List
x: List[int | str]
reveal_type(x) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.str]]"
[builtins fixtures/list.pyi]


[case testUnionOrSyntaxWithQuotedFunctionTypes]
# flags: --python-version 3.4
from typing import Union
Expand All @@ -87,47 +79,79 @@ def g(x: "int | str | None") -> "int | None":
return 42
reveal_type(g) # N: Revealed type is "def (x: Union[builtins.int, builtins.str, None]) -> Union[builtins.int, None]"


[case testUnionOrSyntaxWithQuotedVariableTypes]
# flags: --python-version 3.6
y: "int | str" = 42
reveal_type(y) # N: Revealed type is "Union[builtins.int, builtins.str]"


[case testUnionOrSyntaxWithTypeAliasWorking]
# flags: --python-version 3.10
from typing import Union
T = Union[int, str]
T = int | str
x: T
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]"
S = list[int] | str | None
y: S
reveal_type(y) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.str, None]"
U = str | None
z: U
reveal_type(z) # N: Revealed type is "Union[builtins.str, None]"

def f(): pass

X = int | str | f()
b: X # E: Variable "__main__.X" is not valid as a type \
# N: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
[builtins fixtures/type.pyi]

[case testUnionOrSyntaxWithTypeAliasNotAllowed]
[case testUnionOrSyntaxWithinRuntimeContextNotAllowed]
# flags: --python-version 3.9
from __future__ import annotations
T = int | str # E: Unsupported left operand type for | ("Type[int]")
from typing import List
T = int | str # E: Invalid type alias: expression is not a valid type \
# E: Unsupported left operand type for | ("Type[int]")
class C(List[int | str]): # E: Type expected within [...] \
# E: Invalid base class "List"
pass
C()
[builtins fixtures/tuple.pyi]

[case testUnionOrSyntaxWithinRuntimeContextNotAllowed2]
# flags: --python-version 3.9
from __future__ import annotations
from typing import cast
cast(str | int, 'x') # E: Cast target is not a type
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]

[case testUnionOrSyntaxInComment]
# flags: --python-version 3.6
x = 1 # type: int | str


[case testUnionOrSyntaxFutureImport]
# flags: --python-version 3.7
from __future__ import annotations
x: int | None
[builtins fixtures/tuple.pyi]


[case testUnionOrSyntaxMissingFutureImport]
# flags: --python-version 3.9
x: int | None # E: X | Y syntax for unions requires Python 3.10


[case testUnionOrSyntaxInStubFile]
# flags: --python-version 3.6
from lib import x
[file lib.pyi]
x: int | None

[case testUnionOrSyntaxInMiscRuntimeContexts]
# flags: --python-version 3.10
from typing import cast

class C(list[int | None]):
pass

def f() -> object: pass

reveal_type(cast(str | None, f())) # N: Revealed type is "Union[builtins.str, None]"
reveal_type(list[str | None]()) # N: Revealed type is "builtins.list[Union[builtins.str, None]]"
[builtins fixtures/type.pyi]
Loading