Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TypeGuard (PEP 647) #9865

Merged
merged 27 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9469ffc
Tentative first steps for TypeGuard (PEP 647)
gvanrossum Dec 29, 2020
884d9af
Rename is_type_guard to type_guard, fix repr and serialization
gvanrossum Dec 29, 2020
f55b284
Perhaps naive but passes the test
gvanrossum Dec 30, 2020
bdc97ef
Add more tests, tweak tests (one new test fails)
gvanrossum Dec 30, 2020
0ed5434
Update typeshed
gvanrossum Dec 30, 2020
47df17d
Make is_str_list() test work
gvanrossum Dec 30, 2020
fbfeeb6
Add a test showing a controversial behavior
gvanrossum Dec 30, 2020
03e566f
Fix mypy and lint
gvanrossum Dec 30, 2020
93387bd
Update typeshed and fix test name typo
gvanrossum Dec 30, 2020
708f2e5
Sync typeshed to the version that has TypeGuard
gvanrossum Dec 30, 2020
cf067d1
Merge branch 'master' into typeguard
gvanrossum Dec 30, 2020
0d2eb06
Add two new tests
gvanrossum Dec 31, 2020
8264f8d
Minimal changes to make filter() test pass
gvanrossum Jan 1, 2021
4335785
Fix mypy error in new code (corrected)
gvanrossum Jan 1, 2021
b34d2ac
Make methods work (adds a field to RefExpr)
gvanrossum Jan 2, 2021
9bdd779
Move walrus test to 3.8-only test file
gvanrossum Jan 2, 2021
249b6e5
Merge branch 'master' into typeguard
gvanrossum Jan 4, 2021
639b0fc
Add cross-module test; remove test TODOs
gvanrossum Jan 4, 2021
4400702
Merge remote-tracking branch 'origin/master' into typeguard
gvanrossum Jan 5, 2021
d108c6e
Merge remote-tracking branch 'origin/master' into typeguard
gvanrossum Jan 8, 2021
d96beea
Add many new tests
gvanrossum Jan 8, 2021
370818f
Require that a type guard's first argument is positional
gvanrossum Jan 11, 2021
9062adb
Capitalize error message
gvanrossum Jan 11, 2021
5e76923
Avoid using **extra if possible
gvanrossum Jan 11, 2021
37d2a5f
Clean up testTypeGuardOverload -- it still fails, though
gvanrossum Jan 12, 2021
9d7b6c6
Merge remote-tracking branch 'origin/master' into typeguard
gvanrossum Jan 18, 2021
896c90e
Fix lint
gvanrossum Jan 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Context, Decorator, PrintStmt, BreakStmt, PassStmt, ContinueStmt,
ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, PromoteExpr,
Import, ImportFrom, ImportAll, ImportBase, TypeAlias,
ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF,
ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF, SYMBOL_FUNCBASE_TYPES,
CONTRAVARIANT, COVARIANT, INVARIANT, TypeVarExpr, AssignmentExpr,
is_final_node,
ARG_NAMED)
Expand All @@ -35,7 +35,7 @@
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef,
is_named_instance, union_items, TypeQuery, LiteralType,
is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType,
get_proper_types, is_literal_type, TypeAliasType)
get_proper_types, is_literal_type, TypeAliasType, TypeGuardType)
from mypy.sametypes import is_same_type
from mypy.messages import (
MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq,
Expand Down Expand Up @@ -3957,6 +3957,7 @@ def find_isinstance_check(self, node: Expression
) -> Tuple[TypeMap, TypeMap]:
"""Find any isinstance checks (within a chain of ands). Includes
implicit and explicit checks for None and calls to callable.
Also includes TypeGuard functions.

Return value is a map of variables to their types if the condition
is true and a map of variables to their types if the condition is false.
Expand Down Expand Up @@ -4001,6 +4002,15 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
if literal(expr) == LITERAL_TYPE:
vartype = type_map[expr]
return self.conditional_callable_type_map(expr, vartype)
else:
if (isinstance(node.callee, RefExpr)
and isinstance(node.callee.node, SYMBOL_FUNCBASE_TYPES)
and isinstance(node.callee.node.type, CallableType)
and node.callee.node.type.type_guard is not None):
if len(node.args) < 1: # TODO: Is this an error?
Copy link
Collaborator

@hauntsaninja hauntsaninja Dec 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine to me, it's similar to what happens with if isinstance() + you have a TODO in your tests to error at the definition site for functions that don't take args

return {}, {}
if literal(expr) == LITERAL_TYPE:
return {expr: TypeGuardType(node.callee.node.type.type_guard)}, {}
elif isinstance(node, ComparisonExpr):
# Step 1: Obtain the types of each operand and whether or not we can
# narrow their types. (For example, we shouldn't try narrowing the
Expand Down
6 changes: 5 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
make_optional_type,
)
from mypy.types import (
Type, AnyType, CallableType, Overloaded, NoneType, TypeVarDef,
Type, AnyType, CallableType, Overloaded, NoneType, TypeGuardType, TypeVarDef,
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
is_named_instance, FunctionLike,
Expand Down Expand Up @@ -4163,6 +4163,10 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type,
"""
if literal(expr) >= LITERAL_TYPE:
restriction = self.chk.binder.get(expr)
# Ignore the error about using get_proper_type().
if isinstance(restriction, TypeGuardType): # type: ignore[misc]
# A type guard forces the new type even if it doesn't overlap the old.
return restriction.type_guard
# If the current node is deferred, some variables may get Any types that they
# otherwise wouldn't have. We don't want to narrow down these since it may
# produce invalid inferred Optional[Any] types, at least.
Expand Down
2 changes: 2 additions & 0 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def visit_callable_type(self, ct: CallableType) -> None:
for arg in ct.bound_args:
if arg:
arg.accept(self)
if ct.type_guard is not None:
ct.type_guard.accept(self)

def visit_overloaded(self, t: Overloaded) -> None:
for ct in t.items():
Expand Down
1 change: 1 addition & 0 deletions mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
'check-annotated.test',
'check-parameter-specification.test',
'check-generic-alias.test',
'check-typeguard.test',
]

