Skip to content

Commit afb9ea0

Browse files
committed
[red-knot] Infer the members of a protocol class
1 parent 0d68ffd commit afb9ea0

File tree

3 files changed

+115
-22
lines changed

3 files changed

+115
-22
lines changed

crates/red_knot_python_semantic/resources/mdtest/protocols.md

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -368,16 +368,8 @@ class Foo(Protocol):
368368
def method_member(self) -> bytes:
369369
return b"foo"
370370

371-
# TODO: at runtime, `get_protocol_members` returns a `frozenset`,
372-
# but for now we might pretend it returns a `tuple`, as we support heterogeneous `tuple` types
373-
# but not yet generic `frozenset`s
374-
#
375-
# So this should either be
376-
#
377-
# `tuple[Literal["x"], Literal["y"], Literal["z"], Literal["method_member"]]`
378-
#
379-
# `frozenset[Literal["x", "y", "z", "method_member"]]`
380-
reveal_type(get_protocol_members(Foo)) # revealed: @Todo(specialized non-generic class)
371+
# TODO: actually a frozenset (requires support for legacy generics)
372+
reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["method_member"], Literal["x"], Literal["y"], Literal["z"]]
381373
```
382374

383375
Calling `get_protocol_members` on a non-protocol class raises an error at runtime:
@@ -413,8 +405,8 @@ class Lumberjack(Protocol):
413405
def __init__(self, x: int) -> None:
414406
self.x = x
415407

416-
# TODO: `tuple[Literal["x"]]` or `frozenset[Literal["x"]]`
417-
reveal_type(get_protocol_members(Lumberjack)) # revealed: @Todo(specialized non-generic class)
408+
# TODO: actually a frozenset
409+
reveal_type(get_protocol_members(Lumberjack)) # revealed: tuple[Literal["x"]]
418410
```
419411

420412
A sub-protocol inherits and extends the members of its superclass protocol(s):
@@ -426,15 +418,14 @@ class Bar(Protocol):
426418
class Baz(Bar, Protocol):
427419
ham: memoryview
428420

429-
# TODO: `tuple[Literal["spam", "ham"]]` or `frozenset[Literal["spam", "ham"]]`
430-
reveal_type(get_protocol_members(Baz)) # revealed: @Todo(specialized non-generic class)
421+
# TODO: actually a frozenset
422+
reveal_type(get_protocol_members(Baz)) # revealed: tuple[Literal["ham"], Literal["spam"]]
431423

432424
class Baz2(Bar, Foo, Protocol): ...
433425

434-
# TODO: either
435-
# `tuple[Literal["spam"], Literal["x"], Literal["y"], Literal["z"], Literal["method_member"]]`
436-
# or `frozenset[Literal["spam", "x", "y", "z", "method_member"]]`
437-
reveal_type(get_protocol_members(Baz2)) # revealed: @Todo(specialized non-generic class)
426+
# TODO: actually a frozenset
427+
# revealed: tuple[Literal["method_member"], Literal["spam"], Literal["x"], Literal["y"], Literal["z"]]
428+
reveal_type(get_protocol_members(Baz2))
438429
```
439430

440431
## Subtyping of protocols with attribute members
@@ -630,9 +621,9 @@ class LotsOfBindings(Protocol):
630621
case l: # TODO: this should error with `[invalid-protocol]` (`l` is not declared)
631622
...
632623

633-
# TODO: all bindings in the above class should be understood as protocol members,
634-
# even those that we complained about with a diagnostic
635-
reveal_type(get_protocol_members(LotsOfBindings)) # revealed: @Todo(specialized non-generic class)
624+
# TODO: actually a frozenset
625+
# revealed: tuple[Literal["Nested"], Literal["NestedProtocol"], Literal["a"], Literal["b"], Literal["c"], Literal["d"], Literal["e"], Literal["f"], Literal["g"], Literal["h"], Literal["i"], Literal["j"], Literal["k"], Literal["l"]]
626+
reveal_type(get_protocol_members(LotsOfBindings))
636627
```
637628

638629
Attribute members are allowed to have assignments in methods on the protocol class, just like

crates/red_knot_python_semantic/src/types/call/bind.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::types::signatures::{Parameter, ParameterForm};
2121
use crate::types::{
2222
BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, FunctionType,
2323
KnownClass, KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType,
24-
UnionType, WrapperDescriptorKind,
24+
TupleType, UnionType, WrapperDescriptorKind,
2525
};
2626
use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic};
2727
use ruff_python_ast as ast;
@@ -568,6 +568,21 @@ impl<'db> Bindings<'db> {
568568
}
569569
}
570570

