Skip to content

Commit cc1c679

Browse files
ilevkivskyiJukkaL
authored andcommitted
Better handling of generic functions in partial plugin (#17925)
Fixes #17411 The fix is that we remove type variables that can never be inferred from the initial `check_call()` call. Actual diff is tiny, I just moved a bunch of code, since I need formal to actual mapping sooner now.
1 parent d65a013 commit cc1c679

File tree

2 files changed

+68
-25
lines changed

2 files changed

+68
-25
lines changed

mypy/plugins/functools.py

+39-23
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
from mypy.argmap import map_actuals_to_formals
1111
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, Var
1212
from mypy.plugins.common import add_method_to_class
13+
from mypy.typeops import get_all_type_vars
1314
from mypy.types import (
1415
AnyType,
1516
CallableType,
1617
Instance,
1718
Overloaded,
1819
Type,
1920
TypeOfAny,
21+
TypeVarType,
2022
UnboundType,
2123
UnionType,
2224
get_proper_type,
@@ -164,21 +166,6 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
164166
ctx.api.type_context[-1] = None
165167
wrapped_return = False
166168

167-
defaulted = fn_type.copy_modified(
168-
arg_kinds=[
169-
(
170-
ArgKind.ARG_OPT
171-
if k == ArgKind.ARG_POS
172-
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
173-
)
174-
for k in fn_type.arg_kinds
175-
],
176-
ret_type=ret_type,
177-
)
178-
if defaulted.line < 0:
179-
# Make up a line number if we don't have one
180-
defaulted.set_line(ctx.default_return_type)
181-
182169
# Flatten actual to formal mapping, since this is what check_call() expects.
183170
actual_args = []
184171
actual_arg_kinds = []
@@ -199,6 +186,43 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
199186
actual_arg_names.append(ctx.arg_names[i][j])
200187
actual_types.append(ctx.arg_types[i][j])
201188

189+
formal_to_actual = map_actuals_to_formals(
190+
actual_kinds=actual_arg_kinds,
191+
actual_names=actual_arg_names,
192+
formal_kinds=fn_type.arg_kinds,
193+
formal_names=fn_type.arg_names,
194+
actual_arg_type=lambda i: actual_types[i],
195+
)
196+
197+
# We need to remove any type variables that appear only in formals that have
198+
# no actuals, to avoid eagerly binding them in check_call() below.
199+
can_infer_ids = set()
200+
for i, arg_type in enumerate(fn_type.arg_types):
201+
if not formal_to_actual[i]:
202+
continue
203+
can_infer_ids.update({tv.id for tv in get_all_type_vars(arg_type)})
204+
205+
defaulted = fn_type.copy_modified(
206+
arg_kinds=[
207+
(
208+
ArgKind.ARG_OPT
209+
if k == ArgKind.ARG_POS
210+
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
211+
)
212+
for k in fn_type.arg_kinds
213+
],
214+
ret_type=ret_type,
215+
variables=[
216+
tv
217+
for tv in fn_type.variables
218+
# Keep TypeVarTuple/ParamSpec to avoid spurious errors on empty args.
219+
if tv.id in can_infer_ids or not isinstance(tv, TypeVarType)
220+
],
221+
)
222+
if defaulted.line < 0:
223+
# Make up a line number if we don't have one
224+
defaulted.set_line(ctx.default_return_type)
225+
202226
# Create a valid context for various ad-hoc inspections in check_call().
203227
call_expr = CallExpr(
204228
callee=ctx.args[0][0],
@@ -231,14 +255,6 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
231255
return ctx.default_return_type
232256
bound = bound.copy_modified(ret_type=ret_type.args[0])
233257

234-
formal_to_actual = map_actuals_to_formals(
235-
actual_kinds=actual_arg_kinds,
236-
actual_names=actual_arg_names,
237-
formal_kinds=fn_type.arg_kinds,
238-
formal_names=fn_type.arg_names,
239-
actual_arg_type=lambda i: actual_types[i],
240-
)
241-
242258
partial_kinds = []
243259
partial_types = []
244260
partial_names = []

test-data/unit/check-functools.test

+29-2
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,6 @@ def bar(f: S) -> S:
575575
return f
576576
[builtins fixtures/primitives.pyi]
577577

578-
579578
[case testFunctoolsPartialAbstractType]
580579
# flags: --python-version 3.9
581580
from abc import ABC, abstractmethod
@@ -597,7 +596,6 @@ def f2() -> None:
597596
partial_cls() # E: Cannot instantiate abstract class "A" with abstract attribute "method"
598597
[builtins fixtures/tuple.pyi]
599598

600-
601599
[case testFunctoolsPartialSelfType]
602600
from functools import partial
603601
from typing_extensions import Self
@@ -610,3 +608,32 @@ class A:
610608
factory = partial(cls, ts=0)
611609
return factory(msg=msg)
612610
[builtins fixtures/tuple.pyi]
611+
612+
[case testFunctoolsPartialTypeVarValues]
613+
from functools import partial
614+
from typing import TypeVar
615+
616+
T = TypeVar("T", int, str)
617+
618+
def f(x: int, y: T) -> T:
619+
return y
620+
621+
def g(x: T, y: int) -> T:
622+
return x
623+
624+
def h(x: T, y: T) -> T:
625+
return x
626+
627+
fp = partial(f, 1)
628+
reveal_type(fp(1)) # N: Revealed type is "builtins.int"
629+
reveal_type(fp("a")) # N: Revealed type is "builtins.str"
630+
fp(object()) # E: Value of type variable "T" of "f" cannot be "object"
631+
632+
gp = partial(g, 1)
633+
reveal_type(gp(1)) # N: Revealed type is "builtins.int"
634+
gp("a") # E: Argument 1 to "g" has incompatible type "str"; expected "int"
635+
636+
hp = partial(h, 1)
637+
reveal_type(hp(1)) # N: Revealed type is "builtins.int"
638+
hp("a") # E: Argument 1 to "h" has incompatible type "str"; expected "int"
639+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)