Skip to content

Commit

Permalink
Add foundation for TypeVar defaults (PEP 696)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Apr 16, 2023
1 parent 4276308 commit ccc0eef
Show file tree
Hide file tree
Showing 23 changed files with 331 additions and 87 deletions.
1 change: 1 addition & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7093,6 +7093,7 @@ def detach_callable(typ: CallableType) -> CallableType:
id=var.id,
values=var.values,
upper_bound=var.upper_bound,
default=var.default,
variance=var.variance,
)
)
Expand Down
36 changes: 30 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4138,7 +4138,9 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
# Used for list and set expressions, as well as for tuples
# containing star expressions that don't refer to a
# Tuple. (Note: "lst" stands for list-set-tuple. :-)
tv = TypeVarType("T", "T", -1, [], self.object_type())
tv = TypeVarType(
"T", "T", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
)
constructor = CallableType(
[tv],
[nodes.ARG_STAR],
Expand Down Expand Up @@ -4321,8 +4323,12 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
tup.column = value.column
args.append(tup)
# Define type variables (used in constructors below).
kt = TypeVarType("KT", "KT", -1, [], self.object_type())
vt = TypeVarType("VT", "VT", -2, [], self.object_type())
kt = TypeVarType(
"KT", "KT", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
)
vt = TypeVarType(
"VT", "VT", -2, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
)
rv = None
# Call dict(*args), unless it's empty and stargs is not.
if args or not stargs:
Expand Down Expand Up @@ -4693,7 +4699,9 @@ def check_generator_or_comprehension(

# Infer the type of the list comprehension by using a synthetic generic
# callable type.
tv = TypeVarType("T", "T", -1, [], self.object_type())
tv = TypeVarType(
"T", "T", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
)
tv_list: list[Type] = [tv]
constructor = CallableType(
tv_list,
Expand All @@ -4713,8 +4721,12 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type:

# Infer the type of the list comprehension by using a synthetic generic
# callable type.
ktdef = TypeVarType("KT", "KT", -1, [], self.object_type())
vtdef = TypeVarType("VT", "VT", -2, [], self.object_type())
ktdef = TypeVarType(
"KT", "KT", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
)
vtdef = TypeVarType(
"VT", "VT", -2, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
)
constructor = CallableType(
[ktdef, vtdef],
[nodes.ARG_POS, nodes.ARG_POS],
Expand Down Expand Up @@ -5242,6 +5254,18 @@ def visit_callable_type(self, t: CallableType) -> bool:
return False
return super().visit_callable_type(t)

def visit_type_var(self, t: TypeVarType) -> bool:
default = [t.default] if t.has_default() else []
return self.query_types([t.upper_bound, *default] + t.values)

def visit_param_spec(self, t: ParamSpecType) -> bool:
default = [t.default] if t.has_default() else []
return self.query_types([t.upper_bound, *default])

def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
default = [t.default] if t.has_default() else []
return self.query_types([t.upper_bound, *default])


def has_coroutine_decorator(t: Type) -> bool:
"""Whether t came from a function decorated with `@coroutine`."""
Expand Down
9 changes: 7 additions & 2 deletions mypy/copytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,15 @@ def visit_type_var(self, t: TypeVarType) -> ProperType:
t.id,
values=t.values,
upper_bound=t.upper_bound,
default=t.default,
variance=t.variance,
)
return self.copy_common(t, dup)

def visit_param_spec(self, t: ParamSpecType) -> ProperType:
dup = ParamSpecType(t.name, t.fullname, t.id, t.flavor, t.upper_bound, prefix=t.prefix)
dup = ParamSpecType(
t.name, t.fullname, t.id, t.flavor, t.upper_bound, t.default, prefix=t.prefix
)
return self.copy_common(t, dup)

def visit_parameters(self, t: Parameters) -> ProperType:
Expand All @@ -94,7 +97,9 @@ def visit_parameters(self, t: Parameters) -> ProperType:
return self.copy_common(t, dup)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
dup = TypeVarTupleType(t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback)
dup = TypeVarTupleType(
t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback, t.default
)
return self.copy_common(t, dup)

def visit_unpack_type(self, t: UnpackType) -> ProperType:
Expand Down
10 changes: 1 addition & 9 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
TypedDictType,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
TypeVisitor,
Expand Down Expand Up @@ -135,14 +134,7 @@ def freshen_function_type_vars(callee: F) -> F:
tvs = []
tvmap: dict[TypeVarId, Type] = {}
for v in callee.variables:
if isinstance(v, TypeVarType):
tv: TypeVarLikeType = TypeVarType.new_unification_variable(v)
elif isinstance(v, TypeVarTupleType):
assert isinstance(v, TypeVarTupleType)
tv = TypeVarTupleType.new_unification_variable(v)
else:
assert isinstance(v, ParamSpecType)
tv = ParamSpecType.new_unification_variable(v)
tv = v.new_unification_variable(v)
tvs.append(tv)
tvmap[v.id] = tv
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)
Expand Down
10 changes: 8 additions & 2 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,21 @@ def visit_class_def(self, c: ClassDef) -> None:
for value in v.values:
value.accept(self.type_fixer)
v.upper_bound.accept(self.type_fixer)
v.default.accept(self.type_fixer)

def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
for value in tv.values:
value.accept(self.type_fixer)
tv.upper_bound.accept(self.type_fixer)
tv.default.accept(self.type_fixer)

def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
p.upper_bound.accept(self.type_fixer)
p.default.accept(self.type_fixer)

def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
tv.upper_bound.accept(self.type_fixer)
tv.default.accept(self.type_fixer)

def visit_var(self, v: Var) -> None:
if self.current_info is not None:
Expand Down Expand Up @@ -303,14 +307,16 @@ def visit_type_var(self, tvt: TypeVarType) -> None:
if tvt.values:
for vt in tvt.values:
vt.accept(self)
if tvt.upper_bound is not None:
tvt.upper_bound.accept(self)
tvt.upper_bound.accept(self)
tvt.default.accept(self)

def visit_param_spec(self, p: ParamSpecType) -> None:
p.upper_bound.accept(self)
p.default.accept(self)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
t.upper_bound.accept(self)
t.default.accept(self)

def visit_unpack_type(self, u: UnpackType) -> None:
u.type.accept(self)
Expand Down
6 changes: 3 additions & 3 deletions mypy/indirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def visit_deleted_type(self, t: types.DeletedType) -> set[str]:
return set()

def visit_type_var(self, t: types.TypeVarType) -> set[str]:
return self._visit(t.values) | self._visit(t.upper_bound)
return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)

def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
return set()
return self._visit(t.upper_bound) | self._visit(t.default)

def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
return self._visit(t.upper_bound)
return self._visit(t.upper_bound) | self._visit(t.default)

def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
return t.type.accept(self)
Expand Down
23 changes: 19 additions & 4 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2427,26 +2427,33 @@ class TypeVarLikeExpr(SymbolNode, Expression):
Note that they are constructed by the semantic analyzer.
"""

__slots__ = ("_name", "_fullname", "upper_bound", "variance")
__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance")

_name: str
_fullname: str
# Upper bound: only subtypes of upper_bound are valid as values. By default
# this is 'object', meaning no restriction.
upper_bound: mypy.types.Type
default: mypy.types.Type
# Variance of the type variable. Invariant is the default.
# TypeVar(..., covariant=True) defines a covariant type variable.
# TypeVar(..., contravariant=True) defines a contravariant type
# variable.
variance: int

def __init__(
self, name: str, fullname: str, upper_bound: mypy.types.Type, variance: int = INVARIANT
self,
name: str,
fullname: str,
upper_bound: mypy.types.Type,
default: mypy.types.Type,
variance: int = INVARIANT,
) -> None:
super().__init__()
self._name = name
self._fullname = fullname
self.upper_bound = upper_bound
self.default = default
self.variance = variance

@property
Expand Down Expand Up @@ -2484,9 +2491,10 @@ def __init__(
fullname: str,
values: list[mypy.types.Type],
upper_bound: mypy.types.Type,
default: mypy.types.Type,
variance: int = INVARIANT,
) -> None:
super().__init__(name, fullname, upper_bound, variance)
super().__init__(name, fullname, upper_bound, default, variance)
self.values = values

def accept(self, visitor: ExpressionVisitor[T]) -> T:
Expand All @@ -2499,6 +2507,7 @@ def serialize(self) -> JsonDict:
"fullname": self._fullname,
"values": [t.serialize() for t in self.values],
"upper_bound": self.upper_bound.serialize(),
"default": self.default.serialize(),
"variance": self.variance,
}

Expand All @@ -2510,6 +2519,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarExpr:
data["fullname"],
[mypy.types.deserialize_type(v) for v in data["values"]],
mypy.types.deserialize_type(data["upper_bound"]),
mypy.types.deserialize_type(data["default"]),
data["variance"],
)

Expand All @@ -2528,6 +2538,7 @@ def serialize(self) -> JsonDict:
"name": self._name,
"fullname": self._fullname,
"upper_bound": self.upper_bound.serialize(),
"default": self.default.serialize(),
"variance": self.variance,
}

Expand All @@ -2538,6 +2549,7 @@ def deserialize(cls, data: JsonDict) -> ParamSpecExpr:
data["name"],
data["fullname"],
mypy.types.deserialize_type(data["upper_bound"]),
mypy.types.deserialize_type(data["default"]),
data["variance"],
)

Expand All @@ -2557,9 +2569,10 @@ def __init__(
fullname: str,
upper_bound: mypy.types.Type,
tuple_fallback: mypy.types.Instance,
default: mypy.types.Type,
variance: int = INVARIANT,
) -> None:
super().__init__(name, fullname, upper_bound, variance)
super().__init__(name, fullname, upper_bound, default, variance)
self.tuple_fallback = tuple_fallback

def accept(self, visitor: ExpressionVisitor[T]) -> T:
Expand All @@ -2572,6 +2585,7 @@ def serialize(self) -> JsonDict:
"fullname": self._fullname,
"upper_bound": self.upper_bound.serialize(),
"tuple_fallback": self.tuple_fallback.serialize(),
"default": self.default.serialize(),
"variance": self.variance,
}

Expand All @@ -2583,6 +2597,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleExpr:
data["fullname"],
mypy.types.deserialize_type(data["upper_bound"]),
mypy.types.Instance.deserialize(data["tuple_fallback"]),
mypy.types.deserialize_type(data["default"]),
data["variance"],
)

Expand Down
13 changes: 11 additions & 2 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,10 +762,19 @@ def _add_order(ctx: mypy.plugin.ClassDefContext, adder: MethodAdder) -> None:
# def __lt__(self: AT, other: AT) -> bool
# This way comparisons with subclasses will work correctly.
tvd = TypeVarType(
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, -1, [], object_type
SELF_TVAR_NAME,
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
-1,
[],
object_type,
AnyType(TypeOfAny.from_omitted_generics),
)
self_tvar_expr = TypeVarExpr(
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, [], object_type
SELF_TVAR_NAME,
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
[],
object_type,
AnyType(TypeOfAny.from_omitted_generics),
)
ctx.cls.info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)

Expand Down
13 changes: 11 additions & 2 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,11 @@ def transform(self) -> bool:
# Type variable for self types in generated methods.
obj_type = self._api.named_type("builtins.object")
self_tvar_expr = TypeVarExpr(
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, [], obj_type
SELF_TVAR_NAME,
info.fullname + "." + SELF_TVAR_NAME,
[],
obj_type,
AnyType(TypeOfAny.from_omitted_generics),
)
info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)

Expand All @@ -268,7 +272,12 @@ def transform(self) -> bool:
# the self type.
obj_type = self._api.named_type("builtins.object")
order_tvar_def = TypeVarType(
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], obj_type
SELF_TVAR_NAME,
info.fullname + "." + SELF_TVAR_NAME,
-1,
[],
obj_type,
AnyType(TypeOfAny.from_omitted_generics),
)
order_return_type = self._api.named_type("builtins.bool")
order_args = [
Expand Down
Loading

0 comments on commit ccc0eef

Please sign in to comment.