Skip to content

Commit e8b8f6d

Browse files
committed
avoid type errors on first non-matching overload
1 parent 0d8d064 commit e8b8f6d

File tree

6 files changed

+171
-191
lines changed

6 files changed

+171
-191
lines changed

crates/ty_python_semantic/resources/mdtest/call/overloads.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,11 @@ def _(arg: tuple[A | B, Any]):
16651665

16661666
## Bidirectional Type Inference
16671667

1668+
```toml
1669+
[environment]
1670+
python-version = "3.12"
1671+
```
1672+
16681673
Type inference accounts for parameter type annotations across all overloads.
16691674

16701675
```py
@@ -1709,3 +1714,15 @@ def f(a: T | dict[str, int], b: int | str) -> int | str:
17091714
x = f({"y": 1}, "a")
17101715
reveal_type(x) # revealed: str
17111716
```
1717+
1718+
```py
1719+
from typing import SupportsRound, overload
1720+
1721+
@overload
1722+
def takes_str_or_float(x: str): ...
1723+
@overload
1724+
def takes_str_or_float(x: float): ...
1725+
def takes_str_or_float(x: float | str): ...
1726+
1727+
takes_str_or_float(round(1.0))
1728+
```

crates/ty_python_semantic/resources/mdtest/call/union.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ def _(flag: bool):
277277
x = f({"x": 1})
278278
reveal_type(x) # revealed: int
279279

280-
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `T` constructor"
281-
# error: [invalid-key] "Invalid key access on TypedDict `T`: Unknown key "y""
280+
# TODO: error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `T`, found `dict[str, int]`"
281+
# we currently consider `TypedDict` instances to be subtypes of `dict`
282282
f({"y": 1})
283283
```
284284

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,16 +2048,6 @@ pub(crate) enum MatchingOverloadIndex {
20482048
Multiple(Vec<usize>),
20492049
}
20502050

2051-
impl MatchingOverloadIndex {
2052-
pub(crate) fn count(self) -> usize {
2053-
match self {
2054-
MatchingOverloadIndex::None => 0,
2055-
MatchingOverloadIndex::Single(_) => 1,
2056-
MatchingOverloadIndex::Multiple(items) => items.len(),
2057-
}
2058-
}
2059-
}
2060-
20612051
#[derive(Default, Debug)]
20622052
struct ArgumentForms {
20632053
values: Vec<Option<ParameterForm>>,
@@ -2535,8 +2525,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25352525
Type::TypeVar(_) => {}
25362526

25372527
_ => {
2538-
// Ignore any specialization errors here, because the type context is only used to
2539-
// optionally widen the return type.
2528+
// Ignore any specialization errors here, because the type context is only used as a hint
2529+
// to infer a more assignable return type.
25402530
let _ = builder.infer(return_ty, call_expression_tcx);
25412531
}
25422532
}

crates/ty_python_semantic/src/types/context.rs

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -133,21 +133,7 @@ impl<'db, 'ast> InferContext<'db, 'ast> {
133133
lint: &'static LintMetadata,
134134
ranged: T,
135135
) -> Option<LintDiagnosticGuardBuilder<'ctx, 'db>> {
136-
LintDiagnosticGuardBuilder::new(self, lint, ranged.range(), false)
137-
}
138-
139-
/// Similar to `report_lint`, except forces the diagnostic to be
140-
/// reported even when inferring an expression multiple times.
141-
///
142-
/// This should be used for diagnostics that are affected by bidirectional
143-
/// type context, which may change across multiple inferences of a function
144-
/// argument expression.
145-
pub(super) fn report_bidirectional_lint<'ctx, T: Ranged>(
146-
&'ctx self,
147-
lint: &'static LintMetadata,
148-
ranged: T,
149-
) -> Option<LintDiagnosticGuardBuilder<'ctx, 'db>> {
150-
LintDiagnosticGuardBuilder::new(self, lint, ranged.range(), true)
136+
LintDiagnosticGuardBuilder::new(self, lint, ranged.range())
151137
}
152138

153139
/// Optionally return a builder for a diagnostic guard.
@@ -179,8 +165,9 @@ impl<'db, 'ast> InferContext<'db, 'ast> {
179165
self.multi_inference
180166
}
181167

182-
pub(super) fn set_multi_inference(&mut self, multi_inference: bool) {
183-
self.multi_inference = multi_inference;
168+
/// Set the multi-inference state, returning the previous value.
169+
pub(super) fn set_multi_inference(&mut self, multi_inference: bool) -> bool {
170+
std::mem::replace(&mut self.multi_inference, multi_inference)
184171
}
185172

186173
pub(super) fn set_in_no_type_check(&mut self, no_type_check: InNoTypeCheck) {
@@ -413,7 +400,6 @@ impl<'db, 'ctx> LintDiagnosticGuardBuilder<'db, 'ctx> {
413400
ctx: &'ctx InferContext<'db, 'ctx>,
414401
lint: &'static LintMetadata,
415402
range: TextRange,
416-
bidirectional: bool,
417403
) -> Option<LintDiagnosticGuardBuilder<'db, 'ctx>> {
418404
// The comment below was copied from the original
419405
// implementation of diagnostic reporting. The code
@@ -439,9 +425,8 @@ impl<'db, 'ctx> LintDiagnosticGuardBuilder<'db, 'ctx> {
439425
return None;
440426
}
441427
// If this lint is being reported as part of multi-inference of a given expression,
442-
// silence it to avoid duplicated diagnostics, unless it may have been affected by
443-
// the bidirectional type context.
444-
if ctx.is_in_multi_inference() && !bidirectional {
428+
// silence it to avoid duplicated diagnostics.
429+
if ctx.is_in_multi_inference() {
445430
return None;
446431
}
447432
let id = DiagnosticId::Lint(lint.name());

0 commit comments

Comments
 (0)