571+
Some(KnownFunction::GetProtocolMembers) => {
572+
if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() {
573+
if let Some(protocol_class) = class.into_protocol_class(db) {
574+
overload.set_return_type(Type::Tuple(TupleType::new(
575+
db,
576+
protocol_class
577+
.members(db)
578+
.iter()
579+
.map(|member| Type::string_literal(db, member))
580+
.collect::<Box<[Type<'db>]>>(),
581+
)));
582+
}
583+
}
584+
}
585+
571586
Some(KnownFunction::Overload) => {
572587
// TODO: This can be removed once we understand legacy generics because the
573588
// typeshed definition for `typing.overload` is an identity function.

crates/red_knot_python_semantic/src/types/class.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::hash::BuildHasherDefault;
2+
use std::ops::Deref;
23
use std::sync::{LazyLock, Mutex};
34

45
use super::{
@@ -13,6 +14,7 @@ use crate::types::signatures::{Parameter, Parameters};
1314
use crate::types::{
1415
CallableType, DataclassParams, DataclassTransformerParams, KnownInstanceType, Signature,
1516
};
17+
use crate::FxOrderSet;
1618
use crate::{
1719
module_resolver::file_to_module,
1820
semantic_index::{
@@ -1665,6 +1667,11 @@ impl<'db> ClassLiteralType<'db> {
16651667
Some(InheritanceCycle::Inherited)
16661668
}
16671669
}
1670+
1671+
/// Returns `Some` if this is a protocol class, `None` otherwise.
1672+
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClassLiteral<'db>> {
1673+
self.is_protocol(db).then_some(ProtocolClassLiteral(self))
1674+
}
16681675
}
16691676

16701677
impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
@@ -1676,6 +1683,86 @@ impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
16761683
}
16771684
}
16781685

1686+
/// Representation of a single `Protocol` class definition.
1687+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
1688+
pub(super) struct ProtocolClassLiteral<'db>(ClassLiteralType<'db>);
1689+
1690+
impl<'db> ProtocolClassLiteral<'db> {
1691+
pub(super) fn members(self, db: &'db dyn Db) -> &'db ordermap::set::Slice<Name> {
1692+
/// The list of excluded members is subject to change between Python versions,
1693+
/// especially for dunders, but it probably doesn't matter *too* much if this
1694+
/// list goes out of date. It's up-to-date as of Python commit 87b1ea016b1454b1e83b9113fa9435849b7743aa
1695+
/// (<https://github.com/python/cpython/blob/87b1ea016b1454b1e83b9113fa9435849b7743aa/Lib/typing.py#L1776-L1791>)
1696+
fn excluded_from_proto_members(member: &str) -> bool {
1697+
matches!(
1698+
member,
1699+
"_is_protocol"
1700+
| "__non_callable_proto_members__"
1701+
| "__static_attributes__"
1702+
| "__orig_class__"
1703+
| "__match_args__"
1704+
| "__weakref__"
1705+
| "__doc__"
1706+
| "__parameters__"
1707+
| "__module__"
1708+
| "_MutableMapping__marker"
1709+
| "__slots__"
1710+
| "__dict__"
1711+
| "__new__"
1712+
| "__protocol_attrs__"
1713+
| "__init__"
1714+
| "__class_getitem__"
1715+
| "__firstlineno__"
1716+
| "__abstractmethods__"
1717+
| "__orig_bases__"
1718+
| "_is_runtime_protocol"
1719+
| "__subclasshook__"
1720+
| "__type_params__"
1721+
| "__annotations__"
1722+
| "__annotate__"
1723+
| "__annotate_func__"
1724+
| "__annotations_cache__"
1725+
)
1726+
}
1727+
1728+
#[salsa::tracked(return_ref)]
1729+
fn cached_members<'db>(
1730+
db: &'db dyn Db,
1731+
class: ClassLiteralType<'db>,
1732+
) -> Box<ordermap::set::Slice<Name>> {
1733+
let mut members = FxOrderSet::default();
1734+
1735+
for parent_protocol in class
1736+
.iter_mro(db, None)
1737+
.filter_map(ClassBase::into_class)
1738+
.filter_map(|class| class.class_literal(db).0.into_protocol_class(db))
1739+
{
1740+
members.extend(
1741+
symbol_table(db, parent_protocol.body_scope(db))
1742+
.symbols()
1743+
.filter(|symbol| symbol.is_bound() || symbol.is_declared())
1744+
.map(crate::semantic_index::symbol::Symbol::name)
1745+
.filter(|name| !excluded_from_proto_members(name))
1746+
.cloned(),
1747+
);
1748+
}
1749+
1750+
members.sort();
1751+
members.into_boxed_slice()
1752+
}
1753+
1754+
cached_members(db, *self)
1755+
}
1756+
}
1757+
1758+
impl<'db> Deref for ProtocolClassLiteral<'db> {
1759+
type Target = ClassLiteralType<'db>;
1760+
1761+
fn deref(&self) -> &Self::Target {
1762+
&self.0
1763+
}
1764+
}
1765+
16791766
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
16801767
pub(super) enum InheritanceCycle {
16811768
/// The class is cyclically defined and is a participant in the cycle.

0 commit comments

Comments
 (0)