Skip to content
25 changes: 23 additions & 2 deletions mypy/argmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,25 @@ def f(x: int, *args: str) -> None: ...
needs a separate instance since instances have per-call state.
"""

def __init__(self, context: ArgumentInferContext) -> None:
import mypy.checker
import mypy.nodes

def __init__(
self,
context: ArgumentInferContext,
context_check: mypy.nodes.Context,
checker: mypy.checker.TypeChecker | None = None,
) -> None:
# Next tuple *args index to use.
self.tuple_index = 0
# Keyword arguments in TypedDict **kwargs used.
self.kwargs_used: set[str] = set()
# Type context for `*` and `**` arg kinds.
self.context = context
# TypeChecker to check true argument types for iterables
self.checker = checker

self.context_check = context_check

def expand_actual_type(
self,
Expand Down Expand Up @@ -235,7 +247,16 @@ def expand_actual_type(
# ParamSpec is valid in *args but it can't be unpacked.
return actual_type
else:
return AnyType(TypeOfAny.from_error)
if self.checker is not None:
# get the true type of the arguments of the iterable
iterable_item_type = (
self.checker.analyze_iterable_item_type_without_expression(
actual_type, self.context_check
)[1]
)
return iterable_item_type
else:
return AnyType(TypeOfAny.from_error)
elif actual_kind == nodes.ARG_STAR2:
from mypy.subtypes import is_subtype

Expand Down
8 changes: 7 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,6 +2026,8 @@ def infer_function_type_arguments(
pass1_args.append(arg)

inferred_args, _ = infer_function_type_arguments(
context,
self.chk,
callee_type,
pass1_args,
arg_kinds,
Expand Down Expand Up @@ -2087,6 +2089,8 @@ def infer_function_type_arguments(
# potentially involving free variables.
# TODO: support the similar inference for return type context.
poly_inferred_args, free_vars = infer_function_type_arguments(
context,
self.chk,
callee_type,
arg_types,
arg_kinds,
Expand Down Expand Up @@ -2167,6 +2171,8 @@ def infer_function_type_arguments_pass2(
arg_types = self.infer_arg_types_in_context(callee_type, args, arg_kinds, formal_to_actual)

inferred_args, _ = infer_function_type_arguments(
context,
self.chk,
callee_type,
arg_types,
arg_kinds,
Expand Down Expand Up @@ -2407,7 +2413,7 @@ def check_argument_types(
"""
check_arg = check_arg or self.check_arg
# Keep track of consumed tuple *arg items.
mapper = ArgTypeExpander(self.argument_infer_context())
mapper = ArgTypeExpander(self.argument_infer_context(), context, self.chk)
for i, actuals in enumerate(formal_to_actual):
orig_callee_arg_type = get_proper_type(callee.arg_types[i])

Expand Down
7 changes: 6 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,12 @@ def __eq__(self, other: object) -> bool:
return (self.type_var, self.op, self.target) == (other.type_var, other.op, other.target)


import mypy


def infer_constraints_for_callable(
context_check: mypy.nodes.Context,
checker: mypy.checker.TypeChecker,
callee: CallableType,
arg_types: Sequence[Type | None],
arg_kinds: list[ArgKind],
Expand All @@ -116,7 +121,7 @@ def infer_constraints_for_callable(
Return a list of constraints.
"""
constraints: list[Constraint] = []
mapper = ArgTypeExpander(context)
mapper = ArgTypeExpander(context, context_check, checker)

param_spec = callee.param_spec()
param_spec_arg_types = []
Expand Down
16 changes: 15 additions & 1 deletion mypy/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ class ArgumentInferContext(NamedTuple):
iterable_type: Instance


import mypy.checker
import mypy.nodes


def infer_function_type_arguments(
context_check: mypy.nodes.Context,
checker: mypy.checker.TypeChecker,
callee_type: CallableType,
arg_types: Sequence[Type | None],
arg_kinds: list[ArgKind],
Expand All @@ -53,8 +59,16 @@ def infer_function_type_arguments(
formal_to_actual: mapping from formal to actual variable indices
"""
# Infer constraints.
# pass Context into this
constraints = infer_constraints_for_callable(
callee_type, arg_types, arg_kinds, arg_names, formal_to_actual, context
context_check,
checker,
callee_type,
arg_types,
arg_kinds,
arg_names,
formal_to_actual,
context,
)

# Solve constraints.
Expand Down
28 changes: 28 additions & 0 deletions test-data/unit/check-unpack_iterable.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[case checkiter]
# flags: --python-version 3.11
import typing

class Spam:
def __iter__(self, /) -> typing.Iterator[int]:
yield 1

a = Spam()

# list[int] TODO: fix
# reveal_type(list(a))

# list[int]
reveal_type([i for i in a]) # N: Revealed type is "builtins.list[builtins.int]"

# list[int]
reveal_type([*(i for i in a)]) # N: Revealed type is "builtins.list[builtins.int]"

# list[int]
reveal_type([*a.__iter__()]) # N: Revealed type is "builtins.list[builtins.int]"

# list[Any] ???
reveal_type([*a]) # N: Revealed type is "builtins.list[builtins.int]"

b, = a
# int
reveal_type(b) # N: Revealed type is "builtins.int"