Skip to content

Commit c3cc492

Browse files
committed
Add foundation for TypeVar defaults (PEP 696)
1 parent 267d376 commit c3cc492

23 files changed

+338
-87
lines changed

mypy/checker.py

+1
Original file line numberDiff line numberDiff line change
@@ -7084,6 +7084,7 @@ def detach_callable(typ: CallableType) -> CallableType:
70847084
id=var.id,
70857085
values=var.values,
70867086
upper_bound=var.upper_bound,
7087+
default=var.default,
70877088
variance=var.variance,
70887089
)
70897090
)

mypy/checkexpr.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -4137,7 +4137,9 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag:
41374137
# Used for list and set expressions, as well as for tuples
41384138
# containing star expressions that don't refer to a
41394139
# Tuple. (Note: "lst" stands for list-set-tuple. :-)
4140-
tv = TypeVarType("T", "T", -1, [], self.object_type())
4140+
tv = TypeVarType(
4141+
"T", "T", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4142+
)
41414143
constructor = CallableType(
41424144
[tv],
41434145
[nodes.ARG_STAR],
@@ -4320,8 +4322,12 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
43204322
tup.column = value.column
43214323
args.append(tup)
43224324
# Define type variables (used in constructors below).
4323-
kt = TypeVarType("KT", "KT", -1, [], self.object_type())
4324-
vt = TypeVarType("VT", "VT", -2, [], self.object_type())
4325+
kt = TypeVarType(
4326+
"KT", "KT", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4327+
)
4328+
vt = TypeVarType(
4329+
"VT", "VT", -2, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4330+
)
43254331
rv = None
43264332
# Call dict(*args), unless it's empty and stargs is not.
43274333
if args or not stargs:
@@ -4684,7 +4690,9 @@ def check_generator_or_comprehension(
46844690

46854691
# Infer the type of the list comprehension by using a synthetic generic
46864692
# callable type.
4687-
tv = TypeVarType("T", "T", -1, [], self.object_type())
4693+
tv = TypeVarType(
4694+
"T", "T", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4695+
)
46884696
tv_list: list[Type] = [tv]
46894697
constructor = CallableType(
46904698
tv_list,
@@ -4704,8 +4712,12 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type:
47044712

47054713
# Infer the type of the list comprehension by using a synthetic generic
47064714
# callable type.
4707-
ktdef = TypeVarType("KT", "KT", -1, [], self.object_type())
4708-
vtdef = TypeVarType("VT", "VT", -2, [], self.object_type())
4715+
ktdef = TypeVarType(
4716+
"KT", "KT", -1, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4717+
)
4718+
vtdef = TypeVarType(
4719+
"VT", "VT", -2, [], self.object_type(), AnyType(TypeOfAny.from_omitted_generics)
4720+
)
47094721
constructor = CallableType(
47104722
[ktdef, vtdef],
47114723
[nodes.ARG_POS, nodes.ARG_POS],
@@ -5233,6 +5245,18 @@ def visit_callable_type(self, t: CallableType) -> bool:
52335245
return False
52345246
return super().visit_callable_type(t)
52355247

5248+
def visit_type_var(self, t: TypeVarType) -> bool:
5249+
default = [t.default] if t.has_default() else []
5250+
return self.query_types([t.upper_bound, *default] + t.values)
5251+
5252+
def visit_param_spec(self, t: ParamSpecType) -> bool:
5253+
default = [t.default] if t.has_default() else []
5254+
return self.query_types([t.upper_bound, *default])
5255+
5256+
def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
5257+
default = [t.default] if t.has_default() else []
5258+
return self.query_types([t.upper_bound, *default])
5259+
52365260

52375261
def has_coroutine_decorator(t: Type) -> bool:
52385262
"""Whether t came from a function decorated with `@coroutine`."""

mypy/copytype.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@ def visit_type_var(self, t: TypeVarType) -> ProperType:
7575
t.id,
7676
values=t.values,
7777
upper_bound=t.upper_bound,
78+
default=t.default,
7879
variance=t.variance,
7980
)
8081
return self.copy_common(t, dup)
8182

8283
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
83-
dup = ParamSpecType(t.name, t.fullname, t.id, t.flavor, t.upper_bound, prefix=t.prefix)
84+
dup = ParamSpecType(
85+
t.name, t.fullname, t.id, t.flavor, t.upper_bound, t.default, prefix=t.prefix
86+
)
8487
return self.copy_common(t, dup)
8588

8689
def visit_parameters(self, t: Parameters) -> ProperType:
@@ -94,7 +97,9 @@ def visit_parameters(self, t: Parameters) -> ProperType:
9497
return self.copy_common(t, dup)
9598

9699
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
97-
dup = TypeVarTupleType(t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback)
100+
dup = TypeVarTupleType(
101+
t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback, t.default
102+
)
98103
return self.copy_common(t, dup)
99104

100105
def visit_unpack_type(self, t: UnpackType) -> ProperType:

mypy/expandtype.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
TypedDictType,
2828
TypeType,
2929
TypeVarId,
30-
TypeVarLikeType,
3130
TypeVarTupleType,
3231
TypeVarType,
3332
TypeVisitor,
@@ -123,14 +122,7 @@ def freshen_function_type_vars(callee: F) -> F:
123122
tvs = []
124123
tvmap: dict[TypeVarId, Type] = {}
125124
for v in callee.variables:
126-
if isinstance(v, TypeVarType):
127-
tv: TypeVarLikeType = TypeVarType.new_unification_variable(v)
128-
elif isinstance(v, TypeVarTupleType):
129-
assert isinstance(v, TypeVarTupleType)
130-
tv = TypeVarTupleType.new_unification_variable(v)
131-
else:
132-
assert isinstance(v, ParamSpecType)
133-
tv = ParamSpecType.new_unification_variable(v)
125+
tv = v.new_unification_variable(v)
134126
tvs.append(tv)
135127
tvmap[v.id] = tv
136128
fresh = cast(CallableType, expand_type(callee, tvmap)).copy_modified(variables=tvs)

mypy/fixup.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -171,17 +171,27 @@ def visit_class_def(self, c: ClassDef) -> None:
171171
for value in v.values:
172172
value.accept(self.type_fixer)
173173
v.upper_bound.accept(self.type_fixer)
174+
v.default.accept(self.type_fixer)
175+
if isinstance(v, ParamSpecType):
176+
v.upper_bound.accept(self.type_fixer)
177+
v.default.accept(self.type_fixer)
178+
if isinstance(v, TypeVarTupleType):
179+
v.upper_bound.accept(self.type_fixer)
180+
v.default.accept(self.type_fixer)
174181

175182
def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
176183
for value in tv.values:
177184
value.accept(self.type_fixer)
178185
tv.upper_bound.accept(self.type_fixer)
186+
tv.default.accept(self.type_fixer)
179187

180188
def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
181189
p.upper_bound.accept(self.type_fixer)
190+
p.default.accept(self.type_fixer)
182191

183192
def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
184193
tv.upper_bound.accept(self.type_fixer)
194+
tv.default.accept(self.type_fixer)
185195

186196
def visit_var(self, v: Var) -> None:
187197
if self.current_info is not None:
@@ -303,14 +313,16 @@ def visit_type_var(self, tvt: TypeVarType) -> None:
303313
if tvt.values:
304314
for vt in tvt.values:
305315
vt.accept(self)
306-
if tvt.upper_bound is not None:
307-
tvt.upper_bound.accept(self)
316+
tvt.upper_bound.accept(self)
317+
tvt.default.accept(self)
308318

309319
def visit_param_spec(self, p: ParamSpecType) -> None:
310320
p.upper_bound.accept(self)
321+
p.default.accept(self)
311322

312323
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
313324
t.upper_bound.accept(self)
325+
t.default.accept(self)
314326

315327
def visit_unpack_type(self, u: UnpackType) -> None:
316328
u.type.accept(self)

mypy/indirection.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ def visit_deleted_type(self, t: types.DeletedType) -> set[str]:
6464
return set()
6565

6666
def visit_type_var(self, t: types.TypeVarType) -> set[str]:
67-
return self._visit(t.values) | self._visit(t.upper_bound)
67+
return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)
6868

6969
def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
70-
return set()
70+
return self._visit(t.upper_bound) | self._visit(t.default)
7171

7272
def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
73-
return self._visit(t.upper_bound)
73+
return self._visit(t.upper_bound) | self._visit(t.default)
7474

7575
def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
7676
return t.type.accept(self)

mypy/nodes.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -2422,26 +2422,33 @@ class TypeVarLikeExpr(SymbolNode, Expression):
24222422
Note that they are constructed by the semantic analyzer.
24232423
"""
24242424

2425-
__slots__ = ("_name", "_fullname", "upper_bound", "variance")
2425+
__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance")
24262426

24272427
_name: str
24282428
_fullname: str
24292429
# Upper bound: only subtypes of upper_bound are valid as values. By default
24302430
# this is 'object', meaning no restriction.
24312431
upper_bound: mypy.types.Type
2432+
default: mypy.types.Type
24322433
# Variance of the type variable. Invariant is the default.
24332434
# TypeVar(..., covariant=True) defines a covariant type variable.
24342435
# TypeVar(..., contravariant=True) defines a contravariant type
24352436
# variable.
24362437
variance: int
24372438

24382439
def __init__(
2439-
self, name: str, fullname: str, upper_bound: mypy.types.Type, variance: int = INVARIANT
2440+
self,
2441+
name: str,
2442+
fullname: str,
2443+
upper_bound: mypy.types.Type,
2444+
default: mypy.types.Type,
2445+
variance: int = INVARIANT,
24402446
) -> None:
24412447
super().__init__()
24422448
self._name = name
24432449
self._fullname = fullname
24442450
self.upper_bound = upper_bound
2451+
self.default = default
24452452
self.variance = variance
24462453

24472454
@property
@@ -2479,9 +2486,10 @@ def __init__(
24792486
fullname: str,
24802487
values: list[mypy.types.Type],
24812488
upper_bound: mypy.types.Type,
2489+
default: mypy.types.Type,
24822490
variance: int = INVARIANT,
24832491
) -> None:
2484-
super().__init__(name, fullname, upper_bound, variance)
2492+
super().__init__(name, fullname, upper_bound, default, variance)
24852493
self.values = values
24862494

24872495
def accept(self, visitor: ExpressionVisitor[T]) -> T:
@@ -2494,6 +2502,7 @@ def serialize(self) -> JsonDict:
24942502
"fullname": self._fullname,
24952503
"values": [t.serialize() for t in self.values],
24962504
"upper_bound": self.upper_bound.serialize(),
2505+
"default": self.default.serialize(),
24972506
"variance": self.variance,
24982507
}
24992508

@@ -2505,6 +2514,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarExpr:
25052514
data["fullname"],
25062515
[mypy.types.deserialize_type(v) for v in data["values"]],
25072516
mypy.types.deserialize_type(data["upper_bound"]),
2517+
mypy.types.deserialize_type(data["default"]),
25082518
data["variance"],
25092519
)
25102520

@@ -2523,6 +2533,7 @@ def serialize(self) -> JsonDict:
25232533
"name": self._name,
25242534
"fullname": self._fullname,
25252535
"upper_bound": self.upper_bound.serialize(),
2536+
"default": self.default.serialize(),
25262537
"variance": self.variance,
25272538
}
25282539

@@ -2533,6 +2544,7 @@ def deserialize(cls, data: JsonDict) -> ParamSpecExpr:
25332544
data["name"],
25342545
data["fullname"],
25352546
mypy.types.deserialize_type(data["upper_bound"]),
2547+
mypy.types.deserialize_type(data["default"]),
25362548
data["variance"],
25372549
)
25382550

@@ -2552,9 +2564,10 @@ def __init__(
25522564
fullname: str,
25532565
upper_bound: mypy.types.Type,
25542566
tuple_fallback: mypy.types.Instance,
2567+
default: mypy.types.Type,
25552568
variance: int = INVARIANT,
25562569
) -> None:
2557-
super().__init__(name, fullname, upper_bound, variance)
2570+
super().__init__(name, fullname, upper_bound, default, variance)
25582571
self.tuple_fallback = tuple_fallback
25592572

25602573
def accept(self, visitor: ExpressionVisitor[T]) -> T:
@@ -2567,6 +2580,7 @@ def serialize(self) -> JsonDict:
25672580
"fullname": self._fullname,
25682581
"upper_bound": self.upper_bound.serialize(),
25692582
"tuple_fallback": self.tuple_fallback.serialize(),
2583+
"default": self.default.serialize(),
25702584
"variance": self.variance,
25712585
}
25722586

@@ -2578,6 +2592,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleExpr:
25782592
data["fullname"],
25792593
mypy.types.deserialize_type(data["upper_bound"]),
25802594
mypy.types.Instance.deserialize(data["tuple_fallback"]),
2595+
mypy.types.deserialize_type(data["default"]),
25812596
data["variance"],
25822597
)
25832598

mypy/plugins/attrs.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -726,10 +726,19 @@ def _add_order(ctx: mypy.plugin.ClassDefContext, adder: MethodAdder) -> None:
726726
# def __lt__(self: AT, other: AT) -> bool
727727
# This way comparisons with subclasses will work correctly.
728728
tvd = TypeVarType(
729-
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, -1, [], object_type
729+
SELF_TVAR_NAME,
730+
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
731+
-1,
732+
[],
733+
object_type,
734+
AnyType(TypeOfAny.from_omitted_generics),
730735
)
731736
self_tvar_expr = TypeVarExpr(
732-
SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, [], object_type
737+
SELF_TVAR_NAME,
738+
ctx.cls.info.fullname + "." + SELF_TVAR_NAME,
739+
[],
740+
object_type,
741+
AnyType(TypeOfAny.from_omitted_generics),
733742
)
734743
ctx.cls.info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
735744

mypy/plugins/dataclasses.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,11 @@ def transform(self) -> bool:
250250
# Type variable for self types in generated methods.
251251
obj_type = self._api.named_type("builtins.object")
252252
self_tvar_expr = TypeVarExpr(
253-
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, [], obj_type
253+
SELF_TVAR_NAME,
254+
info.fullname + "." + SELF_TVAR_NAME,
255+
[],
256+
obj_type,
257+
AnyType(TypeOfAny.from_omitted_generics),
254258
)
255259
info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr)
256260

@@ -264,7 +268,12 @@ def transform(self) -> bool:
264268
# the self type.
265269
obj_type = self._api.named_type("builtins.object")
266270
order_tvar_def = TypeVarType(
267-
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], obj_type
271+
SELF_TVAR_NAME,
272+
info.fullname + "." + SELF_TVAR_NAME,
273+
-1,
274+
[],
275+
obj_type,
276+
AnyType(TypeOfAny.from_omitted_generics),
268277
)
269278
order_return_type = self._api.named_type("builtins.bool")
270279
order_args = [

0 commit comments

Comments
 (0)