Skip to content

Commit 3b2264c

Browse files
committed
avoid duplicated diagnostics during function argument inference
1 parent c45569d commit 3b2264c

File tree

6 files changed

+228
-164
lines changed

6 files changed

+228
-164
lines changed

crates/ty_python_semantic/resources/mdtest/call/overloads.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1663,7 +1663,12 @@ def _(arg: tuple[A | B, Any]):
16631663
reveal_type(f(*(arg,))) # revealed: Unknown
16641664
```
16651665

1666-
## Bi-directional Type Inference
1666+
## Bidirectional Type Inference
1667+
1668+
```toml
1669+
[environment]
1670+
python-version = "3.12"
1671+
```
16671672

16681673
Type inference accounts for parameter type annotations across all overloads.
16691674

@@ -1709,3 +1714,15 @@ def f(a: T | dict[str, int], b: int | str) -> int | str:
17091714
x = f({"y": 1}, "a")
17101715
reveal_type(x) # revealed: str
17111716
```
1717+
1718+
```py
1719+
from typing import SupportsRound, overload
1720+
1721+
@overload
1722+
def takes_str_or_float(x: str): ...
1723+
@overload
1724+
def takes_str_or_float(x: float): ...
1725+
def takes_str_or_float(x: float | str): ...
1726+
1727+
takes_str_or_float(round(1.0))
1728+
```

crates/ty_python_semantic/resources/mdtest/call/union.md

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,12 @@ def _(x: Union[Intersection[Any, Not[int]], Intersection[Any, Not[int]]]):
252252
reveal_type(x) # revealed: Any & ~int
253253
```
254254

255-
## Bi-directional Type Inference
255+
## Bidirectional Type Inference
256+
257+
```toml
258+
[environment]
259+
python-version = "3.12"
260+
```
256261

257262
Type inference accounts for parameter type annotations across all signatures in a union.
258263

@@ -272,7 +277,33 @@ def _(flag: bool):
272277
x = f({"x": 1})
273278
reveal_type(x) # revealed: int
274279

275-
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `T` constructor"
276-
# error: [invalid-key] "Invalid key access on TypedDict `T`: Unknown key "y""
280+
# TODO: error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `T`, found `dict[str, int]`"
281+
# we currently consider `TypedDict` instances to be subtypes of `dict`
277282
f({"y": 1})
278283
```
284+
285+
Diagnostics unrelated to the type-context are only reported once:
286+
287+
```py
288+
def f[T](x: T) -> list[T]:
289+
return [x]
290+
291+
def a(x: list[bool], y: list[bool]): ...
292+
def b(x: list[int], y: list[int]): ...
293+
def c(x: list[int], y: list[int]): ...
294+
def _(x: int):
295+
if x == 0:
296+
y = a
297+
elif x == 1:
298+
y = b
299+
else:
300+
y = c
301+
302+
if x == 0:
303+
z = True
304+
305+
y(f(True), [True])
306+
307+
# error: [possibly-unresolved-reference] "Name `z` used when possibly not defined"
308+
y(f(True), [z])
309+
```

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,16 +2048,6 @@ pub(crate) enum MatchingOverloadIndex {
20482048
Multiple(Vec<usize>),
20492049
}
20502050

2051-
impl MatchingOverloadIndex {
2052-
pub(crate) fn count(self) -> usize {
2053-
match self {
2054-
MatchingOverloadIndex::None => 0,
2055-
MatchingOverloadIndex::Single(_) => 1,
2056-
MatchingOverloadIndex::Multiple(items) => items.len(),
2057-
}
2058-
}
2059-
}
2060-
20612051
#[derive(Default, Debug)]
20622052
struct ArgumentForms {
20632053
values: Vec<Option<ParameterForm>>,
@@ -2535,8 +2525,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25352525
Type::TypeVar(_) => {}
25362526

25372527
_ => {
2538-
// Ignore any specialization errors here, because the type context is only used to
2539-
// optionally widen the return type.
2528+
// Ignore any specialization errors here, because the type context is only used as a hint
2529+
// to infer a more assignable return type.
25402530
let _ = builder.infer(return_ty, call_expression_tcx);
25412531
}
25422532
}

crates/ty_python_semantic/src/types/context.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pub(crate) struct InferContext<'db, 'ast> {
4040
module: &'ast ParsedModuleRef,
4141
diagnostics: std::cell::RefCell<TypeCheckDiagnostics>,
4242
no_type_check: InNoTypeCheck,
43+
multi_inference: bool,
4344
bomb: DebugDropBomb,
4445
}
4546

@@ -50,6 +51,7 @@ impl<'db, 'ast> InferContext<'db, 'ast> {
5051
scope,
5152
module,
5253
file: scope.file(db),
54+
multi_inference: false,
5355
diagnostics: std::cell::RefCell::new(TypeCheckDiagnostics::default()),
5456
no_type_check: InNoTypeCheck::default(),
5557
bomb: DebugDropBomb::new(
@@ -156,6 +158,18 @@ impl<'db, 'ast> InferContext<'db, 'ast> {
156158
DiagnosticGuardBuilder::new(self, id, severity)
157159
}
158160

161+
/// Returns `true` if the current expression is being inferred for a second
162+
/// (or subsequent) time, with a potentially different bidirectional type
163+
/// context.
164+
pub(super) fn is_in_multi_inference(&self) -> bool {
165+
self.multi_inference
166+
}
167+
168+
/// Set the multi-inference state, returning the previous value.
169+
pub(super) fn set_multi_inference(&mut self, multi_inference: bool) -> bool {
170+
std::mem::replace(&mut self.multi_inference, multi_inference)
171+
}
172+
159173
pub(super) fn set_in_no_type_check(&mut self, no_type_check: InNoTypeCheck) {
160174
self.no_type_check = no_type_check;
161175
}
@@ -410,6 +424,11 @@ impl<'db, 'ctx> LintDiagnosticGuardBuilder<'db, 'ctx> {
410424
if ctx.is_in_no_type_check() {
411425
return None;
412426
}
427+
// If this lint is being reported as part of multi-inference of a given expression,
428+
// silence it to avoid duplicated diagnostics.
429+
if ctx.is_in_multi_inference() {
430+
return None;
431+
}
413432
let id = DiagnosticId::Lint(lint.name());
414433

415434
let suppressions = suppressions(ctx.db(), ctx.file());

0 commit comments

Comments
 (0)