Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TypeGuard (PEP 647) #9865

Merged
merged 27 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9469ffc
Tentative first steps for TypeGuard (PEP 647)
gvanrossum Dec 29, 2020
884d9af
Rename is_type_guard to type_guard, fix repr and serialization
gvanrossum Dec 29, 2020
f55b284
Perhaps naive but passes the test
gvanrossum Dec 30, 2020
bdc97ef
Add more tests, tweak tests (one new test fails)
gvanrossum Dec 30, 2020
0ed5434
Update typeshed
gvanrossum Dec 30, 2020
47df17d
Make is_str_list() test work
gvanrossum Dec 30, 2020
fbfeeb6
Add a test showing a controversial behavior
gvanrossum Dec 30, 2020
03e566f
Fix mypy and lint
gvanrossum Dec 30, 2020
93387bd
Update typeshed and fix test name typo
gvanrossum Dec 30, 2020
708f2e5
Sync typeshed to the version that has TypeGuard
gvanrossum Dec 30, 2020
cf067d1
Merge branch 'master' into typeguard
gvanrossum Dec 30, 2020
0d2eb06
Add two new tests
gvanrossum Dec 31, 2020
8264f8d
Minimal changes to make filter() test pass
gvanrossum Jan 1, 2021
4335785
Fix mypy error in new code (corrected)
gvanrossum Jan 1, 2021
b34d2ac
Make methods work (adds a field to RefExpr)
gvanrossum Jan 2, 2021
9bdd779
Move walrus test to 3.8-only test file
gvanrossum Jan 2, 2021
249b6e5
Merge branch 'master' into typeguard
gvanrossum Jan 4, 2021
639b0fc
Add cross-module test; remove test TODOs
gvanrossum Jan 4, 2021
4400702
Merge remote-tracking branch 'origin/master' into typeguard
gvanrossum Jan 5, 2021
d108c6e
Merge remote-tracking branch 'origin/master' into typeguard
gvanrossum Jan 8, 2021
d96beea
Add many new tests
gvanrossum Jan 8, 2021
370818f
Require that a type guard's first argument is positional
gvanrossum Jan 11, 2021
9062adb
Capitalize error message
gvanrossum Jan 11, 2021
5e76923
Avoid using **extra if possible
gvanrossum Jan 11, 2021
37d2a5f
Clean up testTypeGuardOverload -- it still fails, though
gvanrossum Jan 12, 2021
9d7b6c6
Merge remote-tracking branch 'origin/master' into typeguard
gvanrossum Jan 18, 2021
896c90e
Fix lint
gvanrossum Jan 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef,
is_named_instance, union_items, TypeQuery, LiteralType,
is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType,
get_proper_types, is_literal_type, TypeAliasType)
get_proper_types, is_literal_type, TypeAliasType, TypeGuardType)
from mypy.sametypes import is_same_type
from mypy.messages import (
MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq,
Expand Down Expand Up @@ -3957,6 +3957,7 @@ def find_isinstance_check(self, node: Expression
) -> Tuple[TypeMap, TypeMap]:
"""Find any isinstance checks (within a chain of ands). Includes
implicit and explicit checks for None and calls to callable.
Also includes TypeGuard functions.

Return value is a map of variables to their types if the condition
is true and a map of variables to their types if the condition is false.
Expand Down Expand Up @@ -4001,6 +4002,14 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
if literal(expr) == LITERAL_TYPE:
vartype = type_map[expr]
return self.conditional_callable_type_map(expr, vartype)
elif isinstance(node.callee, RefExpr):
if node.callee.type_guard is not None:
# TODO: Follow keyword args or *args, **kwargs
if node.arg_kinds[0] != nodes.ARG_POS:
self.fail("Type guard requires positional argument", node)
return {}, {}
if literal(expr) == LITERAL_TYPE:
return {expr: TypeGuardType(node.callee.type_guard)}, {}
elif isinstance(node, ComparisonExpr):
# Step 1: Obtain the types of each operand and whether or not we can
# narrow their types. (For example, we shouldn't try narrowing the
Expand Down
11 changes: 10 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
make_optional_type,
)
from mypy.types import (
Type, AnyType, CallableType, Overloaded, NoneType, TypeVarDef,
Type, AnyType, CallableType, Overloaded, NoneType, TypeGuardType, TypeVarDef,
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
is_named_instance, FunctionLike,
Expand Down Expand Up @@ -317,6 +317,11 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
ret_type=self.object_type(),
fallback=self.named_type('builtins.function'))
callee_type = get_proper_type(self.accept(e.callee, type_context, always_allow_any=True))
if (isinstance(e.callee, RefExpr)
and isinstance(callee_type, CallableType)
and callee_type.type_guard is not None):
# Cache it for find_isinstance_check()
e.callee.type_guard = callee_type.type_guard
if (self.chk.options.disallow_untyped_calls and
self.chk.in_checked_function() and
isinstance(callee_type, CallableType)
Expand Down Expand Up @@ -4163,6 +4168,10 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type,
"""
if literal(expr) >= LITERAL_TYPE:
restriction = self.chk.binder.get(expr)
# Ignore the error about using get_proper_type().
if isinstance(restriction, TypeGuardType): # type: ignore[misc]
# A type guard forces the new type even if it doesn't overlap the old.
return restriction.type_guard
# If the current node is deferred, some variables may get Any types that they
# otherwise wouldn't have. We don't want to narrow down these since it may
# produce invalid inferred Optional[Any] types, at least.
Expand Down
7 changes: 6 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,12 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]:
for t, a in zip(template.arg_types, cactual.arg_types):
# Negate direction due to function argument type contravariance.
res.extend(infer_constraints(t, a, neg_op(self.direction)))
res.extend(infer_constraints(template.ret_type, cactual.ret_type,
template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
if template.type_guard is not None:
template_ret_type = template.type_guard
if cactual.type_guard is not None:
cactual_ret_type = cactual.type_guard
res.extend(infer_constraints(template_ret_type, cactual_ret_type,
self.direction))
return res
elif isinstance(self.actual, AnyType):
Expand Down
4 changes: 3 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def visit_type_var(self, t: TypeVarType) -> Type:

def visit_callable_type(self, t: CallableType) -> Type:
return t.copy_modified(arg_types=self.expand_types(t.arg_types),
ret_type=t.ret_type.accept(self))
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self)
if t.type_guard is not None else None))

