Skip to content

Commit d301713

Browse files
committed
[red-knot] Detect (some) invalid protocols
1 parent 9ff4772 commit d301713

File tree

4 files changed

+76
-19
lines changed

4 files changed

+76
-19
lines changed

crates/red_knot_python_semantic/resources/mdtest/protocols.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,13 @@ If `Protocol` is present in the bases tuple, all other bases in the tuple must b
136136
or `TypeError` is raised at runtime when the class is created.
137137

138138
```py
139-
# TODO: should emit `[invalid-protocol]`
139+
# error: [invalid-protocol] "Protocol class `Invalid` cannot inherit from non-protocol class `NotAProtocol`"
140140
class Invalid(NotAProtocol, Protocol): ...
141141

142142
# revealed: tuple[Literal[Invalid], Literal[NotAProtocol], typing.Protocol, typing.Generic, Literal[object]]
143143
reveal_type(Invalid.__mro__)
144144

145-
# TODO: should emit an `[invalid-protocol`] error
145+
# error: [invalid-protocol] "Protocol class `AlsoInvalid` cannot inherit from non-protocol class `NotAProtocol`"
146146
class AlsoInvalid(MyProtocol, OtherProtocol, NotAProtocol, Protocol): ...
147147

148148
# revealed: tuple[Literal[AlsoInvalid], Literal[MyProtocol], Literal[OtherProtocol], Literal[NotAProtocol], typing.Protocol, typing.Generic, Literal[object]]

crates/red_knot_python_semantic/src/types/diagnostic.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
3636
registry.register_lint(&INVALID_EXCEPTION_CAUGHT);
3737
registry.register_lint(&INVALID_METACLASS);
3838
registry.register_lint(&INVALID_PARAMETER_DEFAULT);
39+
registry.register_lint(&INVALID_PROTOCOL);
3940
registry.register_lint(&INVALID_RAISE);
4041
registry.register_lint(&INVALID_SUPER_ARGUMENT);
4142
registry.register_lint(&INVALID_TYPE_CHECKING_CONSTANT);
@@ -230,6 +231,34 @@ declare_lint! {
230231
}
231232
}
232233

234+
declare_lint! {
235+
/// ## What it does
236+
/// Checks for invalidly defined protocol classes.
237+
///
238+
/// ## Why is this bad?
239+
/// An invalidly defined protocol class may lead to the type checker inferring
240+
/// unexpected things. It may also lead to `TypeError`s at runtime.
241+
///
242+
/// ## Examples
243+
/// A `Protocol` class cannot inherit from a non-`Protocol` class;
244+
/// this raises a `TypeError` at runtime:
245+
///
246+
/// ```pycon
247+
/// >>> from typing import Protocol
248+
/// >>> class Foo(int, Protocol): ...
249+
/// ...
250+
/// Traceback (most recent call last):
251+
/// File "<python-input-1>", line 1, in <module>
252+
/// class Foo(int, Protocol): ...
253+
/// TypeError: Protocols can only inherit from other protocols, got <class 'int'>
254+
/// ```
255+
pub(crate) static INVALID_PROTOCOL = {
256+
summary: "detects invalid protocol class definitions",
257+
status: LintStatus::preview("1.0.0"),
258+
default_level: Level::Error,
259+
}
260+
}
261+
233262
declare_lint! {
234263
/// TODO #14889
235264
pub(crate) static INCONSISTENT_MRO = {

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ use super::diagnostic::{
9999
report_index_out_of_bounds, report_invalid_exception_caught, report_invalid_exception_cause,
100100
report_invalid_exception_raised, report_invalid_type_checking_constant,
101101
report_non_subscriptable, report_possibly_unresolved_reference, report_slice_step_size_zero,
102-
report_unresolved_reference, INVALID_METACLASS, REDUNDANT_CAST, STATIC_ASSERT_ERROR,
103-
SUBCLASS_OF_FINAL_CLASS, TYPE_ASSERTION_FAILURE,
102+
report_unresolved_reference, INVALID_METACLASS, INVALID_PROTOCOL, REDUNDANT_CAST,
103+
STATIC_ASSERT_ERROR, SUBCLASS_OF_FINAL_CLASS, TYPE_ASSERTION_FAILURE,
104104
};
105105
use super::slots::check_class_slots;
106106
use super::string_annotation::{
@@ -763,17 +763,21 @@ impl<'db> TypeInferenceBuilder<'db> {
763763
continue;
764764
}
765765

766-
// (2) Check for inheritance from plain `Generic`,
767-
// and from classes that inherit from `@final` classes
766+
let is_protocol = class.is_protocol(self.db());
767+
768+
// (2) Iterate through the class's explicit bases to check for various possible errors:
769+
// - Check for inheritance from plain `Generic`,
770+
// - Check for inheritance from a `@final` classes
771+
// - If the class is a protocol class: check for inheritance from a non-protocol class
768772
for (i, base_class) in class.explicit_bases(self.db()).iter().enumerate() {
769773
let base_class = match base_class {
770774
Type::KnownInstance(KnownInstanceType::Generic) => {
771-
// `Generic` can appear in the MRO of many classes,
775+
// Unsubscripted `Generic` can appear in the MRO of many classes,
772776
// but it is never valid as an explicit base class in user code.
773777
self.context.report_lint_old(
774778
&INVALID_BASE,
775779
&class_node.bases()[i],
776-
format_args!("Cannot inherit from plain `Generic`",),
780+
format_args!("Cannot inherit from plain `Generic`"),
777781
);
778782
continue;
779783
}
@@ -782,18 +786,32 @@ impl<'db> TypeInferenceBuilder<'db> {
782786
_ => continue,
783787
};
784788

785-
if !base_class.is_final(self.db()) {
786-
continue;
789+
if is_protocol
790+
&& !(base_class.is_protocol(self.db())
791+
|| base_class.is_known(self.db(), KnownClass::Object))
792+
{
793+
self.context.report_lint_old(
794+
&INVALID_PROTOCOL,
795+
&class_node.bases()[i],
796+
format_args!(
797+
"Protocol class `{}` cannot inherit from non-protocol class `{}`",
798+
class.name(self.db()),
799+
base_class.name(self.db()),
800+
),
801+
);
802+
}
803+
804+
if base_class.is_final(self.db()) {
805+
self.context.report_lint_old(
806+
&SUBCLASS_OF_FINAL_CLASS,
807+
&class_node.bases()[i],
808+
format_args!(
809+
"Class `{}` cannot inherit from final class `{}`",
810+
class.name(self.db()),
811+
base_class.name(self.db()),
812+
),
813+
);
787814
}
788-
self.context.report_lint_old(
789-
&SUBCLASS_OF_FINAL_CLASS,
790-
&class_node.bases()[i],
791-
format_args!(
792-
"Class `{}` cannot inherit from final class `{}`",
793-
class.name(self.db()),
794-
base_class.name(self.db()),
795-
),
796-
);
797815
}
798816

799817
// (3) Check that the class's MRO is resolvable

knot.schema.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,16 @@
460460
}
461461
]
462462
},
463+
"invalid-protocol": {
464+
"title": "detects invalid protocol class definitions",
465+
"description": "## What it does\nChecks for invalidly defined protocol classes.\n\n## Why is this bad?\nAn invalidly defined protocol class may lead to the type checker inferring\nunexpected things. It may also lead to `TypeError`s at runtime.\n\n## Examples\nA `Protocol` class cannot inherit from a non-`Protocol` class;\nthis raises a `TypeError` at runtime:\n\n```pycon\n>>> from typing import Protocol\n>>> class Foo(int, Protocol): ...\n...\nTraceback (most recent call last):\n File \"<python-input-1>\", line 1, in <module>\n class Foo(int, Protocol): ...\nTypeError: Protocols can only inherit from other protocols, got <class 'int'>\n```",
466+
"default": "error",
467+
"oneOf": [
468+
{
469+
"$ref": "#/definitions/Level"
470+
}
471+
]
472+
},
463473
"invalid-raise": {
464474
"title": "detects `raise` statements that raise invalid exceptions or use invalid causes",
465475
"description": "Checks for `raise` statements that raise non-exceptions or use invalid\ncauses for their raised exceptions.\n\n## Why is this bad?\nOnly subclasses or instances of `BaseException` can be raised.\nFor an exception's cause, the same rules apply, except that `None` is also\npermitted. Violating these rules results in a `TypeError` at runtime.\n\n## Examples\n```python\ndef f():\n try:\n something()\n except NameError:\n raise \"oops!\" from f\n\ndef g():\n raise NotImplemented from 42\n```\n\nUse instead:\n```python\ndef f():\n try:\n something()\n except NameError as e:\n raise RuntimeError(\"oops!\") from e\n\ndef g():\n raise NotImplementedError from None\n```\n\n## References\n- [Python documentation: The `raise` statement](https://docs.python.org/3/reference/simple_stmts.html#raise)\n- [Python documentation: Built-in Exceptions](https://docs.python.org/3/library/exceptions.html#built-in-exceptions)",

0 commit comments

Comments
 (0)