Skip to content
13 changes: 12 additions & 1 deletion crates/red_knot_python_semantic/resources/mdtest/overloads.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,18 +309,29 @@ reveal_type(func("")) # revealed: Literal[""]

### At least two overloads

<!-- snapshot-diagnostics -->

At least two `@overload`-decorated definitions must be present.

```py
from typing import overload

# TODO: error
@overload
def func(x: int) -> int: ...

# error: [invalid-overload]
def func(x: int | str) -> int | str:
return x
```

```pyi
from typing import overload

@overload
# error: [invalid-overload]
def func(x: int) -> int: ...
```

### Overload without an implementation

#### Regular modules
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
---
source: crates/red_knot_test/src/lib.rs
expression: snapshot
---
---
mdtest name: overloads.md - Overloads - Invalid - At least two overloads
mdtest path: crates/red_knot_python_semantic/resources/mdtest/overloads.md
---

# Python source files

## mdtest_snippet.py

```
1 | from typing import overload
2 |
3 | @overload
4 | def func(x: int) -> int: ...
5 |
6 | # error: [invalid-overload]
7 | def func(x: int | str) -> int | str:
8 | return x
```

## mdtest_snippet.pyi

```
1 | from typing import overload
2 |
3 | @overload
4 | # error: [invalid-overload]
5 | def func(x: int) -> int: ...
```

# Diagnostics

```
error: lint:invalid-overload: Overloaded function `func` requires at least two overloads
--> src/mdtest_snippet.py:4:5
|
3 | @overload
4 | def func(x: int) -> int: ...
| ---- Only one overload defined here
5 |
6 | # error: [invalid-overload]
7 | def func(x: int | str) -> int | str:
| ^^^^
8 | return x
|

```

```
error: lint:invalid-overload: Overloaded function `func` requires at least two overloads
--> src/mdtest_snippet.pyi:5:5
|
3 | @overload
4 | # error: [invalid-overload]
5 | def func(x: int) -> int: ...
| ----
| |
| Only one overload defined here
|

```
31 changes: 29 additions & 2 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6375,6 +6375,13 @@ pub struct FunctionType<'db> {