def visit_overloaded(self, t: Overloaded) -> Type:
items = [] # type: List[CallableType]
Expand Down
2 changes: 2 additions & 0 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def visit_callable_type(self, ct: CallableType) -> None:
for arg in ct.bound_args:
if arg:
arg.accept(self)
if ct.type_guard is not None:
ct.type_guard.accept(self)

def visit_overloaded(self, t: Overloaded) -> None:
for ct in t.items():
Expand Down
5 changes: 4 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,8 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
class RefExpr(Expression):
"""Abstract base class for name-like constructs"""

__slots__ = ('kind', 'node', 'fullname', 'is_new_def', 'is_inferred_def', 'is_alias_rvalue')
__slots__ = ('kind', 'node', 'fullname', 'is_new_def', 'is_inferred_def', 'is_alias_rvalue',
'type_guard')

def __init__(self) -> None:
super().__init__()
Expand All @@ -1467,6 +1468,8 @@ def __init__(self) -> None:
self.is_inferred_def = False
# Is this expression appears as an rvalue of a valid type alias definition?
self.is_alias_rvalue = False
# Cache type guard from callable_type.type_guard
self.type_guard = None # type: Optional[mypy.types.Type]


class NameExpr(RefExpr):
Expand Down
1 change: 1 addition & 0 deletions mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
'check-annotated.test',
'check-parameter-specification.test',
'check-generic-alias.test',
'check-typeguard.test',
]

