Skip to content

Commit 2ce3aba

Browse files
authored
[ty] Use annotated parameters as type context (#20635)
## Summary Use the type annotation of function parameters as bidirectional type context when inferring the argument expression. For example, the following example now type-checks: ```py class TD(TypedDict): x: int def f(_: TD): ... f({ "x": 1 }) ``` Part of astral-sh/ty#168.
1 parent b83ac5e commit 2ce3aba

File tree

13 files changed

+519
-95
lines changed

13 files changed

+519
-95
lines changed

crates/ruff_benchmark/benches/ty_walltime.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ static COLOUR_SCIENCE: std::sync::LazyLock<Benchmark<'static>> = std::sync::Lazy
117117
max_dep_date: "2025-06-17",
118118
python_version: PythonVersion::PY310,
119119
},
120-
500,
120+
600,
121121
)
122122
});
123123

crates/ty_python_semantic/resources/mdtest/call/overloads.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,3 +1662,67 @@ def _(arg: tuple[A | B, Any]):
16621662
reveal_type(f(arg)) # revealed: Unknown
16631663
reveal_type(f(*(arg,))) # revealed: Unknown
16641664
```
1665+
1666+
## Bidirectional Type Inference
1667+
1668+
```toml
1669+
[environment]
1670+
python-version = "3.12"
1671+
```
1672+
1673+
Type inference accounts for parameter type annotations across all overloads.
1674+
1675+
```py
1676+
from typing import TypedDict, overload
1677+
1678+
class T(TypedDict):
1679+
x: int
1680+
1681+
@overload
1682+
def f(a: list[T], b: int) -> int: ...
1683+
@overload
1684+
def f(a: list[dict[str, int]], b: str) -> str: ...
1685+
def f(a: list[dict[str, int]] | list[T], b: int | str) -> int | str:
1686+
return 1
1687+
1688+
def int_or_str() -> int | str:
1689+
return 1
1690+
1691+
x = f([{"x": 1}], int_or_str())
1692+
reveal_type(x) # revealed: int | str
1693+
1694+
# TODO: error: [no-matching-overload] "No overload of function `f` matches arguments"
1695+
# we currently incorrectly consider `list[dict[str, int]]` a subtype of `list[T]`
1696+
f([{"y": 1}], int_or_str())
1697+
```
1698+
1699+
Non-matching overloads do not produce diagnostics:
1700+
1701+
```py
1702+
from typing import TypedDict, overload
1703+
1704+
class T(TypedDict):
1705+
x: int
1706+
1707+
@overload
1708+
def f(a: T, b: int) -> int: ...
1709+
@overload
1710+
def f(a: dict[str, int], b: str) -> str: ...
1711+
def f(a: T | dict[str, int], b: int | str) -> int | str:
1712+
return 1
1713+
1714+
x = f({"y": 1}, "a")
1715+
reveal_type(x) # revealed: str
1716+
```
1717+
1718+
```py
1719+
from typing import SupportsRound, overload
1720+
1721+
@overload
1722+
def takes_str_or_float(x: str): ...
1723+
@overload
1724+
def takes_str_or_float(x: float): ...
1725+
def takes_str_or_float(x: float | str): ...
1726+
1727+
takes_str_or_float(round(1.0))
1728+
```

crates/ty_python_semantic/resources/mdtest/call/union.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,59 @@ from ty_extensions import Intersection, Not
251251
def _(x: Union[Intersection[Any, Not[int]], Intersection[Any, Not[int]]]):
252252
reveal_type(x) # revealed: Any & ~int
253253
```
254+
255+
## Bidirectional Type Inference
256+
257+
```toml
258+
[environment]
259+
python-version = "3.12"
260+
```
261+
262+
Type inference accounts for parameter type annotations across all signatures in a union.
263+
264+
```py
265+
from typing import TypedDict, overload
266+
267+
class T(TypedDict):
268+
x: int
269+
270+
def _(flag: bool):
271+
if flag:
272+
def f(x: T) -> int:
273+
return 1
274+
else:
275+
def f(x: dict[str, int]) -> int:
276+
return 1
277+
x = f({"x": 1})
278+
reveal_type(x) # revealed: int
279+
280+
# TODO: error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `T`, found `dict[str, int]`"
281+
# we currently consider `TypedDict` instances to be subtypes of `dict`
282+
f({"y": 1})
283+
```
284+
285+
Diagnostics unrelated to the type-context are only reported once:
286+
287+
```py
288+
def f[T](x: T) -> list[T]:
289+
return [x]
290+
291+
def a(x: list[bool], y: list[bool]): ...
292+
def b(x: list[int], y: list[int]): ...
293+
def c(x: list[int], y: list[int]): ...
294+
def _(x: int):
295+
if x == 0:
296+
y = a
297+
elif x == 1:
298+
y = b
299+
else:
300+
y = c
301+
302+
if x == 0:
303+
z = True
304+
305+
y(f(True), [True])
306+
307+
# error: [possibly-unresolved-reference] "Name `z` used when possibly not defined"
308+
y(f(True), [z])
309+
```

