diff --git a/crates/red_knot_python_semantic/resources/mdtest/overloads.md b/crates/red_knot_python_semantic/resources/mdtest/overloads.md index 764b6a44123ed..4bbef31b1d735 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/overloads.md +++ b/crates/red_knot_python_semantic/resources/mdtest/overloads.md @@ -309,18 +309,29 @@ reveal_type(func("")) # revealed: Literal[""] ### At least two overloads + + At least two `@overload`-decorated definitions must be present. ```py from typing import overload -# TODO: error @overload def func(x: int) -> int: ... + +# error: [invalid-overload] def func(x: int | str) -> int | str: return x ``` +```pyi +from typing import overload + +@overload +# error: [invalid-overload] +def func(x: int) -> int: ... +``` + ### Overload without an implementation #### Regular modules diff --git a/crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_At_least_two_overloads.snap b/crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_At_least_two_overloads.snap new file mode 100644 index 0000000000000..246dff3c952ba --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/snapshots/overloads.md_-_Overloads_-_Invalid_-_At_least_two_overloads.snap @@ -0,0 +1,65 @@ +--- +source: crates/red_knot_test/src/lib.rs +expression: snapshot +--- +--- +mdtest name: overloads.md - Overloads - Invalid - At least two overloads +mdtest path: crates/red_knot_python_semantic/resources/mdtest/overloads.md +--- + +# Python source files + +## mdtest_snippet.py + +``` +1 | from typing import overload +2 | +3 | @overload +4 | def func(x: int) -> int: ... +5 | +6 | # error: [invalid-overload] +7 | def func(x: int | str) -> int | str: +8 | return x +``` + +## mdtest_snippet.pyi + +``` +1 | from typing import overload +2 | +3 | @overload +4 | # error: [invalid-overload] +5 | def func(x: int) -> int: ... +``` + +# Diagnostics + +``` +error: lint:invalid-overload: Overloaded function `func` requires at least two overloads + --> src/mdtest_snippet.py:4:5 + | +3 | @overload +4 | def func(x: int) -> int: ... + | ---- Only one overload defined here +5 | +6 | # error: [invalid-overload] +7 | def func(x: int | str) -> int | str: + | ^^^^ +8 | return x + | + +``` + +``` +error: lint:invalid-overload: Overloaded function `func` requires at least two overloads + --> src/mdtest_snippet.pyi:5:5 + | +3 | @overload +4 | # error: [invalid-overload] +5 | def func(x: int) -> int: ... + | ---- + | | + | Only one overload defined here + | + +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 335e4d8e1c13f..df018024fc625 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -6375,6 +6375,13 @@ pub struct FunctionType<'db> { #[salsa::tracked] impl<'db> FunctionType<'db> { + /// Returns the [`File`] in which this function is defined. + pub(crate) fn file(self, db: &'db dyn Db) -> File { + // NOTE: Do not use `self.definition(db).file(db)` here, as that could create a + // cross-module dependency on the full AST. + self.body_scope(db).file(db) + } + pub(crate) fn has_known_decorator(self, db: &dyn Db, decorator: FunctionDecorators) -> bool { self.decorators(db).contains(decorator) } @@ -6396,21 +6403,41 @@ impl<'db> FunctionType<'db> { Type::BoundMethod(BoundMethodType::new(db, self, self_instance)) } + /// Returns the AST node for this function. + pub(crate) fn node(self, db: &'db dyn Db, file: File) -> &'db ast::StmtFunctionDef { + debug_assert_eq!( + file, + self.file(db), + "FunctionType::node() must be called with the same file as the one where \ + the function is defined." + ); + + self.body_scope(db).node(db).expect_function() + } + /// Returns the [`FileRange`] of the function's name. pub fn focus_range(self, db: &dyn Db) -> FileRange { FileRange::new( - self.body_scope(db).file(db), + self.file(db), self.body_scope(db).node(db).expect_function().name.range, ) } pub fn full_range(self, db: &dyn Db) -> FileRange { FileRange::new( - self.body_scope(db).file(db), + self.file(db), self.body_scope(db).node(db).expect_function().range, ) } + /// Returns the [`Definition`] of this function. + /// + /// ## Warning + /// + /// This uses the semantic index to find the definition of the function. This means that if the + /// calling query is not in the same file as this function is defined in, then this will create + /// a cross-module dependency directly on the full AST which will lead to cache + /// over-invalidation. pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { let body_scope = self.body_scope(db); let index = semantic_index(db, body_scope.file(db)); diff --git a/crates/red_knot_python_semantic/src/types/diagnostic.rs b/crates/red_knot_python_semantic/src/types/diagnostic.rs index 63d5a18e4baef..f56300489085d 100644 --- a/crates/red_knot_python_semantic/src/types/diagnostic.rs +++ b/crates/red_knot_python_semantic/src/types/diagnostic.rs @@ -37,6 +37,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) { registry.register_lint(&INVALID_EXCEPTION_CAUGHT); registry.register_lint(&INVALID_LEGACY_TYPE_VARIABLE); registry.register_lint(&INVALID_METACLASS); + registry.register_lint(&INVALID_OVERLOAD); registry.register_lint(&INVALID_PARAMETER_DEFAULT); registry.register_lint(&INVALID_PROTOCOL); registry.register_lint(&INVALID_RAISE); @@ -447,6 +448,49 @@ declare_lint! { } } +declare_lint! { + /// ## What it does + /// Checks for various invalid `@overload` usages. + /// + /// ## Why is this bad? + /// The `@overload` decorator is used to define functions and methods that accepts different + /// combinations of arguments and return different types based on the arguments passed. This is + /// mainly beneficial for type checkers. But, if the `@overload` usage is invalid, the type + /// checker may not be able to provide correct type information. + /// + /// ## Example + /// + /// Defining only one overload: + /// + /// ```py + /// from typing import overload + /// + /// @overload + /// def foo(x: int) -> int: ... + /// def foo(x: int | None) -> int | None: + /// return x + /// ``` + /// + /// Or, not providing an implementation for the overloaded definition: + /// + /// ```py + /// from typing import overload + /// + /// @overload + /// def foo() -> None: ... + /// @overload + /// def foo(x: int) -> int: ... + /// ``` + /// + /// ## References + /// - [Python documentation: `@overload`](https://docs.python.org/3/library/typing.html#typing.overload) + pub(crate) static INVALID_OVERLOAD = { + summary: "detects invalid `@overload` usages", + status: LintStatus::preview("1.0.0"), + default_level: Level::Error, + } +} + declare_lint! { /// ## What it does /// Checks for default values that can't be assigned to the parameter's annotated type. diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index db5b2a53197f0..f6bfe29b48c74 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -101,8 +101,8 @@ use super::diagnostic::{ report_invalid_exception_raised, report_invalid_type_checking_constant, report_non_subscriptable, report_possibly_unresolved_reference, report_runtime_check_against_non_runtime_checkable_protocol, report_slice_step_size_zero, - report_unresolved_reference, INVALID_METACLASS, INVALID_PROTOCOL, REDUNDANT_CAST, - STATIC_ASSERT_ERROR, SUBCLASS_OF_FINAL_CLASS, TYPE_ASSERTION_FAILURE, + report_unresolved_reference, INVALID_METACLASS, INVALID_OVERLOAD, INVALID_PROTOCOL, + REDUNDANT_CAST, STATIC_ASSERT_ERROR, SUBCLASS_OF_FINAL_CLASS, TYPE_ASSERTION_FAILURE, }; use super::slots::check_class_slots; use super::string_annotation::{ @@ -418,7 +418,7 @@ impl<'db> TypeInference<'db> { .copied() .or(self.cycle_fallback_type) .expect( - "definition should belong to this TypeInference region and + "definition should belong to this TypeInference region and \ TypeInferenceBuilder should have inferred a type for it", ) } @@ -430,7 +430,7 @@ impl<'db> TypeInference<'db> { .copied() .or(self.cycle_fallback_type.map(Into::into)) .expect( - "definition should belong to this TypeInference region and + "definition should belong to this TypeInference region and \ TypeInferenceBuilder should have inferred a type for it", ) } @@ -524,6 +524,31 @@ pub(super) struct TypeInferenceBuilder<'db> { /// The returned types and their corresponding ranges of the region, if it is a function body. return_types_and_ranges: Vec>, + /// A set of functions that have been defined **and** called in this region. + /// + /// This is a set because the same function could be called multiple times in the same region. + /// This is mainly used in [`check_overloaded_functions`] to check an overloaded function that + /// is shadowed by a function with the same name in this scope but has been called before. For + /// example: + /// + /// ```py + /// from typing import overload + /// + /// @overload + /// def foo() -> None: ... + /// @overload + /// def foo(x: int) -> int: ... + /// def foo(x: int | None) -> int | None: return x + /// + /// foo() # An overloaded function that was defined in this scope have been called + /// + /// def foo(x: int) -> int: + /// return x + /// ``` + /// + /// [`check_overloaded_functions`]: TypeInferenceBuilder::check_overloaded_functions + called_functions: FxHashSet>, + /// The deferred state of inferring types of certain expressions within the region. /// /// This is different from [`InferenceRegion::Deferred`] which works on the entire definition @@ -556,6 +581,7 @@ impl<'db> TypeInferenceBuilder<'db> { index, region, return_types_and_ranges: vec![], + called_functions: FxHashSet::default(), deferred_state: DeferredExpressionState::None, types: TypeInference::empty(scope), } @@ -718,6 +744,7 @@ impl<'db> TypeInferenceBuilder<'db> { // TODO: Only call this function when diagnostics are enabled. self.check_class_definitions(); + self.check_overloaded_functions(); } /// Iterate over all class definitions to check that the definition will not cause an exception @@ -952,6 +979,86 @@ impl<'db> TypeInferenceBuilder<'db> { } } + /// Check the overloaded functions in this scope. + /// + /// This only checks the overloaded functions that are: + /// 1. Visible publicly at the end of this scope + /// 2. Or, defined and called in this scope + /// + /// For (1), this has the consequence of not checking an overloaded function that is being + /// shadowed by another function with the same name in this scope. + fn check_overloaded_functions(&mut self) { + // Collect all the unique overloaded function symbols in this scope. This requires a set + // because an overloaded function uses the same symbol for each of the overloads and the + // implementation. + let overloaded_function_symbols: FxHashSet<_> = self + .types + .declarations + .iter() + .filter_map(|(definition, ty)| { + // Filter out function literals that result from anything other than a function + // definition e.g., imports which would create a cross-module AST dependency. + if !matches!(definition.kind(self.db()), DefinitionKind::Function(_)) { + return None; + } + let function = ty.inner_type().into_function_literal()?; + if function.has_known_decorator(self.db(), FunctionDecorators::OVERLOAD) { + Some(definition.symbol(self.db())) + } else { + None + } + }) + .collect(); + + let use_def = self + .index + .use_def_map(self.scope().file_scope_id(self.db())); + + let mut public_functions = FxHashSet::default(); + + for symbol in overloaded_function_symbols { + if let Symbol::Type(Type::FunctionLiteral(function), Boundness::Bound) = + symbol_from_bindings(self.db(), use_def.public_bindings(symbol)) + { + if function.file(self.db()) != self.file() { + // If the function is not in this file, we don't need to check it. + // https://github.com/astral-sh/ruff/pull/17609#issuecomment-2839445740 + continue; + } + + // Extend the functions that we need to check with the publicly visible overloaded + // function. This is always going to be either the implementation or the last + // overload if the implementation doesn't exists. + public_functions.insert(function); + } + } + + for function in self.called_functions.union(&public_functions) { + let Some(overloaded) = function.to_overloaded(self.db()) else { + continue; + }; + + // Check that the overloaded function has at least two overloads + if let [single_overload] = overloaded.overloads.as_slice() { + let function_node = function.node(self.db(), self.file()); + if let Some(builder) = self + .context + .report_lint(&INVALID_OVERLOAD, &function_node.name) + { + let mut diagnostic = builder.into_diagnostic(format_args!( + "Overloaded function `{}` requires at least two overloads", + &function_node.name + )); + diagnostic.annotate( + self.context + .secondary(single_overload.focus_range(self.db())) + .message(format_args!("Only one overload defined here")), + ); + } + } + } + } + fn infer_region_definition(&mut self, definition: Definition<'db>) { match definition.kind(self.db()) { DefinitionKind::Function(function) => { @@ -4298,6 +4405,18 @@ impl<'db> TypeInferenceBuilder<'db> { let mut call_arguments = Self::parse_arguments(arguments); let callable_type = self.infer_expression(func); + if let Type::FunctionLiteral(function) = callable_type { + // Make sure that the `function.definition` is only called when the function is defined + // in the same file as the one we're currently inferring the types for. This is because + // the `definition` method accesses the semantic index, which could create a + // cross-module AST dependency. + if function.file(self.db()) == self.file() + && function.definition(self.db()).scope(self.db()) == self.scope() + { + self.called_functions.insert(function); + } + } + // It might look odd here that we emit an error for class-literals but not `type[]` types. // But it's deliberate! The typing spec explicitly mandates that `type[]` types can be called // even though class-literals cannot. This is because even though a protocol class `SomeProtocol` diff --git a/knot.schema.json b/knot.schema.json index 66e0a5b2e304a..aeda31fd530bc 100644 --- a/knot.schema.json +++ b/knot.schema.json @@ -470,6 +470,16 @@ } ] }, + "invalid-overload": { + "title": "detects invalid `@overload` usages", + "description": "## What it does\nChecks for various invalid `@overload` usages.\n\n## Why is this bad?\nThe `@overload` decorator is used to define functions and methods that accepts different\ncombinations of arguments and return different types based on the arguments passed. This is\nmainly beneficial for type checkers. But, if the `@overload` usage is invalid, the type\nchecker may not be able to provide correct type information.\n\n## Example\n\nDefining only one overload:\n\n```py\nfrom typing import overload\n\n@overload\ndef foo(x: int) -> int: ...\ndef foo(x: int | None) -> int | None:\n return x\n```\n\nOr, not providing an implementation for the overloaded definition:\n\n```py\nfrom typing import overload\n\n@overload\ndef foo() -> None: ...\n@overload\ndef foo(x: int) -> int: ...\n```\n\n## References\n- [Python documentation: `@overload`](https://docs.python.org/3/library/typing.html#typing.overload)", + "default": "error", + "oneOf": [ + { + "$ref": "#/definitions/Level" + } + ] + }, "invalid-parameter-default": { "title": "detects default values that can't be assigned to the parameter's annotated type", "description": "## What it does\nChecks for default values that can't be assigned to the parameter's annotated type.\n\n## Why is this bad?\nTODO #14889",