# Tests that use Python 3.8-only AST features (like expression-scoped ignores):
Expand Down
24 changes: 23 additions & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt
" and at least one annotation", t)
return AnyType(TypeOfAny.from_error)
return self.anal_type(t.args[0])
elif self.anal_type_guard_arg(t, fullname) is not None:
# In most contexts, TypeGuard[...] acts as an alias for bool (ignoring its args)
return self.named_type('builtins.bool')
return None

def get_omitted_any(self, typ: Type, fullname: Optional[str] = None) -> AnyType:
Expand Down Expand Up @@ -524,15 +527,34 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
variables = t.variables
else:
variables = self.bind_function_type_variables(t, t)
special = self.anal_type_guard(t.ret_type)
ret = t.copy_modified(arg_types=self.anal_array(t.arg_types, nested=nested),
ret_type=self.anal_type(t.ret_type, nested=nested),
# If the fallback isn't filled in yet,
# its type will be the falsey FakeInfo
fallback=(t.fallback if t.fallback.type
else self.named_type('builtins.function')),
variables=self.anal_var_defs(variables))
variables=self.anal_var_defs(variables),
type_guard=special,
)
return ret

def anal_type_guard(self, t: Type) -> Optional[Type]:
if isinstance(t, UnboundType):
sym = self.lookup_qualified(t.name, t)
if sym is not None and sym.node is not None:
return self.anal_type_guard_arg(t, sym.node.fullname)
# TODO: What if it's an Instance? Then use t.type.fullname?
return None

def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Optional[Type]:
if fullname in ('typing_extensions.TypeGuard', 'typing.TypeGuard'):
if len(t.args) != 1:
self.fail("TypeGuard must have exactly one type argument", t)
return AnyType(TypeOfAny.from_error)
return self.anal_type(t.args[0])
return None

def visit_overloaded(self, t: Overloaded) -> Type:
# Overloaded types are manually constructed in semanal.py by analyzing the
# AST and combining together the Callable types this visitor converts.
Expand Down
30 changes: 27 additions & 3 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,16 @@ def copy_modified(self, *,
self.line, self.column)


class TypeGuardType(Type):
"""Only used by find_instance_check() etc."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add link to the PEP in the docstring.

def __init__(self, type_guard: Type):
super().__init__(line=type_guard.line, column=type_guard.column)
self.type_guard = type_guard

def __repr__(self) -> str:
return "TypeGuard({})".format(self.type_guard)


class ProperType(Type):
"""Not a type alias.

Expand Down Expand Up @@ -1005,6 +1015,7 @@ class CallableType(FunctionLike):
# tools that consume mypy ASTs
'def_extras', # Information about original definition we want to serialize.
# This is used for more detailed error messages.
'type_guard', # T, if -> TypeGuard[T] (ret_type is bool in this case).
)

