10
10
from mypy .argmap import map_actuals_to_formals
11
11
from mypy .nodes import ARG_POS , ARG_STAR2 , ArgKind , Argument , CallExpr , FuncItem , Var
12
12
from mypy .plugins .common import add_method_to_class
13
+ from mypy .typeops import get_all_type_vars
13
14
from mypy .types import (
14
15
AnyType ,
15
16
CallableType ,
16
17
Instance ,
17
18
Overloaded ,
18
19
Type ,
19
20
TypeOfAny ,
21
+ TypeVarType ,
20
22
UnboundType ,
21
23
UnionType ,
22
24
get_proper_type ,
@@ -164,21 +166,6 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
164
166
ctx .api .type_context [- 1 ] = None
165
167
wrapped_return = False
166
168
167
- defaulted = fn_type .copy_modified (
168
- arg_kinds = [
169
- (
170
- ArgKind .ARG_OPT
171
- if k == ArgKind .ARG_POS
172
- else (ArgKind .ARG_NAMED_OPT if k == ArgKind .ARG_NAMED else k )
173
- )
174
- for k in fn_type .arg_kinds
175
- ],
176
- ret_type = ret_type ,
177
- )
178
- if defaulted .line < 0 :
179
- # Make up a line number if we don't have one
180
- defaulted .set_line (ctx .default_return_type )
181
-
182
169
# Flatten actual to formal mapping, since this is what check_call() expects.
183
170
actual_args = []
184
171
actual_arg_kinds = []
@@ -199,6 +186,43 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
199
186
actual_arg_names .append (ctx .arg_names [i ][j ])
200
187
actual_types .append (ctx .arg_types [i ][j ])
201
188
189
+ formal_to_actual = map_actuals_to_formals (
190
+ actual_kinds = actual_arg_kinds ,
191
+ actual_names = actual_arg_names ,
192
+ formal_kinds = fn_type .arg_kinds ,
193
+ formal_names = fn_type .arg_names ,
194
+ actual_arg_type = lambda i : actual_types [i ],
195
+ )
196
+
197
+ # We need to remove any type variables that appear only in formals that have
198
+ # no actuals, to avoid eagerly binding them in check_call() below.
199
+ can_infer_ids = set ()
200
+ for i , arg_type in enumerate (fn_type .arg_types ):
201
+ if not formal_to_actual [i ]:
202
+ continue
203
+ can_infer_ids .update ({tv .id for tv in get_all_type_vars (arg_type )})
204
+
205
+ defaulted = fn_type .copy_modified (
206
+ arg_kinds = [
207
+ (
208
+ ArgKind .ARG_OPT
209
+ if k == ArgKind .ARG_POS
210
+ else (ArgKind .ARG_NAMED_OPT if k == ArgKind .ARG_NAMED else k )
211
+ )
212
+ for k in fn_type .arg_kinds
213
+ ],
214
+ ret_type = ret_type ,
215
+ variables = [
216
+ tv
217
+ for tv in fn_type .variables
218
+ # Keep TypeVarTuple/ParamSpec to avoid spurious errors on empty args.
219
+ if tv .id in can_infer_ids or not isinstance (tv , TypeVarType )
220
+ ],
221
+ )
222
+ if defaulted .line < 0 :
223
+ # Make up a line number if we don't have one
224
+ defaulted .set_line (ctx .default_return_type )
225
+
202
226
# Create a valid context for various ad-hoc inspections in check_call().
203
227
call_expr = CallExpr (
204
228
callee = ctx .args [0 ][0 ],
@@ -231,14 +255,6 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
231
255
return ctx .default_return_type
232
256
bound = bound .copy_modified (ret_type = ret_type .args [0 ])
233
257
234
- formal_to_actual = map_actuals_to_formals (
235
- actual_kinds = actual_arg_kinds ,
236
- actual_names = actual_arg_names ,
237
- formal_kinds = fn_type .arg_kinds ,
238
- formal_names = fn_type .arg_names ,
239
- actual_arg_type = lambda i : actual_types [i ],
240
- )
241
-
242
258
partial_kinds = []
243
259
partial_types = []
244
260
partial_names = []
0 commit comments