#[salsa::tracked]
impl<'db> FunctionType<'db> {
/// Returns the [`File`] in which this function is defined.
pub(crate) fn file(self, db: &'db dyn Db) -> File {
// NOTE: Do not use `self.definition(db).file(db)` here, as that could create a
// cross-module dependency on the full AST.
self.body_scope(db).file(db)
}

pub(crate) fn has_known_decorator(self, db: &dyn Db, decorator: FunctionDecorators) -> bool {
self.decorators(db).contains(decorator)
}
Expand All @@ -6396,21 +6403,41 @@ impl<'db> FunctionType<'db> {
Type::BoundMethod(BoundMethodType::new(db, self, self_instance))
}

/// Returns the AST node for this function.
pub(crate) fn node(self, db: &'db dyn Db, file: File) -> &'db ast::StmtFunctionDef {
debug_assert_eq!(
file,
self.file(db),
"FunctionType::node() must be called with the same file as the one where \
the function is defined."
);

self.body_scope(db).node(db).expect_function()
}

/// Returns the [`FileRange`] of the function's name.
pub fn focus_range(self, db: &dyn Db) -> FileRange {
FileRange::new(
self.body_scope(db).file(db),
self.file(db),
self.body_scope(db).node(db).expect_function().name.range,
)
}

pub fn full_range(self, db: &dyn Db) -> FileRange {
FileRange::new(
self.body_scope(db).file(db),
self.file(db),
self.body_scope(db).node(db).expect_function().range,
)
}

/// Returns the [`Definition`] of this function.
///
/// ## Warning
///
/// This uses the semantic index to find the definition of the function. This means that if the
/// calling query is not in the same file as this function is defined in, then this will create
/// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation.
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
let body_scope = self.body_scope(db);
let index = semantic_index(db, body_scope.file(db));
Expand Down
44 changes: 44 additions & 0 deletions crates/red_knot_python_semantic/src/types/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
registry.register_lint(&INVALID_EXCEPTION_CAUGHT);
registry.register_lint(&INVALID_LEGACY_TYPE_VARIABLE);
registry.register_lint(&INVALID_METACLASS);
registry.register_lint(&INVALID_OVERLOAD);
registry.register_lint(&INVALID_PARAMETER_DEFAULT);
registry.register_lint(&INVALID_PROTOCOL);
registry.register_lint(&INVALID_RAISE);
Expand Down Expand Up @@ -447,6 +448,49 @@ declare_lint! {
}
}

declare_lint! {
/// ## What it does
/// Checks for various invalid `@overload` usages.
///
/// ## Why is this bad?
/// The `@overload` decorator is used to define functions and methods that accepts different
/// combinations of arguments and return different types based on the arguments passed. This is
/// mainly beneficial for type checkers. But, if the `@overload` usage is invalid, the type
/// checker may not be able to provide correct type information.
///
/// ## Example
///
/// Defining only one overload:
///
/// ```py
/// from typing import overload
///
/// @overload
/// def foo(x: int) -> int: ...
/// def foo(x: int | None) -> int | None:
/// return x
/// ```
///
/// Or, not providing an implementation for the overloaded definition:
///
/// ```py
/// from typing import overload
///
/// @overload
/// def foo() -> None: ...
/// @overload
/// def foo(x: int) -> int: ...
/// ```
///
/// ## References
/// - [Python documentation: `@overload`](https://docs.python.org/3/library/typing.html#typing.overload)
pub(crate) static INVALID_OVERLOAD = {
summary: "detects invalid `@overload` usages",
status: LintStatus::preview("1.0.0"),
default_level: Level::Error,
}
}

declare_lint! {
/// ## What it does
/// Checks for default values that can't be assigned to the parameter's annotated type.
Expand Down
127 changes: 123 additions & 4 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ use super::diagnostic::{
report_invalid_exception_raised, report_invalid_type_checking_constant,
report_non_subscriptable, report_possibly_unresolved_reference,
report_runtime_check_against_non_runtime_checkable_protocol, report_slice_step_size_zero,
report_unresolved_reference, INVALID_METACLASS, INVALID_PROTOCOL, REDUNDANT_CAST,
STATIC_ASSERT_ERROR, SUBCLASS_OF_FINAL_CLASS, TYPE_ASSERTION_FAILURE,
report_unresolved_reference, INVALID_METACLASS, INVALID_OVERLOAD, INVALID_PROTOCOL,
REDUNDANT_CAST, STATIC_ASSERT_ERROR, SUBCLASS_OF_FINAL_CLASS, TYPE_ASSERTION_FAILURE,
};
use super::slots::check_class_slots;
use super::string_annotation::{
Expand Down Expand Up @@ -418,7 +418,7 @@ impl<'db> TypeInference<'db> {
.copied()
.or(self.cycle_fallback_type)
.expect(
"definition should belong to this TypeInference region and
"definition should belong to this TypeInference region and \
TypeInferenceBuilder should have inferred a type for it",
)
}
Expand All @@ -430,7 +430,7 @@ impl<'db> TypeInference<'db> {
.copied()
.or(self.cycle_fallback_type.map(Into::into))
.expect(
"definition should belong to this TypeInference region and
"definition should belong to this TypeInference region and \
TypeInferenceBuilder should have inferred a type for it",
)
}
Expand Down Expand Up @@ -524,6 +524,31 @@ pub(super) struct TypeInferenceBuilder<'db> {
/// The returned types and their corresponding ranges of the region, if it is a function body.
return_types_and_ranges: Vec<TypeAndRange<'db>>,

/// A set of functions that have been defined **and** called in this region.
///
/// This is a set because the same function could be called multiple times in the same region.
/// This is mainly used in [`check_overloaded_functions`] to check an overloaded function that
/// is shadowed by a function with the same name in this scope but has been called before. For
/// example:
///
/// ```py
/// from typing import overload
///
/// @overload
/// def foo() -> None: ...
/// @overload
/// def foo(x: int) -> int: ...
/// def foo(x: int | None) -> int | None: return x
///
/// foo() # An overloaded function that was defined in this scope have been called
///
/// def foo(x: int) -> int:
/// return x
/// ```
///
/// [`check_overloaded_functions`]: TypeInferenceBuilder::check_overloaded_functions
called_functions: FxHashSet<FunctionType<'db>>,

/// The deferred state of inferring types of certain expressions within the region.
///
/// This is different from [`InferenceRegion::Deferred`] which works on the entire definition
Expand Down Expand Up @@ -556,6 +581,7 @@ impl<'db> TypeInferenceBuilder<'db> {
index,
region,
return_types_and_ranges: vec![],
called_functions: FxHashSet::default(),
deferred_state: DeferredExpressionState::None,
types: TypeInference::empty(scope),
}
Expand Down Expand Up @@ -718,6 +744,7 @@ impl<'db> TypeInferenceBuilder<'db> {

// TODO: Only call this function when diagnostics are enabled.
self.check_class_definitions();
self.check_overloaded_functions();
}

/// Iterate over all class definitions to check that the definition will not cause an exception
Expand Down Expand Up @@ -952,6 +979,86 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}

/// Check the overloaded functions in this scope.
///
/// This only checks the overloaded functions that are:
/// 1. Visible publicly at the end of this scope
/// 2. Or, defined and called in this scope
///
/// For (1), this has the consequence of not checking an overloaded function that is being
/// shadowed by another function with the same name in this scope.
fn check_overloaded_functions(&mut self) {
// Collect all the unique overloaded function symbols in this scope. This requires a set
// because an overloaded function uses the same symbol for each of the overloads and the
// implementation.
let overloaded_function_symbols: FxHashSet<_> = self
.types
.declarations
.iter()
.filter_map(|(definition, ty)| {
// Filter out function literals that result from anything other than a function
// definition e.g., imports which would create a cross-module AST dependency.
if !matches!(definition.kind(self.db()), DefinitionKind::Function(_)) {
return None;
}
let function = ty.inner_type().into_function_literal()?;
if function.has_known_decorator(self.db(), FunctionDecorators::OVERLOAD) {
Some(definition.symbol(self.db()))
} else {
None
}
})
.collect();

let use_def = self
.index
.use_def_map(self.scope().file_scope_id(self.db()));

let mut public_functions = FxHashSet::default();

for symbol in overloaded_function_symbols {
if let Symbol::Type(Type::FunctionLiteral(function), Boundness::Bound) =
symbol_from_bindings(self.db(), use_def.public_bindings(symbol))
{
if function.file(self.db()) != self.file() {
// If the function is not in this file, we don't need to check it.
// https://github.com/astral-sh/ruff/pull/17609#issuecomment-2839445740
continue;
}

// Extend the functions that we need to check with the publicly visible overloaded
// function. This is always going to be either the implementation or the last
// overload if the implementation doesn't exists.
public_functions.insert(function);
}
}

for function in self.called_functions.union(&public_functions) {
let Some(overloaded) = function.to_overloaded(self.db()) else {
continue;
};

// Check that the overloaded function has at least two overloads
if let [single_overload] = overloaded.overloads.as_slice() {
let function_node = function.node(self.db(), self.file());
if let Some(builder) = self
.context
.report_lint(&INVALID_OVERLOAD, &function_node.name)
{
let mut diagnostic = builder.into_diagnostic(format_args!(
"Overloaded function `{}` requires at least two overloads",
&function_node.name
));
diagnostic.annotate(
self.context
.secondary(single_overload.focus_range(self.db()))
.message(format_args!("Only one overload defined here")),
);
}
}
}
}

fn infer_region_definition(&mut self, definition: Definition<'db>) {
match definition.kind(self.db()) {
DefinitionKind::Function(function) => {
Expand Down Expand Up @@ -4298,6 +4405,18 @@ impl<'db> TypeInferenceBuilder<'db> {
let mut call_arguments = Self::parse_arguments(arguments);
let callable_type = self.infer_expression(func);

if let Type::FunctionLiteral(function) = callable_type {
// Make sure that the `function.definition` is only called when the function is defined
// in the same file as the one we're currently inferring the types for. This is because
// the `definition` method accesses the semantic index, which could create a
// cross-module AST dependency.
if function.file(self.db()) == self.file()
&& function.definition(self.db()).scope(self.db()) == self.scope()
{
self.called_functions.insert(function);
}
}

// It might look odd here that we emit an error for class-literals but not `type[]` types.
// But it's deliberate! The typing spec explicitly mandates that `type[]` types can be called
// even though class-literals cannot. This is because even though a protocol class `SomeProtocol`
Expand Down
Loading
Loading