# Tests that use Python 3.8-only AST features (like expression-scoped ignores):
Expand Down
24 changes: 23 additions & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt
" and at least one annotation", t)
return AnyType(TypeOfAny.from_error)
return self.anal_type(t.args[0])
elif self.anal_type_guard_arg(t, fullname) is not None:
# In most contexts, TypeGuard[...] acts as an alias for bool (ignoring its args)
return self.named_type('builtins.bool')
return None

def get_omitted_any(self, typ: Type, fullname: Optional[str] = None) -> AnyType:
Expand Down Expand Up @@ -524,15 +527,34 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
variables = t.variables
else:
variables = self.bind_function_type_variables(t, t)
special = self.anal_type_guard(t.ret_type)
ret = t.copy_modified(arg_types=self.anal_array(t.arg_types, nested=nested),
ret_type=self.anal_type(t.ret_type, nested=nested),
# If the fallback isn't filled in yet,
# its type will be the falsey FakeInfo
fallback=(t.fallback if t.fallback.type
else self.named_type('builtins.function')),
variables=self.anal_var_defs(variables))
variables=self.anal_var_defs(variables),
type_guard=special,
)
return ret

def anal_type_guard(self, t: Type) -> Optional[Type]:
if isinstance(t, UnboundType):
sym = self.lookup_qualified(t.name, t)
if sym is not None and sym.node is not None:
return self.anal_type_guard_arg(t, sym.node.fullname)
# TODO: What if it's an Instance? Then use t.type.fullname?
return None

def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Optional[Type]:
if fullname in ('typing_extensions.TypeGuard', 'typing.TypeGuard'):
if len(t.args) != 1:
self.fail("TypeGuard must have exactly one type argument", t)
return AnyType(TypeOfAny.from_error)
return self.anal_type(t.args[0])
return None

def visit_overloaded(self, t: Overloaded) -> Type:
# Overloaded types are manually constructed in semanal.py by analyzing the
# AST and combining together the Callable types this visitor converts.
Expand Down
30 changes: 27 additions & 3 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,16 @@ def copy_modified(self, *,
self.line, self.column)


class TypeGuardType(Type):
"""Only used by find_instance_check() etc."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add link to the PEP in the docstring.

def __init__(self, type_guard: Type):
super().__init__(line=type_guard.line, column=type_guard.column)
self.type_guard = type_guard

def __repr__(self) -> str:
return "TypeGuard({})".format(self.type_guard)


class ProperType(Type):
"""Not a type alias.

Expand Down Expand Up @@ -1005,6 +1015,7 @@ class CallableType(FunctionLike):
# tools that consume mypy ASTs
'def_extras', # Information about original definition we want to serialize.
# This is used for more detailed error messages.
'type_guard', # T, if -> TypeGuard[T] (ret_type is bool in this case).
)