crates/ty_python_semantic/resources/mdtest/directives/assert_type.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from typing_extensions import assert_type
1010
def _(x: int):
1111
assert_type(x, int) # fine
1212
assert_type(x, str) # error: [type-assertion-failure]
13+
assert_type(assert_type(x, int), int)
1314
```
1415

1516
## Narrowing

crates/ty_python_semantic/resources/mdtest/snapshots/assert_type.md_-_`assert_type`_-_Basic_(c507788da2659ec9).snap

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/directives/assert_type.m
1717
3 | def _(x: int):
1818
4 | assert_type(x, int) # fine
1919
5 | assert_type(x, str) # error: [type-assertion-failure]
20+
6 | assert_type(assert_type(x, int), int)
2021
```
2122

2223
# Diagnostics
@@ -31,6 +32,7 @@ error[type-assertion-failure]: Argument does not have asserted type `str`
3132
| ^^^^^^^^^^^^-^^^^^^
3233
| |
3334
| Inferred type of argument is `int`
35+
6 | assert_type(assert_type(x, int), int)
3436
|
3537
info: `str` and `int` are not equivalent types
3638
info: rule `type-assertion-failure` is enabled by default

crates/ty_python_semantic/resources/mdtest/typed_dict.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ Person(name="Alice")
152152
# error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor"
153153
Person({"name": "Alice"})
154154

155-
# TODO: this should be an error, similar to the above
155+
# error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor"
156156
accepts_person({"name": "Alice"})
157157
# TODO: this should be an error, similar to the above
158158
house.owner = {"name": "Alice"}
@@ -171,7 +171,7 @@ Person(name=None, age=30)
171171
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`"
172172
Person({"name": None, "age": 30})
173173

174-
# TODO: this should be an error, similar to the above
174+
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`"
175175
accepts_person({"name": None, "age": 30})
176176
# TODO: this should be an error, similar to the above
177177
house.owner = {"name": None, "age": 30}
@@ -190,7 +190,7 @@ Person(name="Alice", age=30, extra=True)
190190
# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra""
191191
Person({"name": "Alice", "age": 30, "extra": True})
192192

193-
# TODO: this should be an error
193+
# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra""
194194
accepts_person({"name": "Alice", "age": 30, "extra": True})
195195
# TODO: this should be an error
196196
house.owner = {"name": "Alice", "age": 30, "extra": True}

crates/ty_python_semantic/src/types.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4194,20 +4194,26 @@ impl<'db> Type<'db> {
41944194
.into()
41954195
}
41964196

4197-
Some(KnownFunction::AssertType) => Binding::single(
4198-
self,
4199-
Signature::new(
4200-
Parameters::new([
4201-
Parameter::positional_only(Some(Name::new_static("value")))
4202-
.with_annotated_type(Type::any()),
4203-
Parameter::positional_only(Some(Name::new_static("type")))
4204-
.type_form()
4205-
.with_annotated_type(Type::any()),
4206-
]),
4207-
Some(Type::none(db)),
4208-
),
4209-
)
4210-
.into(),
4197+
Some(KnownFunction::AssertType) => {
4198+
let val_ty =
4199+
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Invariant);
4200+
4201+
Binding::single(
4202+
self,
4203+
Signature::new_generic(
4204+
Some(GenericContext::from_typevar_instances(db, [val_ty])),
4205+
Parameters::new([
4206+
Parameter::positional_only(Some(Name::new_static("value")))
4207+
.with_annotated_type(Type::TypeVar(val_ty)),
4208+
Parameter::positional_only(Some(Name::new_static("type")))
4209+
.type_form()
4210+
.with_annotated_type(Type::any()),
4211+
]),
4212+
Some(Type::TypeVar(val_ty)),
4213+
),
4214+
)
4215+
.into()
4216+
}
42114217

