Skip to content

Commit

Permalink
Some cleanup in partial plugin (#17423)
Browse files Browse the repository at this point in the history
Fixes #17405

Apart from fixing the crash I fix two obvious bugs I noticed while
making this PR.
  • Loading branch information
ilevkivskyi authored Jun 22, 2024
1 parent cc3492e commit 9012fc9
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 7 deletions.
2 changes: 2 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,8 @@ def apply_function_plugin(
formal_arg_exprs[formal].append(args[actual])
if arg_names:
formal_arg_names[formal].append(arg_names[actual])
else:
formal_arg_names[formal].append(None)
formal_arg_kinds[formal].append(arg_kinds[actual])

if object_type is None:
Expand Down
31 changes: 24 additions & 7 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Type,
TypeOfAny,
UnboundType,
UninhabitedType,
get_proper_type,
)

Expand Down Expand Up @@ -132,6 +131,9 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
if fn_type is None:
return ctx.default_return_type

# We must normalize from the start to have coherent view together with TypeChecker.
fn_type = fn_type.with_unpacked_kwargs().with_normalized_var_args()

defaulted = fn_type.copy_modified(
arg_kinds=[
(
Expand All @@ -146,10 +148,25 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
# Make up a line number if we don't have one
defaulted.set_line(ctx.default_return_type)

actual_args = [a for param in ctx.args[1:] for a in param]
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param]
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param]
actual_types = [a for param in ctx.arg_types[1:] for a in param]
# Flatten actual to formal mapping, since this is what check_call() expects.
actual_args = []
actual_arg_kinds = []
actual_arg_names = []
actual_types = []
seen_args = set()
for i, param in enumerate(ctx.args[1:], start=1):
for j, a in enumerate(param):
if a in seen_args:
# Same actual arg can map to multiple formals, but we need to include
# each one only once.
continue
# Here we rely on the fact that expressions are essentially immutable, so
# they can be compared by identity.
seen_args.add(a)
actual_args.append(a)
actual_arg_kinds.append(ctx.arg_kinds[i][j])
actual_arg_names.append(ctx.arg_names[i][j])
actual_types.append(ctx.arg_types[i][j])

# Create a valid context for various ad-hoc inspections in check_call().
call_expr = CallExpr(
Expand Down Expand Up @@ -188,7 +205,7 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
for i, actuals in enumerate(formal_to_actual):
if len(bound.arg_types) == len(fn_type.arg_types):
arg_type = bound.arg_types[i]
if isinstance(get_proper_type(arg_type), UninhabitedType):
if not mypy.checker.is_valid_inferred_type(arg_type):
arg_type = fn_type.arg_types[i] # bit of a hack
else:
# TODO: I assume that bound and fn_type have the same arguments. It appears this isn't
Expand All @@ -210,7 +227,7 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
partial_names.append(fn_type.arg_names[i])

ret_type = bound.ret_type
if isinstance(get_proper_type(ret_type), UninhabitedType):
if not mypy.checker.is_valid_inferred_type(ret_type):
ret_type = fn_type.ret_type # same kind of hack as above

partially_applied = fn_type.copy_modified(
Expand Down
52 changes: 52 additions & 0 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,55 @@ def foo(cls3: Type[B[T]]):
reveal_type(functools.partial(cls3, 2)()) # N: Revealed type is "__main__.B[T`-1]" \
# E: Argument 1 to "B" has incompatible type "int"; expected "T"
[builtins fixtures/tuple.pyi]

[case testFunctoolsPartialTypedDictUnpack]
from typing_extensions import TypedDict, Unpack
from functools import partial

class Data(TypedDict, total=False):
x: int

def f(**kwargs: Unpack[Data]) -> None: ...
def g(**kwargs: Unpack[Data]) -> None:
partial(f, **kwargs)()

class MoreData(TypedDict, total=False):
x: int
y: int

def f_more(**kwargs: Unpack[MoreData]) -> None: ...
def g_more(**kwargs: Unpack[MoreData]) -> None:
partial(f_more, **kwargs)()

class Good(TypedDict, total=False):
y: int
class Bad(TypedDict, total=False):
y: str

def h(**kwargs: Unpack[Data]) -> None:
bad: Bad
partial(f_more, **kwargs)(**bad) # E: Argument "y" to "f_more" has incompatible type "str"; expected "int"
good: Good
partial(f_more, **kwargs)(**good)
[builtins fixtures/dict.pyi]

[case testFunctoolsPartialNestedGeneric]
from functools import partial
from typing import Generic, TypeVar, List

T = TypeVar("T")
def get(n: int, args: List[T]) -> T: ...
first = partial(get, 0)

x: List[str]
reveal_type(first(x)) # N: Revealed type is "builtins.str"
reveal_type(first([1])) # N: Revealed type is "builtins.int"

first_kw = partial(get, n=0)
reveal_type(first_kw(args=[1])) # N: Revealed type is "builtins.int"

# TODO: this is indeed invalid, but the error is incomprehensible.
first_kw([1]) # E: Too many positional arguments for "get" \
# E: Too few arguments for "get" \
# E: Argument 1 to "get" has incompatible type "List[int]"; expected "int"
[builtins fixtures/list.pyi]

0 comments on commit 9012fc9

Please sign in to comment.