def __init__(self,
Expand All @@ -1024,6 +1035,7 @@ def __init__(self,
from_type_type: bool = False,
bound_args: Sequence[Optional[Type]] = (),
def_extras: Optional[Dict[str, Any]] = None,
type_guard: Optional[Type] = None,
) -> None:
super().__init__(line, column)
assert len(arg_types) == len(arg_kinds) == len(arg_names)
Expand Down Expand Up @@ -1058,6 +1070,7 @@ def __init__(self,
not definition.is_static else None}
else:
self.def_extras = {}
self.type_guard = type_guard

def copy_modified(self,
arg_types: Bogus[Sequence[Type]] = _dummy,
Expand All @@ -1075,7 +1088,9 @@ def copy_modified(self,
special_sig: Bogus[Optional[str]] = _dummy,
from_type_type: Bogus[bool] = _dummy,
bound_args: Bogus[List[Optional[Type]]] = _dummy,
def_extras: Bogus[Dict[str, Any]] = _dummy) -> 'CallableType':
def_extras: Bogus[Dict[str, Any]] = _dummy,
type_guard: Bogus[Optional[Type]] = _dummy,
) -> 'CallableType':
return CallableType(
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
arg_kinds=arg_kinds if arg_kinds is not _dummy else self.arg_kinds,
Expand All @@ -1094,6 +1109,7 @@ def copy_modified(self,
from_type_type=from_type_type if from_type_type is not _dummy else self.from_type_type,
bound_args=bound_args if bound_args is not _dummy else self.bound_args,
def_extras=def_extras if def_extras is not _dummy else dict(self.def_extras),
type_guard=type_guard if type_guard is not _dummy else self.type_guard,
)

def var_arg(self) -> Optional[FormalArgument]:
Expand Down Expand Up @@ -1255,6 +1271,8 @@ def __eq__(self, other: object) -> bool:
def serialize(self) -> JsonDict:
# TODO: As an optimization, leave out everything related to
# generic functions for non-generic functions.
assert (self.type_guard is None
or isinstance(get_proper_type(self.type_guard), Instance)), str(self.type_guard)
return {'.class': 'CallableType',
'arg_types': [t.serialize() for t in self.arg_types],
'arg_kinds': self.arg_kinds,
Expand All @@ -1269,6 +1287,7 @@ def serialize(self) -> JsonDict:
'bound_args': [(None if t is None else t.serialize())
for t in self.bound_args],
'def_extras': dict(self.def_extras),
'type_guard': self.type_guard.serialize() if self.type_guard is not None else None,
}

@classmethod
Expand All @@ -1286,7 +1305,9 @@ def deserialize(cls, data: JsonDict) -> 'CallableType':
implicit=data['implicit'],
bound_args=[(None if t is None else deserialize_type(t))
for t in data['bound_args']],
def_extras=data['def_extras']
def_extras=data['def_extras'],
type_guard=(deserialize_type(data['type_guard'])
if data['type_guard'] is not None else None),
)


Expand Down Expand Up @@ -2097,7 +2118,10 @@ def visit_callable_type(self, t: CallableType) -> str:
s = '({})'.format(s)

if not isinstance(get_proper_type(t.ret_type), NoneType):
s += ' -> {}'.format(t.ret_type.accept(self))
if t.type_guard is not None:
s += ' -> TypeGuard[{}]'.format(t.type_guard.accept(self))
else:
s += ' -> {}'.format(t.ret_type.accept(self))

if t.variables:
vs = []
Expand Down
9 changes: 9 additions & 0 deletions test-data/unit/check-python38.test
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,12 @@ def func() -> None:
class Foo:
def __init__(self) -> None:
self.x = 123

[case testWalrusTypeGuard]
from typing_extensions import TypeGuard
def is_float(a: object) -> TypeGuard[float]: pass
def main(a: object) -> None:
if is_float(x := a):
reveal_type(x) # N: Revealed type is 'builtins.float'
reveal_type(a) # N: Revealed type is 'builtins.object'
[builtins fixtures/tuple.pyi]
15 changes: 15 additions & 0 deletions test-data/unit/check-serialize.test
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,21 @@ def f(x: int) -> int: pass
tmp/a.py:2: note: Revealed type is 'builtins.str'
tmp/a.py:3: error: Unexpected keyword argument "x" for "f"

[case testSerializeTypeGuardFunction]
import a
[file a.py]
import b
[file a.py.2]
import b
reveal_type(b.guard(''))
reveal_type(b.guard)
[file b.py]
from typing_extensions import TypeGuard
def guard(a: object) -> TypeGuard[str]: pass
[builtins fixtures/tuple.pyi]
[out2]
tmp/a.py:2: note: Revealed type is 'builtins.bool'
tmp/a.py:3: note: Revealed type is 'def (a: builtins.object) -> TypeGuard[builtins.str]'
--
-- Classes
--
Expand Down
Loading