1
1
from __future__ import annotations
2
2
3
- from typing import Callable , Sequence
3
+ from typing import Callable , Iterable , Sequence
4
4
5
5
import mypy .subtypes
6
6
from mypy .erasetype import erase_typevars
7
7
from mypy .expandtype import expand_type
8
- from mypy .nodes import Context
8
+ from mypy .nodes import Context , TypeInfo
9
+ from mypy .type_visitor import TypeTranslator
10
+ from mypy .typeops import get_all_type_vars
9
11
from mypy .types import (
10
12
AnyType ,
11
13
CallableType ,
14
+ Instance ,
15
+ Parameters ,
16
+ ParamSpecFlavor ,
12
17
ParamSpecType ,
13
18
PartialType ,
19
+ ProperType ,
14
20
Type ,
21
+ TypeAliasType ,
15
22
TypeVarId ,
16
23
TypeVarLikeType ,
17
24
TypeVarTupleType ,
18
25
TypeVarType ,
19
26
UninhabitedType ,
20
27
UnpackType ,
21
28
get_proper_type ,
29
+ remove_dups ,
22
30
)
23
31
24
32
@@ -93,8 +101,7 @@ def apply_generic_arguments(
93
101
bound or constraints, instead of giving an error.
94
102
"""
95
103
tvars = callable .variables
96
- min_arg_count = sum (not tv .has_default () for tv in tvars )
97
- assert min_arg_count <= len (orig_types ) <= len (tvars )
104
+ assert len (orig_types ) <= len (tvars )
98
105
# Check that inferred type variable values are compatible with allowed
99
106
# values and bounds. Also, promote subtype values to allowed values.
100
107
# Create a map from type variable id to target type.
@@ -148,7 +155,7 @@ def apply_generic_arguments(
148
155
type_is = None
149
156
150
157
# The callable may retain some type vars if only some were applied.
151
- # TODO: move apply_poly() logic from checkexpr.py here when new inference
158
+ # TODO: move apply_poly() logic here when new inference
152
159
# becomes universally used (i.e. in all passes + in unification).
153
160
# With this new logic we can actually *add* some new free variables.
154
161
remaining_tvars : list [TypeVarLikeType ] = []
@@ -170,3 +177,126 @@ def apply_generic_arguments(
170
177
type_guard = type_guard ,
171
178
type_is = type_is ,
172
179
)
180
+
181
+
182
+ def apply_poly (tp : CallableType , poly_tvars : Sequence [TypeVarLikeType ]) -> CallableType | None :
183
+ """Make free type variables generic in the type if possible.
184
+
185
+ This will translate the type `tp` while trying to create valid bindings for
186
+ type variables `poly_tvars` while traversing the type. This follows the same rules
187
+ as we do during semantic analysis phase, examples:
188
+ * Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T
189
+ * Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T)
190
+ * List[T] -> None (not possible)
191
+ """
192
+ try :
193
+ return tp .copy_modified (
194
+ arg_types = [t .accept (PolyTranslator (poly_tvars )) for t in tp .arg_types ],
195
+ ret_type = tp .ret_type .accept (PolyTranslator (poly_tvars )),
196
+ variables = [],
197
+ )
198
+ except PolyTranslationError :
199
+ return None
200
+
201
+
202
+ class PolyTranslationError (Exception ):
203
+ pass
204
+
205
+
206
+ class PolyTranslator (TypeTranslator ):
207
+ """Make free type variables generic in the type if possible.
208
+
209
+ See docstring for apply_poly() for details.
210
+ """
211
+
212
+ def __init__ (
213
+ self ,
214
+ poly_tvars : Iterable [TypeVarLikeType ],
215
+ bound_tvars : frozenset [TypeVarLikeType ] = frozenset (),
216
+ seen_aliases : frozenset [TypeInfo ] = frozenset (),
217
+ ) -> None :
218
+ self .poly_tvars = set (poly_tvars )
219
+ # This is a simplified version of TypeVarScope used during semantic analysis.
220
+ self .bound_tvars = bound_tvars
221
+ self .seen_aliases = seen_aliases
222
+
223
+ def collect_vars (self , t : CallableType | Parameters ) -> list [TypeVarLikeType ]:
224
+ found_vars = []
225
+ for arg in t .arg_types :
226
+ for tv in get_all_type_vars (arg ):
227
+ if isinstance (tv , ParamSpecType ):
228
+ normalized : TypeVarLikeType = tv .copy_modified (
229
+ flavor = ParamSpecFlavor .BARE , prefix = Parameters ([], [], [])
230
+ )
231
+ else :
232
+ normalized = tv
233
+ if normalized in self .poly_tvars and normalized not in self .bound_tvars :
234
+ found_vars .append (normalized )
235
+ return remove_dups (found_vars )
236
+
237
+ def visit_callable_type (self , t : CallableType ) -> Type :
238
+ found_vars = self .collect_vars (t )
239
+ self .bound_tvars |= set (found_vars )
240
+ result = super ().visit_callable_type (t )
241
+ self .bound_tvars -= set (found_vars )
242
+
243
+ assert isinstance (result , ProperType ) and isinstance (result , CallableType )
244
+ result .variables = list (result .variables ) + found_vars
245
+ return result
246
+
247
+ def visit_type_var (self , t : TypeVarType ) -> Type :
248
+ if t in self .poly_tvars and t not in self .bound_tvars :
249
+ raise PolyTranslationError ()
250
+ return super ().visit_type_var (t )
251
+
252
+ def visit_param_spec (self , t : ParamSpecType ) -> Type :
253
+ if t in self .poly_tvars and t not in self .bound_tvars :
254
+ raise PolyTranslationError ()
255
+ return super ().visit_param_spec (t )
256
+
257
+ def visit_type_var_tuple (self , t : TypeVarTupleType ) -> Type :
258
+ if t in self .poly_tvars and t not in self .bound_tvars :
259
+ raise PolyTranslationError ()
260
+ return super ().visit_type_var_tuple (t )
261
+
262
+ def visit_type_alias_type (self , t : TypeAliasType ) -> Type :
263
+ if not t .args :
264
+ return t .copy_modified ()
265
+ if not t .is_recursive :
266
+ return get_proper_type (t ).accept (self )
267
+ # We can't handle polymorphic application for recursive generic aliases
268
+ # without risking an infinite recursion, just give up for now.
269
+ raise PolyTranslationError ()
270
+
271
+ def visit_instance (self , t : Instance ) -> Type :
272
+ if t .type .has_param_spec_type :
273
+ # We need this special-casing to preserve the possibility to store a
274
+ # generic function in an instance type. Things like
275
+ # forall T . Foo[[x: T], T]
276
+ # are not really expressible in current type system, but this looks like
277
+ # a useful feature, so let's keep it.
278
+ param_spec_index = next (
279
+ i for (i , tv ) in enumerate (t .type .defn .type_vars ) if isinstance (tv , ParamSpecType )
280
+ )
281
+ p = get_proper_type (t .args [param_spec_index ])
282
+ if isinstance (p , Parameters ):
283
+ found_vars = self .collect_vars (p )
284
+ self .bound_tvars |= set (found_vars )
285
+ new_args = [a .accept (self ) for a in t .args ]
286
+ self .bound_tvars -= set (found_vars )
287
+
288
+ repl = new_args [param_spec_index ]
289
+ assert isinstance (repl , ProperType ) and isinstance (repl , Parameters )
290
+ repl .variables = list (repl .variables ) + list (found_vars )
291
+ return t .copy_modified (args = new_args )
292
+ # There is the same problem with callback protocols as with aliases
293
+ # (callback protocols are essentially more flexible aliases to callables).
294
+ if t .args and t .type .is_protocol and t .type .protocol_members == ["__call__" ]:
295
+ if t .type in self .seen_aliases :
296
+ raise PolyTranslationError ()
297
+ call = mypy .subtypes .find_member ("__call__" , t , t , is_operator = True )
298
+ assert call is not None
299
+ return call .accept (
300
+ PolyTranslator (self .poly_tvars , self .bound_tvars , self .seen_aliases | {t .type })
301
+ )
302
+ return super ().visit_instance (t )
0 commit comments