42124218
Some(KnownFunction::AssertNever) => {
42134219
Binding::single(

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,11 @@ impl<'db> InnerIntersectionBuilder<'db> {
10771077
// don't need to worry about finding any particular constraint more than once.
10781078
let constraints = constraints.elements(db);
10791079
let mut positive_constraint_count = 0;
1080-
for positive in &self.positive {
1080+
for (i, positive) in self.positive.iter().enumerate() {
1081+
if i == typevar_index {
1082+
continue;
1083+
}
1084+
10811085
// This linear search should be fine as long as we don't encounter typevars with
10821086
// thousands of constraints.
10831087
positive_constraint_count += constraints

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ use crate::types::{
3333
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
3434
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
3535
TrackedConstraintSet, TypeAliasType, TypeContext, UnionBuilder, UnionType,
36-
WrapperDescriptorKind, enums, ide_support, todo_type,
36+
WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, todo_type,
3737
};
3838
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
39-
use ruff_python_ast::{self as ast, PythonVersion};
39+
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
4040

4141
/// Binding information for a possible union of callables. At a call site, the arguments must be
4242
/// compatible with _all_ of the types in the union for the call to be valid.
@@ -1776,7 +1776,7 @@ impl<'db> CallableBinding<'db> {
17761776
}
17771777

17781778
/// Returns the index of the matching overload in the form of [`MatchingOverloadIndex`].
1779-
fn matching_overload_index(&self) -> MatchingOverloadIndex {
1779+
pub(crate) fn matching_overload_index(&self) -> MatchingOverloadIndex {
17801780
let mut matching_overloads = self.matching_overloads();
17811781
match matching_overloads.next() {
17821782
None => MatchingOverloadIndex::None,
@@ -1794,8 +1794,15 @@ impl<'db> CallableBinding<'db> {
17941794
}
17951795
}
17961796

1797+
/// Returns all overloads for this call binding, including overloads that did not match.
1798+
pub(crate) fn overloads(&self) -> &[Binding<'db>] {
1799+
self.overloads.as_slice()
1800+
}
1801+
17971802
/// Returns an iterator over all the overloads that matched for this call binding.
1798-
pub(crate) fn matching_overloads(&self) -> impl Iterator<Item = (usize, &Binding<'db>)> {
1803+
pub(crate) fn matching_overloads(
1804+
&self,
1805+
) -> impl Iterator<Item = (usize, &Binding<'db>)> + Clone {
17991806
self.overloads
18001807
.iter()
18011808
.enumerate()
@@ -2026,7 +2033,7 @@ enum OverloadCallReturnType<'db> {
20262033
}
20272034

20282035
#[derive(Debug)]
2029-
enum MatchingOverloadIndex {
2036+
pub(crate) enum MatchingOverloadIndex {
20302037
/// No matching overloads found.
20312038
None,
20322039

@@ -2504,9 +2511,17 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25042511
if let Some(return_ty) = self.signature.return_ty
25052512
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
25062513
{
2507-
// Ignore any specialization errors here, because the type context is only used to
2508-
// optionally widen the return type.
2509-
let _ = builder.infer(return_ty, call_expression_tcx);
2514+
match call_expression_tcx {
2515+
// A type variable is not a useful type-context for expression inference, and applying it
2516+
// to the return type can lead to confusing unions in nested generic calls.
2517+
Type::TypeVar(_) => {}
2518+
2519+
_ => {
2520+
// Ignore any specialization errors here, because the type context is only used as a hint
2521+
// to infer a more assignable return type.
2522+
let _ = builder.infer(return_ty, call_expression_tcx);
2523+
}
2524+
}
25102525
}
25112526

25122527
let parameters = self.signature.parameters();
@@ -3289,6 +3304,23 @@ impl<'db> BindingError<'db> {
32893304
return;
32903305
};
32913306

3307+
// Re-infer the argument type of call expressions, ignoring the type context for more
3308+
// precise error messages.
3309+
let provided_ty = match Self::get_argument_node(node, *argument_index) {
3310+
None => *provided_ty,
3311+
3312+
// Ignore starred arguments, as those are difficult to re-infer.
3313+
Some(
3314+
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_))
3315+
| ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }),
3316+
) => *provided_ty,
3317+
3318+
Some(
3319+
ast::ArgOrKeyword::Arg(value)
3320+
| ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }),
3321+
) => infer_isolated_expression(context.db(), context.scope(), value),
3322+
};
3323+
32923324
let provided_ty_display = provided_ty.display(context.db());
32933325
let expected_ty_display = expected_ty.display(context.db());
32943326

@@ -3624,22 +3656,29 @@ impl<'db> BindingError<'db> {
36243656
}
36253657
}
36263658

3627-
fn get_node(node: ast::AnyNodeRef, argument_index: Option<usize>) -> ast::AnyNodeRef {
3659+
fn get_node(node: ast::AnyNodeRef<'_>, argument_index: Option<usize>) -> ast::AnyNodeRef<'_> {
36283660
// If we have a Call node and an argument index, report the diagnostic on the correct
36293661
// argument node; otherwise, report it on the entire provided node.
3662+
match Self::get_argument_node(node, argument_index) {
3663+
Some(ast::ArgOrKeyword::Arg(expr)) => expr.into(),
3664+
Some(ast::ArgOrKeyword::Keyword(expr)) => expr.into(),
3665+
None => node,
3666+
}
3667+
}
3668+
3669+
fn get_argument_node(
3670+
node: ast::AnyNodeRef<'_>,
3671+
argument_index: Option<usize>,
3672+
) -> Option<ArgOrKeyword<'_>> {
36303673
match (node, argument_index) {
3631-
(ast::AnyNodeRef::ExprCall(call_node), Some(argument_index)) => {
3632-
match call_node
3674+
(ast::AnyNodeRef::ExprCall(call_node), Some(argument_index)) => Some(
3675+
call_node
36333676
.arguments
36343677
.arguments_source_order()
36353678
.nth(argument_index)
3636-
.expect("argument index should not be out of range")
3637-
{
3638-
ast::ArgOrKeyword::Arg(expr) => expr.into(),
3639-
ast::ArgOrKeyword::Keyword(keyword) => keyword.into(),
3640-
}
3641-
}
3642-
_ => node,
3679+
.expect("argument index should not be out of range"),
3680+
),
3681+
_ => None,
36433682
}
36443683
}
36453684
}

0 commit comments

Comments
 (0)