Skip to content

Commit

Permalink
Fix typevar tuple handling to expect unpack in class def (#13630)
Browse files Browse the repository at this point in the history
Originally this PR was intended to add some test cases from PEP646.
However it became immediately apparent that there was a major bug in the
implementation where we expected the definition to look like:

```
class Foo(Generic[Ts])
```

When it is supposed to be

```
class Foo(Generic[Unpack[Ts]])
```

This fixes that.

Also improve constraints solving involving typevar tuples.
  • Loading branch information
jhance authored Sep 21, 2022
1 parent 6a50192 commit 0a720ed
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 13 deletions.
61 changes: 58 additions & 3 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,9 +583,60 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
if self.direction == SUBTYPE_OF and template.type.has_base(instance.type.fullname):
mapped = map_instance_to_supertype(template, instance.type)
tvars = mapped.type.defn.type_vars

if instance.type.has_type_var_tuple_type:
mapped_prefix, mapped_middle, mapped_suffix = split_with_instance(mapped)
instance_prefix, instance_middle, instance_suffix = split_with_instance(
instance
)

# Add a constraint for the type var tuple, and then
# remove it for the case below.
instance_unpack = extract_unpack(instance_middle)
if instance_unpack is not None:
if isinstance(instance_unpack, TypeVarTupleType):
res.append(
Constraint(
instance_unpack, SUBTYPE_OF, TypeList(list(mapped_middle))
)
)
elif (
isinstance(instance_unpack, Instance)
and instance_unpack.type.fullname == "builtins.tuple"
):
for item in mapped_middle:
res.extend(
infer_constraints(
instance_unpack.args[0], item, self.direction
)
)
elif isinstance(instance_unpack, TupleType):
if len(instance_unpack.items) == len(mapped_middle):
for instance_arg, item in zip(
instance_unpack.items, mapped_middle
):
res.extend(
infer_constraints(instance_arg, item, self.direction)
)

mapped_args = mapped_prefix + mapped_suffix
instance_args = instance_prefix + instance_suffix

assert instance.type.type_var_tuple_prefix is not None
assert instance.type.type_var_tuple_suffix is not None
tvars_prefix, _, tvars_suffix = split_with_prefix_and_suffix(
tuple(tvars),
instance.type.type_var_tuple_prefix,
instance.type.type_var_tuple_suffix,
)
tvars = list(tvars_prefix + tvars_suffix)
else:
mapped_args = mapped.args
instance_args = instance.args

# N.B: We use zip instead of indexing because the lengths might have
# mismatches during daemon reprocessing.
for tvar, mapped_arg, instance_arg in zip(tvars, mapped.args, instance.args):
for tvar, mapped_arg, instance_arg in zip(tvars, mapped_args, instance_args):
# TODO(PEP612): More ParamSpec work (or is Parameters the only thing accepted)
if isinstance(tvar, TypeVarType):
# The constraints for generic type parameters depend on variance.
Expand Down Expand Up @@ -617,8 +668,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
elif isinstance(suffix, ParamSpecType):
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
elif isinstance(tvar, TypeVarTupleType):
raise NotImplementedError
else:
# This case should have been handled above.
assert not isinstance(tvar, TypeVarTupleType)

return res
elif self.direction == SUPERTYPE_OF and instance.type.has_base(template.type.fullname):
Expand Down Expand Up @@ -710,6 +762,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
elif isinstance(suffix, ParamSpecType):
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
else:
# This case should have been handled above.
assert not isinstance(tvar, TypeVarTupleType)
return res
if (
template.type.is_protocol
Expand Down
1 change: 1 addition & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2858,6 +2858,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None
self.metadata = {}

def add_type_vars(self) -> None:
self.has_type_var_tuple_type = False
if self.defn.type_vars:
for i, vd in enumerate(self.defn.type_vars):
if isinstance(vd, mypy.types.ParamSpecType):
Expand Down
22 changes: 18 additions & 4 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,10 +1684,16 @@ def analyze_class_typevar_declaration(self, base: Type) -> tuple[TypeVarLikeList
):
is_proto = sym.node.fullname != "typing.Generic"
tvars: TypeVarLikeList = []
have_type_var_tuple = False
for arg in unbound.args:
tag = self.track_incomplete_refs()
tvar = self.analyze_unbound_tvar(arg)
if tvar:
if isinstance(tvar[1], TypeVarTupleExpr):
if have_type_var_tuple:
self.fail("Can only use one type var tuple in a class def", base)
continue
have_type_var_tuple = True
tvars.append(tvar)
elif not self.found_incomplete_ref(tag):
self.fail("Free type variable expected in %s[...]" % sym.node.name, base)
Expand All @@ -1706,11 +1712,19 @@ def analyze_unbound_tvar(self, t: Type) -> tuple[str, TypeVarLikeExpr] | None:
# It's bound by our type variable scope
return None
return unbound.name, sym.node
if sym and isinstance(sym.node, TypeVarTupleExpr):
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
if sym and sym.fullname == "typing_extensions.Unpack":
inner_t = unbound.args[0]
if not isinstance(inner_t, UnboundType):
return None
return unbound.name, sym.node
inner_unbound = inner_t
inner_sym = self.lookup_qualified(inner_unbound.name, inner_unbound)
if inner_sym and isinstance(inner_sym.node, PlaceholderNode):
self.record_incomplete_ref()
if inner_sym and isinstance(inner_sym.node, TypeVarTupleExpr):
if inner_sym.fullname and not self.tvar_scope.allow_binding(inner_sym.fullname):
# It's bound by our type variable scope
return None
return inner_unbound.name, inner_sym.node
if sym is None or not isinstance(sym.node, TypeVarExpr):
return None
elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
Expand Down
99 changes: 93 additions & 6 deletions test-data/unit/check-typevar-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,18 @@ reveal_type(h(args)) # N: Revealed type is "Tuple[builtins.str, builtins.str, b

[case testTypeVarTupleGenericClassDefn]
from typing import Generic, TypeVar, Tuple
from typing_extensions import TypeVarTuple
from typing_extensions import TypeVarTuple, Unpack

T = TypeVar("T")
Ts = TypeVarTuple("Ts")

class Variadic(Generic[Ts]):
class Variadic(Generic[Unpack[Ts]]):
pass

class Mixed1(Generic[T, Ts]):
class Mixed1(Generic[T, Unpack[Ts]]):
pass

class Mixed2(Generic[Ts, T]):
class Mixed2(Generic[Unpack[Ts], T]):
pass

variadic: Variadic[int, str]
Expand All @@ -133,7 +133,7 @@ Ts = TypeVarTuple("Ts")
T = TypeVar("T")
S = TypeVar("S")

class Variadic(Generic[T, Ts, S]):
class Variadic(Generic[T, Unpack[Ts], S]):
pass

def foo(t: Variadic[int, Unpack[Ts], object]) -> Tuple[int, Unpack[Ts]]:
Expand All @@ -152,7 +152,7 @@ Ts = TypeVarTuple("Ts")
T = TypeVar("T")
S = TypeVar("S")

class Variadic(Generic[T, Ts, S]):
class Variadic(Generic[T, Unpack[Ts], S]):
def __init__(self, t: Tuple[Unpack[Ts]]) -> None:
...

Expand All @@ -170,3 +170,90 @@ from typing_extensions import TypeVarTuple
Ts = TypeVarTuple("Ts")
B = Ts # E: Type variable "__main__.Ts" is invalid as target for type alias
[builtins fixtures/tuple.pyi]

[case testPep646ArrayExample]
from typing import Generic, Tuple, TypeVar, Protocol, NewType
from typing_extensions import TypeVarTuple, Unpack

Shape = TypeVarTuple('Shape')

Height = NewType('Height', int)
Width = NewType('Width', int)

T_co = TypeVar("T_co", covariant=True)
T = TypeVar("T")

class SupportsAbs(Protocol[T_co]):
def __abs__(self) -> T_co: pass

def abs(a: SupportsAbs[T]) -> T:
...

class Array(Generic[Unpack[Shape]]):
def __init__(self, shape: Tuple[Unpack[Shape]]):
self._shape: Tuple[Unpack[Shape]] = shape

def get_shape(self) -> Tuple[Unpack[Shape]]:
return self._shape

def __abs__(self) -> Array[Unpack[Shape]]: ...

def __add__(self, other: Array[Unpack[Shape]]) -> Array[Unpack[Shape]]: ...

shape = (Height(480), Width(640))
x: Array[Height, Width] = Array(shape)
reveal_type(abs(x)) # N: Revealed type is "__main__.Array[__main__.Height, __main__.Width]"
reveal_type(x + x) # N: Revealed type is "__main__.Array[__main__.Height, __main__.Width]"

[builtins fixtures/tuple.pyi]
[case testPep646ArrayExampleWithDType]
from typing import Generic, Tuple, TypeVar, Protocol, NewType
from typing_extensions import TypeVarTuple, Unpack

DType = TypeVar("DType")
Shape = TypeVarTuple('Shape')

Height = NewType('Height', int)
Width = NewType('Width', int)

T_co = TypeVar("T_co", covariant=True)
T = TypeVar("T")

class SupportsAbs(Protocol[T_co]):
def __abs__(self) -> T_co: pass

def abs(a: SupportsAbs[T]) -> T:
...

class Array(Generic[DType, Unpack[Shape]]):
def __init__(self, shape: Tuple[Unpack[Shape]]):
self._shape: Tuple[Unpack[Shape]] = shape

def get_shape(self) -> Tuple[Unpack[Shape]]:
return self._shape

def __abs__(self) -> Array[DType, Unpack[Shape]]: ...

def __add__(self, other: Array[DType, Unpack[Shape]]) -> Array[DType, Unpack[Shape]]: ...

shape = (Height(480), Width(640))
x: Array[float, Height, Width] = Array(shape)
reveal_type(abs(x)) # N: Revealed type is "__main__.Array[builtins.float, __main__.Height, __main__.Width]"
reveal_type(x + x) # N: Revealed type is "__main__.Array[builtins.float, __main__.Height, __main__.Width]"

[builtins fixtures/tuple.pyi]

[case testPep646ArrayExampleInfer]
from typing import Generic, Tuple, TypeVar, NewType
from typing_extensions import TypeVarTuple, Unpack

Shape = TypeVarTuple('Shape')

Height = NewType('Height', int)
Width = NewType('Width', int)

class Array(Generic[Unpack[Shape]]):
pass

x: Array[float, Height, Width] = Array()
[builtins fixtures/tuple.pyi]
6 changes: 6 additions & 0 deletions test-data/unit/semanal-errors.test
Original file line number Diff line number Diff line change
Expand Up @@ -1456,9 +1456,11 @@ bad: Tuple[Unpack[int]] # E: builtins.int cannot be unpacked (must be tuple or
[builtins fixtures/tuple.pyi]

[case testTypeVarTuple]
from typing import Generic
from typing_extensions import TypeVarTuple, Unpack

TVariadic = TypeVarTuple('TVariadic')
TVariadic2 = TypeVarTuple('TVariadic2')
TP = TypeVarTuple('?') # E: String argument 1 "?" to TypeVarTuple(...) does not match variable name "TP"
TP2: int = TypeVarTuple('TP2') # E: Cannot declare the type of a TypeVar or similar construct
TP3 = TypeVarTuple() # E: Too few arguments for TypeVarTuple()
Expand All @@ -1467,3 +1469,7 @@ TP5 = TypeVarTuple(t='TP5') # E: TypeVarTuple() expects a string literal as fir

x: TVariadic # E: TypeVarTuple "TVariadic" is unbound
y: Unpack[TVariadic] # E: TypeVarTuple "TVariadic" is unbound


class Variadic(Generic[Unpack[TVariadic], Unpack[TVariadic2]]): # E: Can only use one type var tuple in a class def
pass

0 comments on commit 0a720ed

Please sign in to comment.