Skip to content

Commit 9ff4772

Browse files
authored
[red-knot] Correctly identify protocol classes (#17487)
1 parent c077b10 commit 9ff4772

File tree

5 files changed

+90
-36
lines changed

5 files changed

+90
-36
lines changed

crates/red_knot_python_semantic/resources/mdtest/function/return_type.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@ class Baz(Bar):
7474
T = TypeVar("T")
7575

7676
class Qux(Protocol[T]):
77-
# TODO: no error
78-
# error: [invalid-return-type]
7977
def f(self) -> int: ...
8078

8179
class Foo(Protocol):

crates/red_knot_python_semantic/resources/mdtest/protocols.md

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,63 @@ class Foo(Protocol, Protocol): ... # error: [inconsistent-mro]
4040
reveal_type(Foo.__mro__) # revealed: tuple[Literal[Foo], Unknown, Literal[object]]
4141
```
4242

43+
Protocols can also be generic, either by including `Generic[]` in the bases list, subscripting
44+
`Protocol` directly in the bases list, using PEP-695 type parameters, or some combination of the
45+
above:
46+
47+
```py
48+
from typing import TypeVar, Generic
49+
50+
T = TypeVar("T")
51+
52+
class Bar0(Protocol[T]):
53+
x: T
54+
55+
class Bar1(Protocol[T], Generic[T]):
56+
x: T
57+
58+
class Bar2[T](Protocol):
59+
x: T
60+
61+
class Bar3[T](Protocol[T]):
62+
x: T
63+
```
64+
65+
It's an error to include both bare `Protocol` and subscripted `Protocol[]` in the bases list
66+
simultaneously:
67+
68+
```py
69+
# TODO: should emit a `[duplicate-bases]` error here:
70+
class DuplicateBases(Protocol, Protocol[T]):
71+
x: T
72+
73+
# TODO: should not have `Generic` multiple times and `Protocol` multiple times
74+
# revealed: tuple[Literal[DuplicateBases], typing.Protocol, typing.Generic, @Todo(`Protocol[]` subscript), @Todo(`Generic[]` subscript), Literal[object]]
75+
reveal_type(DuplicateBases.__mro__)
76+
```
77+
4378
The introspection helper `typing(_extensions).is_protocol` can be used to verify whether a class is
4479
a protocol class or not:
4580

4681
```py
4782
from typing_extensions import is_protocol
4883

49-
# TODO: should be `Literal[True]`
50-
reveal_type(is_protocol(MyProtocol)) # revealed: bool
84+
reveal_type(is_protocol(MyProtocol)) # revealed: Literal[True]
85+
reveal_type(is_protocol(Bar0)) # revealed: Literal[True]
86+
reveal_type(is_protocol(Bar1)) # revealed: Literal[True]
87+
reveal_type(is_protocol(Bar2)) # revealed: Literal[True]
88+
reveal_type(is_protocol(Bar3)) # revealed: Literal[True]
5189

5290
class NotAProtocol: ...
5391

54-
# TODO: should be `Literal[False]`
55-
reveal_type(is_protocol(NotAProtocol)) # revealed: bool
92+
reveal_type(is_protocol(NotAProtocol)) # revealed: Literal[False]
5693
```
5794

5895
A type checker should follow the typeshed stubs if a non-class is passed in, and typeshed's stubs
59-
indicate that the argument passed in must be an instance of `type`. `Literal[False]` should be
60-
inferred as the return type, however.
96+
indicate that the argument passed in must be an instance of `type`.
6197

6298
```py
63-
# TODO: the diagnostic is correct, but should infer `Literal[False]`
99+
# We could also reasonably infer `Literal[False]` here, but it probably doesn't matter that much:
64100
# error: [invalid-argument-type]
65101
reveal_type(is_protocol("not a class")) # revealed: bool
66102
```
@@ -74,8 +110,7 @@ class SubclassOfMyProtocol(MyProtocol): ...
74110
# revealed: tuple[Literal[SubclassOfMyProtocol], Literal[MyProtocol], typing.Protocol, typing.Generic, Literal[object]]
75111
reveal_type(SubclassOfMyProtocol.__mro__)
76112

77-
# TODO: should be `Literal[False]`
78-
reveal_type(is_protocol(SubclassOfMyProtocol)) # revealed: bool
113+
reveal_type(is_protocol(SubclassOfMyProtocol)) # revealed: Literal[False]
79114
```
80115

81116
A protocol class may inherit from other protocols, however, as long as it re-inherits from
@@ -84,8 +119,7 @@ A protocol class may inherit from other protocols, however, as long as it re-inh
84119
```py
85120
class SubProtocol(MyProtocol, Protocol): ...
86121

87-
# TODO: should be `Literal[True]`
88-
reveal_type(is_protocol(SubProtocol)) # revealed: bool
122+
reveal_type(is_protocol(SubProtocol)) # revealed: Literal[True]
89123

90124
class OtherProtocol(Protocol):
91125
some_attribute: str
@@ -95,8 +129,7 @@ class ComplexInheritance(SubProtocol, OtherProtocol, Protocol): ...
95129
# revealed: tuple[Literal[ComplexInheritance], Literal[SubProtocol], Literal[MyProtocol], Literal[OtherProtocol], typing.Protocol, typing.Generic, Literal[object]]
96130
reveal_type(ComplexInheritance.__mro__)
97131

98-
# TODO: should be `Literal[True]`
99-
reveal_type(is_protocol(ComplexInheritance)) # revealed: bool
132+
reveal_type(is_protocol(ComplexInheritance)) # revealed: Literal[True]
100133
```
101134

102135
If `Protocol` is present in the bases tuple, all other bases in the tuple must be protocol classes,
@@ -134,6 +167,8 @@ reveal_type(Fine.__mro__) # revealed: tuple[Literal[Fine], typing.Protocol, typ
134167

135168
class StillFine(Protocol, Generic[T], object): ...
136169
class EvenThis[T](Protocol, object): ...
170+
class OrThis(Protocol[T], Generic[T]): ...
171+
class AndThis(Protocol[T], Generic[T], object): ...
137172
```
138173

139174
And multiple inheritance from a mix of protocol and non-protocol classes is fine as long as
@@ -150,8 +185,7 @@ But if `Protocol` is not present in the bases list, the resulting class doesn't
150185
class anymore:
151186

152187
```py
153-
# TODO: should reveal `Literal[False]`
154-
reveal_type(is_protocol(FineAndDandy)) # revealed: bool
188+
reveal_type(is_protocol(FineAndDandy)) # revealed: Literal[False]
155189
```
156190

157191
A class does not *have* to inherit from a protocol class in order for it to be considered a subtype
@@ -230,9 +264,10 @@ class Foo(typing.Protocol):
230264
class Bar(typing_extensions.Protocol):
231265
x: int
232266

233-
# TODO: these should pass
234-
static_assert(typing_extensions.is_protocol(Foo)) # error: [static-assert-error]
235-
static_assert(typing_extensions.is_protocol(Bar)) # error: [static-assert-error]
267+
static_assert(typing_extensions.is_protocol(Foo))
268+
static_assert(typing_extensions.is_protocol(Bar))
269+
270+
# TODO: should pass
236271
static_assert(is_equivalent_to(Foo, Bar)) # error: [static-assert-error]
237272
```
238273

@@ -247,9 +282,10 @@ class RuntimeCheckableFoo(typing.Protocol):
247282
class RuntimeCheckableBar(typing_extensions.Protocol):
248283
x: int
249284

250-
# TODO: these should pass
251-
static_assert(typing_extensions.is_protocol(RuntimeCheckableFoo)) # error: [static-assert-error]
252-
static_assert(typing_extensions.is_protocol(RuntimeCheckableBar)) # error: [static-assert-error]
285+
static_assert(typing_extensions.is_protocol(RuntimeCheckableFoo))
286+
static_assert(typing_extensions.is_protocol(RuntimeCheckableBar))
287+
288+
# TODO: should pass
253289
static_assert(is_equivalent_to(RuntimeCheckableFoo, RuntimeCheckableBar)) # error: [static-assert-error]
254290

255291
# These should not error because the protocols are decorated with `@runtime_checkable`

crates/red_knot_python_semantic/src/types/call/bind.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,15 @@ impl<'db> Bindings<'db> {
535535
}
536536
}
537537

