Skip to content

Commit bcb3747

Browse files
Implement TypeIs (PEP 742) (#16898)
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
1 parent 3c87af2 commit bcb3747

19 files changed

+962
-19
lines changed

docs/source/error_code_list2.rst

+16
Original file line numberDiff line numberDiff line change
@@ -555,3 +555,19 @@ Correct usage:
555555
556556
When this code is enabled, using ``reveal_locals`` is always an error,
557557
because there's no way one can import it.
558+
559+
.. _code-narrowed-type-not-subtype:
560+
561+
Check that ``TypeIs`` narrows types [narrowed-type-not-subtype]
562+
---------------------------------------------------------------
563+
564+
:pep:`742` requires that when ``TypeIs`` is used, the narrowed
565+
type must be a subtype of the original type::
566+
567+
from typing_extensions import TypeIs
568+
569+
def f(x: int) -> TypeIs[str]: # Error, str is not a subtype of int
570+
...
571+
572+
def g(x: object) -> TypeIs[str]: # OK
573+
...

mypy/applytype.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,15 @@ def apply_generic_arguments(
137137
arg_types=[expand_type(at, id_to_type) for at in callable.arg_types]
138138
)
139139

140-
# Apply arguments to TypeGuard if any.
140+
# Apply arguments to TypeGuard and TypeIs if any.
141141
if callable.type_guard is not None:
142142
type_guard = expand_type(callable.type_guard, id_to_type)
143143
else:
144144
type_guard = None
145+
if callable.type_is is not None:
146+
type_is = expand_type(callable.type_is, id_to_type)
147+
else:
148+
type_is = None
145149

146150
# The callable may retain some type vars if only some were applied.
147151
# TODO: move apply_poly() logic from checkexpr.py here when new inference
@@ -164,4 +168,5 @@ def apply_generic_arguments(
164168
ret_type=expand_type(callable.ret_type, id_to_type),
165169
variables=remaining_tvars,
166170
type_guard=type_guard,
171+
type_is=type_is,
167172
)

mypy/checker.py

+38-4
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,22 @@ def check_func_def(
12031203
# visible from *inside* of this function/method.
12041204
ref_type: Type | None = self.scope.active_self_type()
12051205

1206+
if typ.type_is:
1207+
arg_index = 0
1208+
# For methods and classmethods, we want the second parameter
1209+
if ref_type is not None and (not defn.is_static or defn.name == "__new__"):
1210+
arg_index = 1
1211+
if arg_index < len(typ.arg_types) and not is_subtype(
1212+
typ.type_is, typ.arg_types[arg_index]
1213+
):
1214+
self.fail(
1215+
message_registry.NARROWED_TYPE_NOT_SUBTYPE.format(
1216+
format_type(typ.type_is, self.options),
1217+
format_type(typ.arg_types[arg_index], self.options),
1218+
),
1219+
item,
1220+
)
1221+
12061222
# Store argument types.
12071223
for i in range(len(typ.arg_types)):
12081224
arg_type = typ.arg_types[i]
@@ -2178,6 +2194,8 @@ def check_override(
21782194
elif isinstance(original, CallableType) and isinstance(override, CallableType):
21792195
if original.type_guard is not None and override.type_guard is None:
21802196
fail = True
2197+
if original.type_is is not None and override.type_is is None:
2198+
fail = True
21812199

21822200
if is_private(name):
21832201
fail = False
@@ -5643,7 +5661,7 @@ def combine_maps(list_maps: list[TypeMap]) -> TypeMap:
56435661
def find_isinstance_check(self, node: Expression) -> tuple[TypeMap, TypeMap]:
56445662
"""Find any isinstance checks (within a chain of ands). Includes
56455663
implicit and explicit checks for None and calls to callable.
5646-
Also includes TypeGuard functions.
5664+
Also includes TypeGuard and TypeIs functions.
56475665
56485666
Return value is a map of variables to their types if the condition
56495667
is true and a map of variables to their types if the condition is false.
@@ -5695,7 +5713,7 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
56955713
if literal(expr) == LITERAL_TYPE and attr and len(attr) == 1:
56965714
return self.hasattr_type_maps(expr, self.lookup_type(expr), attr[0])
56975715
elif isinstance(node.callee, RefExpr):
5698-
if node.callee.type_guard is not None:
5716+
if node.callee.type_guard is not None or node.callee.type_is is not None:
56995717
# TODO: Follow *args, **kwargs
57005718
if node.arg_kinds[0] != nodes.ARG_POS:
57015719
# the first argument might be used as a kwarg
@@ -5721,15 +5739,31 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
57215739
# we want the idx-th variable to be narrowed
57225740
expr = collapse_walrus(node.args[idx])
57235741
else:
5724-
self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node)
5742+
kind = (
5743+
"guard" if node.callee.type_guard is not None else "narrower"
5744+
)
5745+
self.fail(
5746+
message_registry.TYPE_GUARD_POS_ARG_REQUIRED.format(kind), node
5747+
)
57255748
return {}, {}
57265749
if literal(expr) == LITERAL_TYPE:
57275750
# Note: we wrap the target type, so that we can special case later.
57285751
# Namely, for isinstance() we use a normal meet, while TypeGuard is
57295752
# considered "always right" (i.e. even if the types are not overlapping).
57305753
# Also note that a care must be taken to unwrap this back at read places
57315754
# where we use this to narrow down declared type.
5732-
return {expr: TypeGuardedType(node.callee.type_guard)}, {}
5755+
if node.callee.type_guard is not None:
5756+
return {expr: TypeGuardedType(node.callee.type_guard)}, {}
5757+
else:
5758+
assert node.callee.type_is is not None
5759+
return conditional_types_to_typemaps(
5760+
expr,
5761+
*self.conditional_types_with_intersection(
5762+
self.lookup_type(expr),
5763+
[TypeRange(node.callee.type_is, is_upper_bound=False)],
5764+
expr,
5765+
),
5766+
)
57335767
elif isinstance(node, ComparisonExpr):
57345768
# Step 1: Obtain the types of each operand and whether or not we can
57355769
# narrow their types. (For example, we shouldn't try narrowing the

mypy/checkexpr.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -1451,13 +1451,12 @@ def check_call_expr_with_callee_type(
14511451
object_type=object_type,
14521452
)
14531453
proper_callee = get_proper_type(callee_type)
1454-
if (
1455-
isinstance(e.callee, RefExpr)
1456-
and isinstance(proper_callee, CallableType)
1457-
and proper_callee.type_guard is not None
1458-
):
1454+
if isinstance(e.callee, RefExpr) and isinstance(proper_callee, CallableType):
14591455
# Cache it for find_isinstance_check()
1460-
e.callee.type_guard = proper_callee.type_guard
1456+
if proper_callee.type_guard is not None:
1457+
e.callee.type_guard = proper_callee.type_guard
1458+
if proper_callee.type_is is not None:
1459+
e.callee.type_is = proper_callee.type_is
14611460
return ret_type
14621461

14631462
def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str) -> Type:
@@ -5283,7 +5282,7 @@ def infer_lambda_type_using_context(
52835282
# is a constructor -- but this fallback doesn't make sense for lambdas.
52845283
callable_ctx = callable_ctx.copy_modified(fallback=self.named_type("builtins.function"))
52855284

5286-
if callable_ctx.type_guard is not None:
5285+
if callable_ctx.type_guard is not None or callable_ctx.type_is is not None:
52875286
# Lambda's return type cannot be treated as a `TypeGuard`,
52885287
# because it is implicit. And `TypeGuard`s must be explicit.
52895288
# See https://github.com/python/mypy/issues/9927

mypy/constraints.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1018,10 +1018,22 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
10181018
param_spec = template.param_spec()
10191019

10201020
template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
1021-
if template.type_guard is not None:
1021+
if template.type_guard is not None and cactual.type_guard is not None:
10221022
template_ret_type = template.type_guard
1023-
if cactual.type_guard is not None:
10241023
cactual_ret_type = cactual.type_guard
1024+
elif template.type_guard is not None:
1025+
template_ret_type = AnyType(TypeOfAny.special_form)
1026+
elif cactual.type_guard is not None:
1027+
cactual_ret_type = AnyType(TypeOfAny.special_form)
1028+
1029+
if template.type_is is not None and cactual.type_is is not None:
1030+
template_ret_type = template.type_is
1031+
cactual_ret_type = cactual.type_is
1032+
elif template.type_is is not None:
1033+
template_ret_type = AnyType(TypeOfAny.special_form)
1034+
elif cactual.type_is is not None:
1035+
cactual_ret_type = AnyType(TypeOfAny.special_form)
1036+
10251037
res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))
10261038

10271039
if param_spec is None:

mypy/errorcodes.py

+6
Original file line numberDiff line numberDiff line change
@@ -281,5 +281,11 @@ def __hash__(self) -> int:
281281
sub_code_of=MISC,
282282
)
283283

284+
NARROWED_TYPE_NOT_SUBTYPE: Final[ErrorCode] = ErrorCode(
285+
"narrowed-type-not-subtype",
286+
"Warn if a TypeIs function's narrowed type is not a subtype of the original type",
287+
"General",
288+
)
289+
284290
# This copy will not include any error codes defined later in the plugins.
285291
mypy_error_codes = error_codes.copy()

mypy/expandtype.py

+2
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
351351
arg_names=t.arg_names[:-2] + repl.arg_names,
352352
ret_type=t.ret_type.accept(self),
353353
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
354+
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
354355
imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds),
355356
variables=[*repl.variables, *t.variables],
356357
)
@@ -384,6 +385,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
384385
arg_types=arg_types,
385386
ret_type=t.ret_type.accept(self),
386387
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
388+
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
387389
)
388390
if needs_normalization:
389391
return expanded.with_normalized_var_args()

mypy/fixup.py

+2
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ def visit_callable_type(self, ct: CallableType) -> None:
270270
arg.accept(self)
271271
if ct.type_guard is not None:
272272
ct.type_guard.accept(self)
273+
if ct.type_is is not None:
274+
ct.type_is.accept(self)
273275

274276
def visit_overloaded(self, t: Overloaded) -> None:
275277
for ct in t.items:

mypy/message_registry.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
262262

263263
CONTIGUOUS_ITERABLE_EXPECTED: Final = ErrorMessage("Contiguous iterable with same type expected")
264264
ITERABLE_TYPE_EXPECTED: Final = ErrorMessage("Invalid type '{}' for *expr (iterable expected)")
265-
TYPE_GUARD_POS_ARG_REQUIRED: Final = ErrorMessage("Type guard requires positional argument")
265+
TYPE_GUARD_POS_ARG_REQUIRED: Final = ErrorMessage("Type {} requires positional argument")
266266

267267
# Match Statement
268268
MISSING_MATCH_ARGS: Final = 'Class "{}" doesn\'t define "__match_args__"'
@@ -324,3 +324,6 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
324324
ARG_NAME_EXPECTED_STRING_LITERAL: Final = ErrorMessage(
325325
"Expected string literal for argument name, got {}", codes.SYNTAX
326326
)
327+
NARROWED_TYPE_NOT_SUBTYPE: Final = ErrorMessage(
328+
"Narrowed type {} is not a subtype of input type {}", codes.NARROWED_TYPE_NOT_SUBTYPE
329+
)

mypy/messages.py

+4
Original file line numberDiff line numberDiff line change
@@ -2643,6 +2643,8 @@ def format_literal_value(typ: LiteralType) -> str:
26432643
elif isinstance(func, CallableType):
26442644
if func.type_guard is not None:
26452645
return_type = f"TypeGuard[{format(func.type_guard)}]"
2646+
elif func.type_is is not None:
2647+
return_type = f"TypeIs[{format(func.type_is)}]"
26462648
else:
26472649
return_type = format(func.ret_type)
26482650
if func.is_ellipsis_args:
@@ -2859,6 +2861,8 @@ def [T <: int] f(self, x: int, y: T) -> None
28592861
s += " -> "
28602862
if tp.type_guard is not None:
28612863
s += f"TypeGuard[{format_type_bare(tp.type_guard, options)}]"
2864+
elif tp.type_is is not None:
2865+
s += f"TypeIs[{format_type_bare(tp.type_is, options)}]"
28622866
else:
28632867
s += format_type_bare(tp.ret_type, options)
28642868

mypy/nodes.py

+3
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,7 @@ class RefExpr(Expression):
17551755
"is_inferred_def",
17561756
"is_alias_rvalue",
17571757
"type_guard",
1758+
"type_is",
17581759
)
17591760

17601761
def __init__(self) -> None:
@@ -1776,6 +1777,8 @@ def __init__(self) -> None:
17761777
self.is_alias_rvalue = False
17771778
# Cache type guard from callable_type.type_guard
17781779
self.type_guard: mypy.types.Type | None = None
1780+
# And same for TypeIs
1781+
self.type_is: mypy.types.Type | None = None
17791782

17801783
@property
17811784
def fullname(self) -> str:

mypy/semanal.py

+7
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,13 @@ def analyze_func_def(self, defn: FuncDef) -> None:
881881
)
882882
# in this case, we just kind of just ... remove the type guard.
883883
result = result.copy_modified(type_guard=None)
884+
if result.type_is and ARG_POS not in result.arg_kinds[skip_self:]:
885+
self.fail(
886+
'"TypeIs" functions must have a positional argument',
887+
result,
888+
code=codes.VALID_TYPE,
889+
)
890+
result = result.copy_modified(type_is=None)
884891

885892
result = self.remove_unpack_kwargs(defn, result)
886893
if has_self_type and self.type is not None:

mypy/subtypes.py

+13
Original file line numberDiff line numberDiff line change
@@ -683,10 +683,23 @@ def visit_callable_type(self, left: CallableType) -> bool:
683683
if left.type_guard is not None and right.type_guard is not None:
684684
if not self._is_subtype(left.type_guard, right.type_guard):
685685
return False
686+
elif left.type_is is not None and right.type_is is not None:
687+
# For TypeIs we have to check both ways; it is unsafe to pass
688+
# a TypeIs[Child] when a TypeIs[Parent] is expected, because
689+
# if the narrower returns False, we assume that the narrowed value is
690+
# *not* a Parent.
691+
if not self._is_subtype(left.type_is, right.type_is) or not self._is_subtype(
692+
right.type_is, left.type_is
693+
):
694+
return False
686695
elif right.type_guard is not None and left.type_guard is None:
687696
# This means that one function has `TypeGuard` and other does not.
688697
# They are not compatible. See https://github.com/python/mypy/issues/11307
689698
return False
699+
elif right.type_is is not None and left.type_is is None:
700+
# Similarly, if one function has `TypeIs` and the other does not,
701+
# they are not compatible.
702+
return False
690703
return is_callable_compatible(
691704
left,
692705
right,

mypy/typeanal.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,10 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ
668668
)
669669
return AnyType(TypeOfAny.from_error)
670670
return RequiredType(self.anal_type(t.args[0]), required=False)
671-
elif self.anal_type_guard_arg(t, fullname) is not None:
671+
elif (
672+
self.anal_type_guard_arg(t, fullname) is not None
673+
or self.anal_type_is_arg(t, fullname) is not None
674+
):
672675
# In most contexts, TypeGuard[...] acts as an alias for bool (ignoring its args)
673676
return self.named_type("builtins.bool")
674677
elif fullname in ("typing.Unpack", "typing_extensions.Unpack"):
@@ -986,7 +989,8 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
986989
variables = t.variables
987990
else:
988991
variables, _ = self.bind_function_type_variables(t, t)
989-
special = self.anal_type_guard(t.ret_type)
992+
type_guard = self.anal_type_guard(t.ret_type)
993+
type_is = self.anal_type_is(t.ret_type)
990994
arg_kinds = t.arg_kinds
991995
if len(arg_kinds) >= 2 and arg_kinds[-2] == ARG_STAR and arg_kinds[-1] == ARG_STAR2:
992996
arg_types = self.anal_array(t.arg_types[:-2], nested=nested) + [
@@ -1041,7 +1045,8 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
10411045
# its type will be the falsey FakeInfo
10421046
fallback=(t.fallback if t.fallback.type else self.named_type("builtins.function")),
10431047
variables=self.anal_var_defs(variables),
1044-
type_guard=special,
1048+
type_guard=type_guard,
1049+
type_is=type_is,
10451050
unpack_kwargs=unpacked_kwargs,
10461051
)
10471052
return ret
@@ -1064,6 +1069,22 @@ def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Type | None:
10641069
return self.anal_type(t.args[0])
10651070
return None
10661071

1072+
def anal_type_is(self, t: Type) -> Type | None:
1073+
if isinstance(t, UnboundType):
1074+
sym = self.lookup_qualified(t.name, t)
1075+
if sym is not None and sym.node is not None:
1076+
return self.anal_type_is_arg(t, sym.node.fullname)
1077+
# TODO: What if it's an Instance? Then use t.type.fullname?
1078+
return None
1079+
1080+
def anal_type_is_arg(self, t: UnboundType, fullname: str) -> Type | None:
1081+
if fullname in ("typing_extensions.TypeIs", "typing.TypeIs"):
1082+
if len(t.args) != 1:
1083+
self.fail("TypeIs must have exactly one type argument", t, code=codes.VALID_TYPE)
1084+
return AnyType(TypeOfAny.from_error)
1085+
return self.anal_type(t.args[0])
1086+
return None
1087+
10671088
def anal_star_arg_type(self, t: Type, kind: ArgKind, nested: bool) -> Type:
10681089
"""Analyze signature argument type for *args and **kwargs argument."""
10691090
if isinstance(t, UnboundType) and t.name and "." in t.name and not t.args:

0 commit comments

Comments
 (0)