def __init__(self,
Expand All @@ -1024,6 +1035,7 @@ def __init__(self,
from_type_type: bool = False,
bound_args: Sequence[Optional[Type]] = (),
def_extras: Optional[Dict[str, Any]] = None,
type_guard: Optional[Type] = None,
) -> None:
super().__init__(line, column)
assert len(arg_types) == len(arg_kinds) == len(arg_names)
Expand Down Expand Up @@ -1058,6 +1070,7 @@ def __init__(self,
not definition.is_static else None}
else:
self.def_extras = {}
self.type_guard = type_guard

def copy_modified(self,
arg_types: Bogus[Sequence[Type]] = _dummy,
Expand All @@ -1075,7 +1088,9 @@ def copy_modified(self,
special_sig: Bogus[Optional[str]] = _dummy,
from_type_type: Bogus[bool] = _dummy,
bound_args: Bogus[List[Optional[Type]]] = _dummy,
def_extras: Bogus[Dict[str, Any]] = _dummy) -> 'CallableType':
def_extras: Bogus[Dict[str, Any]] = _dummy,
type_guard: Bogus[Optional[Type]] = _dummy,
) -> 'CallableType':
return CallableType(
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
arg_kinds=arg_kinds if arg_kinds is not _dummy else self.arg_kinds,
Expand All @@ -1094,6 +1109,7 @@ def copy_modified(self,
from_type_type=from_type_type if from_type_type is not _dummy else self.from_type_type,
bound_args=bound_args if bound_args is not _dummy else self.bound_args,
def_extras=def_extras if def_extras is not _dummy else dict(self.def_extras),
type_guard=type_guard if type_guard is not _dummy else self.type_guard,
)

def var_arg(self) -> Optional[FormalArgument]:
Expand Down Expand Up @@ -1255,6 +1271,8 @@ def __eq__(self, other: object) -> bool:
def serialize(self) -> JsonDict:
# TODO: As an optimization, leave out everything related to
# generic functions for non-generic functions.
assert (self.type_guard is None
or isinstance(get_proper_type(self.type_guard), Instance)), str(self.type_guard)
return {'.class': 'CallableType',
'arg_types': [t.serialize() for t in self.arg_types],
'arg_kinds': self.arg_kinds,
Expand All @@ -1269,6 +1287,7 @@ def serialize(self) -> JsonDict:
'bound_args': [(None if t is None else t.serialize())
for t in self.bound_args],
'def_extras': dict(self.def_extras),
'type_guard': self.type_guard.serialize() if self.type_guard is not None else None,
}

@classmethod
Expand All @@ -1286,7 +1305,9 @@ def deserialize(cls, data: JsonDict) -> 'CallableType':
implicit=data['implicit'],
bound_args=[(None if t is None else deserialize_type(t))
for t in data['bound_args']],
def_extras=data['def_extras']
def_extras=data['def_extras'],
type_guard=(deserialize_type(data['type_guard'])
if data['type_guard'] is not None else None),
)


Expand Down Expand Up @@ -2097,7 +2118,10 @@ def visit_callable_type(self, t: CallableType) -> str:
s = '({})'.format(s)

if not isinstance(get_proper_type(t.ret_type), NoneType):
s += ' -> {}'.format(t.ret_type.accept(self))
if t.type_guard is not None:
s += ' -> TypeGuard[{}]'.format(t.type_guard.accept(self))
else:
s += ' -> {}'.format(t.ret_type.accept(self))

if t.variables:
vs = []
Expand Down
15 changes: 15 additions & 0 deletions test-data/unit/check-serialize.test
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,21 @@ def f(x: int) -> int: pass
tmp/a.py:2: note: Revealed type is 'builtins.str'
tmp/a.py:3: error: Unexpected keyword argument "x" for "f"

[case testSerializeTypeGuardFunction]
import a
[file a.py]
import b
[file a.py.2]
import b
reveal_type(b.guard(''))
reveal_type(b.guard)
[file b.py]
from typing_extensions import TypeGuard
def guard(a: object) -> TypeGuard[str]: pass
[builtins fixtures/tuple.pyi]
[out2]
tmp/a.py:2: note: Revealed type is 'builtins.bool'
tmp/a.py:3: note: Revealed type is 'def (a: builtins.object) -> TypeGuard[builtins.str]'
--
-- Classes
--
Expand Down
112 changes: 112 additions & 0 deletions test-data/unit/check-typeguard.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
[case testTypeGuardBasic]
from typing_extensions import TypeGuard
class Point: pass
def is_point(a: object) -> TypeGuard[Point]: pass
def main(a: object) -> None:
if is_point(a):
reveal_type(a) # N: Revealed type is '__main__.Point'
else:
reveal_type(a) # N: Revealed type is 'builtins.object'
[builtins fixtures/tuple.pyi]

[case testTypeGuardTypeArgsNone]
from typing_extensions import TypeGuard
def foo(a: object) -> TypeGuard: # E: TypeGuard must have exactly one type argument
pass
[builtins fixtures/tuple.pyi]

[case testTypeGuardTypeArgsTooMany]
from typing_extensions import TypeGuard
def foo(a: object) -> TypeGuard[int, int]: # E: TypeGuard must have exactly one type argument
pass
[builtins fixtures/tuple.pyi]

[case testTypeGuardTypeArgType]
from typing_extensions import TypeGuard
def foo(a: object) -> TypeGuard[42]: # E: Invalid type: try using Literal[42] instead?
pass
[builtins fixtures/tuple.pyi]

[case testTypeGuardRepr]
from typing_extensions import TypeGuard
def foo(a: object) -> TypeGuard[int]:
pass
reveal_type(foo) # N: Revealed type is 'def (a: builtins.object) -> TypeGuard[builtins.int]'
[builtins fixtures/tuple.pyi]

[case testTypeGuardCallArgsNone]
from typing_extensions import TypeGuard
class Point: pass
# TODO: error on the 'def' line (insufficient args for type guard)
def is_point() -> TypeGuard[Point]: pass
def main(a: object) -> None:
if is_point():
reveal_type(a) # N: Revealed type is 'builtins.object'
[builtins fixtures/tuple.pyi]

[case testTypeGuardCallArgsMultiple]
from typing_extensions import TypeGuard
class Point: pass
def is_point(a: object, b: object) -> TypeGuard[Point]: pass
def main(a: object, b: object) -> None:
if is_point(a, b):
reveal_type(a) # N: Revealed type is '__main__.Point'
reveal_type(b) # N: Revealed type is 'builtins.object'
[builtins fixtures/tuple.pyi]

[case testTypeGuardIsBool]
from typing_extensions import TypeGuard
def f(a: TypeGuard[int]) -> None: pass
reveal_type(f) # N: Revealed type is 'def (a: builtins.bool)'
a: TypeGuard[int]
reveal_type(a) # N: Revealed type is 'builtins.bool'
class C:
a: TypeGuard[int]
reveal_type(C().a) # N: Revealed type is 'builtins.bool'
[builtins fixtures/tuple.pyi]

[case testTypeGuardWithTypeVar]
from typing import TypeVar, Tuple
from typing_extensions import TypeGuard
T = TypeVar('T')
def is_two_element_tuple(a: Tuple[T, ...]) -> TypeGuard[Tuple[T, T]]: pass
def main(a: Tuple[T, ...]):
if is_two_element_tuple(a):
reveal_type(a) # N: Revealed type is 'Tuple[T`-1, T`-1]'
[builtins fixtures/tuple.pyi]

[case testTypeGuardNonOverlapping]
from typing import List
from typing_extensions import TypeGuard
def is_str_list(a: List[object]) -> TypeGuard[List[str]]: pass
def main(a: List[object]):
if is_str_list(a):
reveal_type(a) # N: Revealed type is 'builtins.list[builtins.str]'
[builtins fixtures/tuple.pyi]

[case testTypeGuardUnionIn]
from typing import Union
from typing_extensions import TypeGuard
def is_foo(a: Union[int, str]) -> TypeGuard[str]: pass
def main(a: Union[str, int]) -> None:
if is_foo(a):
reveal_type(a) # N: Revealed type is 'builtins.str'
[builtins fixtures/tuple.pyi]

[case testTypeGuardUnionOut]
from typing import Union
from typing_extensions import TypeGuard
def is_foo(a: object) -> TypeGuard[Union[int, str]]: pass
def main(a: object) -> None:
if is_foo(a):
reveal_type(a) # N: Revealed type is 'Union[builtins.int, builtins.str]'
[builtins fixtures/tuple.pyi]

[case testTypeGuardNonzeroFloat]
from typing import Union
from typing_extensions import TypeGuard
def is_nonzero(a: object) -> TypeGuard[float]: pass
def main(a: int):
if is_nonzero(a):
reveal_type(a) # N: Revealed type is 'builtins.float'
[builtins fixtures/tuple.pyi]
2 changes: 2 additions & 0 deletions test-data/unit/lib-stub/typing_extensions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Annotated: _SpecialForm = ...
ParamSpec: _SpecialForm
Concatenate: _SpecialForm

TypeGuard: _SpecialForm

# Fallback type for all typed dicts (does not exist at runtime).
class _TypedDict(Mapping[str, object]):
# Needed to make this class non-abstract. It is explicitly declared abstract in
Expand Down