538+
Some(KnownFunction::IsProtocol) => {
539+
if let [Some(ty)] = overload.parameter_types() {
540+
overload.set_return_type(Type::BooleanLiteral(
541+
ty.into_class_literal()
542+
.is_some_and(|class| class.is_protocol(db)),
543+
));
544+
}
545+
}
546+
538547
Some(KnownFunction::Overload) => {
539548
// TODO: This can be removed once we understand legacy generics because the
540549
// typeshed definition for `typing.overload` is an identity function.

crates/red_knot_python_semantic/src/types/class.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,17 @@ impl<'db> ClassLiteralType<'db> {
582582
.collect()
583583
}
584584

585+
/// Determine if this class is a protocol.
586+
pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool {
587+
self.explicit_bases(db).iter().any(|base| {
588+
matches!(
589+
base,
590+
Type::KnownInstance(KnownInstanceType::Protocol)
591+
| Type::Dynamic(DynamicType::SubscriptedProtocol)
592+
)
593+
})
594+
}
595+
585596
/// Return the types of the decorators on this class
586597
#[salsa::tracked(return_ref)]
587598
fn decorators(self, db: &'db dyn Db) -> Box<[Type<'db>]> {

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ use crate::types::generics::GenericContext;
8181
use crate::types::mro::MroErrorKind;
8282
use crate::types::unpacker::{UnpackResult, Unpacker};
8383
use crate::types::{
84-
todo_type, CallDunderError, CallableSignature, CallableType, Class, ClassLiteralType,
85-
ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType, GenericAlias,
86-
GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction,
84+
binding_type, todo_type, CallDunderError, CallableSignature, CallableType, Class,
85+
ClassLiteralType, ClassType, DataclassMetadata, DynamicType, FunctionDecorators, FunctionType,
86+
GenericAlias, GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction,
8787
KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, NonGenericClass, Parameter,
8888
ParameterForm, Parameters, Signature, Signatures, SliceLiteralType, StringLiteralType,
8989
SubclassOfType, Symbol, SymbolAndQualifiers, Truthiness, TupleType, Type, TypeAliasType,
@@ -1224,7 +1224,7 @@ impl<'db> TypeInferenceBuilder<'db> {
12241224

12251225
/// Returns `true` if the current scope is the function body scope of a method of a protocol
12261226
/// (that is, a class which directly inherits `typing.Protocol`.)
1227-
fn in_class_that_inherits_protocol_directly(&self) -> bool {
1227+
fn in_protocol_class(&self) -> bool {
12281228
let current_scope_id = self.scope().file_scope_id(self.db());
12291229
let current_scope = self.index.scope(current_scope_id);
12301230
let Some(parent_scope_id) = current_scope.parent() else {
@@ -1252,13 +1252,13 @@ impl<'db> TypeInferenceBuilder<'db> {
12521252
return false;
12531253
};
12541254

1255-
// TODO move this to `Class` once we add proper `Protocol` support
1256-
node_ref.bases().iter().any(|base| {
1257-
matches!(
1258-
self.file_expression_type(base),
1259-
Type::KnownInstance(KnownInstanceType::Protocol)
1260-
)
1261-
})
1255+
let class_definition = self.index.expect_single_definition(node_ref.node());
1256+
1257+
let Type::ClassLiteral(class) = binding_type(self.db(), class_definition) else {
1258+
return false;
1259+
};
1260+
1261+
class.is_protocol(self.db())
12621262
}
12631263

12641264
/// Returns `true` if the current scope is the function body scope of a function overload (that
@@ -1322,7 +1322,7 @@ impl<'db> TypeInferenceBuilder<'db> {
13221322

13231323
if (self.in_stub()
13241324
|| self.in_function_overload_or_abstractmethod()
1325-
|| self.in_class_that_inherits_protocol_directly())
1325+
|| self.in_protocol_class())
13261326
&& self.return_types_and_ranges.is_empty()
13271327
&& is_stub_suite(&function.body)
13281328
{
@@ -1625,7 +1625,7 @@ impl<'db> TypeInferenceBuilder<'db> {
16251625
}
16261626
} else if (self.in_stub()
16271627
|| self.in_function_overload_or_abstractmethod()
1628-
|| self.in_class_that_inherits_protocol_directly())
1628+
|| self.in_protocol_class())
16291629
&& default
16301630
.as_ref()
16311631
.is_some_and(|d| d.is_ellipsis_literal_expr())

0 commit comments

Comments
 (0)