Skip to content

Commit 00e73dc

Browse files
authored
[red-knot] Infer the members of a protocol class (#17556)
1 parent 7b62227 commit 00e73dc

File tree

4 files changed

+205
-28
lines changed

4 files changed

+205
-28
lines changed

crates/red_knot_python_semantic/resources/mdtest/protocols.md

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ reveal_type(Protocol()) # revealed: Unknown
315315
class MyProtocol(Protocol):
316316
x: int
317317

318-
# error
318+
# TODO: should emit error
319319
reveal_type(MyProtocol()) # revealed: MyProtocol
320320
```
321321

@@ -363,16 +363,8 @@ class Foo(Protocol):
363363
def method_member(self) -> bytes:
364364
return b"foo"
365365

366-
# TODO: at runtime, `get_protocol_members` returns a `frozenset`,
367-
# but for now we might pretend it returns a `tuple`, as we support heterogeneous `tuple` types
368-
# but not yet generic `frozenset`s
369-
#
370-
# So this should either be
371-
#
372-
# `tuple[Literal["x"], Literal["y"], Literal["z"], Literal["method_member"]]`
373-
#
374-
# `frozenset[Literal["x", "y", "z", "method_member"]]`
375-
reveal_type(get_protocol_members(Foo)) # revealed: @Todo(specialized non-generic class)
366+
# TODO: actually a frozenset (requires support for legacy generics)
367+
reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["method_member"], Literal["x"], Literal["y"], Literal["z"]]
376368
```
377369

378370
Certain special attributes and methods are not considered protocol members at runtime, and should
@@ -390,8 +382,8 @@ class Lumberjack(Protocol):
390382
def __init__(self, x: int) -> None:
391383
self.x = x
392384

393-
# TODO: `tuple[Literal["x"]]` or `frozenset[Literal["x"]]`
394-
reveal_type(get_protocol_members(Lumberjack)) # revealed: @Todo(specialized non-generic class)
385+
# TODO: actually a frozenset
386+
reveal_type(get_protocol_members(Lumberjack)) # revealed: tuple[Literal["x"]]
395387
```
396388

397389
A sub-protocol inherits and extends the members of its superclass protocol(s):
@@ -403,15 +395,42 @@ class Bar(Protocol):
403395
class Baz(Bar, Protocol):
404396
ham: memoryview
405397

406-
# TODO: `tuple[Literal["spam", "ham"]]` or `frozenset[Literal["spam", "ham"]]`
407-
reveal_type(get_protocol_members(Baz)) # revealed: @Todo(specialized non-generic class)
398+
# TODO: actually a frozenset
399+
reveal_type(get_protocol_members(Baz)) # revealed: tuple[Literal["ham"], Literal["spam"]]
408400

409401
class Baz2(Bar, Foo, Protocol): ...
410402

411-
# TODO: either
412-
# `tuple[Literal["spam"], Literal["x"], Literal["y"], Literal["z"], Literal["method_member"]]`
413-
# or `frozenset[Literal["spam", "x", "y", "z", "method_member"]]`
414-
reveal_type(get_protocol_members(Baz2)) # revealed: @Todo(specialized non-generic class)
403+
# TODO: actually a frozenset
404+
# revealed: tuple[Literal["method_member"], Literal["spam"], Literal["x"], Literal["y"], Literal["z"]]
405+
reveal_type(get_protocol_members(Baz2))
406+
```
407+
408+
## Protocol members in statically known branches
409+
410+
The list of protocol members does not include any members declared in branches that are statically
411+
known to be unreachable:
412+
413+
```toml
414+
[environment]
415+
python-version = "3.9"
416+
```
417+
418+
```py
419+
import sys
420+
from typing_extensions import Protocol, get_protocol_members
421+
422+
class Foo(Protocol):
423+
if sys.version_info >= (3, 10):
424+
a: int
425+
b = 42
426+
def c(self) -> None: ...
427+
else:
428+
d: int
429+
e = 56
430+
def f(self) -> None: ...
431+
432+
# TODO: actually a frozenset
433+
reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["d"], Literal["e"], Literal["f"]]
415434
```
416435

417436
## Invalid calls to `get_protocol_members()`
@@ -639,14 +658,14 @@ class LotsOfBindings(Protocol):
639658
case l: # TODO: this should error with `[invalid-protocol]` (`l` is not declared)
640659
...
641660

642-
# TODO: all bindings in the above class should be understood as protocol members,
643-
# even those that we complained about with a diagnostic
644-
reveal_type(get_protocol_members(LotsOfBindings)) # revealed: @Todo(specialized non-generic class)
661+
# TODO: actually a frozenset
662+
# revealed: tuple[Literal["Nested"], Literal["NestedProtocol"], Literal["a"], Literal["b"], Literal["c"], Literal["d"], Literal["e"], Literal["f"], Literal["g"], Literal["h"], Literal["i"], Literal["j"], Literal["k"], Literal["l"]]
663+
reveal_type(get_protocol_members(LotsOfBindings))
645664
```
646665

647666
Attribute members are allowed to have assignments in methods on the protocol class, just like
648-
non-protocol classes. Unlike other classes, however, *implicit* instance attributes -- those that
649-
are not declared in the class body -- are not allowed:
667+
non-protocol classes. Unlike other classes, however, instance attributes that are not declared in
668+
the class body are disallowed:
650669

651670
```py
652671
class Foo(Protocol):
@@ -655,11 +674,18 @@ class Foo(Protocol):
655674

656675
def __init__(self) -> None:
657676
self.x = 42 # fine
658-
self.a = 56 # error
677+
self.a = 56 # TODO: should emit diagnostic
678+
self.b: int = 128 # TODO: should emit diagnostic
659679

660680
def non_init_method(self) -> None:
661681
self.y = 64 # fine
662-
self.b = 72 # error
682+
self.c = 72 # TODO: should emit diagnostic
683+
684+
# Note: the list of members does not include `a`, `b` or `c`,
685+
# as none of these attributes is declared in the class body.
686+
#
687+
# TODO: actually a frozenset
688+
reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["non_init_method"], Literal["x"], Literal["y"]]
663689
```
664690

665691
If a protocol has 0 members, then all other types are assignable to it, and all fully static types

crates/red_knot_python_semantic/src/semantic_index/use_def.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,15 @@ impl<'db> UseDefMap<'db> {
437437
.map(|symbol_id| (symbol_id, self.public_declarations(symbol_id)))
438438
}
439439

440+
pub(crate) fn all_public_bindings<'map>(
441+
&'map self,
442+
) -> impl Iterator<Item = (ScopedSymbolId, BindingWithConstraintsIterator<'map, 'db>)> + 'map
443+
{
444+
(0..self.public_symbols.len())
445+
.map(ScopedSymbolId::from_usize)
446+
.map(|symbol_id| (symbol_id, self.public_bindings(symbol_id)))
447+
}
448+
440449
/// This function is intended to be called only once inside `TypeInferenceBuilder::infer_function_body`.
441450
pub(crate) fn can_implicit_return(&self, db: &dyn crate::Db) -> bool {
442451
!self

crates/red_knot_python_semantic/src/types/call/bind.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ use crate::types::generics::{Specialization, SpecializationBuilder};
2020
use crate::types::signatures::{Parameter, ParameterForm};
2121
use crate::types::{
2222
BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, KnownClass,
23-
KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType,
24-
WrapperDescriptorKind,
23+
KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, TupleType,
24+
UnionType, WrapperDescriptorKind,
2525
};
2626
use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic};
2727
use ruff_python_ast as ast;
@@ -561,6 +561,22 @@ impl<'db> Bindings<'db> {
561561
}
562562
}
563563

564+
Some(KnownFunction::GetProtocolMembers) => {
565+
if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() {
566+
if let Some(protocol_class) = class.into_protocol_class(db) {
567+
// TODO: actually a frozenset at runtime (requires support for legacy generic classes)
568+
overload.set_return_type(Type::Tuple(TupleType::new(
569+
db,
570+
protocol_class
571+
.protocol_members(db)
572+
.iter()
573+
.map(|member| Type::string_literal(db, member))
574+
.collect::<Box<[Type<'db>]>>(),
575+
)));
576+
}
577+
}
578+
}
579+
564580
Some(KnownFunction::Overload) => {
565581
// TODO: This can be removed once we understand legacy generics because the
566582
// typeshed definition for `typing.overload` is an identity function.

crates/red_knot_python_semantic/src/types/class.rs

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::hash::BuildHasherDefault;
2+
use std::ops::Deref;
23
use std::sync::{LazyLock, Mutex};
34

45
use super::{
@@ -13,6 +14,7 @@ use crate::types::signatures::{Parameter, Parameters};
1314
use crate::types::{
1415
CallableType, DataclassParams, DataclassTransformerParams, KnownInstanceType, Signature,
1516
};
17+
use crate::FxOrderSet;
1618
use crate::{
1719
module_resolver::file_to_module,
1820
semantic_index::{
@@ -1710,6 +1712,11 @@ impl<'db> ClassLiteralType<'db> {
17101712
Some(InheritanceCycle::Inherited)
17111713
}
17121714
}
1715+
1716+
/// Returns `Some` if this is a protocol class, `None` otherwise.
1717+
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClassLiteral<'db>> {
1718+
self.is_protocol(db).then_some(ProtocolClassLiteral(self))
1719+
}
17131720
}
17141721

17151722
impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
@@ -1721,6 +1728,125 @@ impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
17211728
}
17221729
}
17231730

1731+
/// Representation of a single `Protocol` class definition.
1732+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
1733+
pub(super) struct ProtocolClassLiteral<'db>(ClassLiteralType<'db>);
1734+
1735+
impl<'db> ProtocolClassLiteral<'db> {
1736+
/// Returns the protocol members of this class.
1737+
///
1738+
/// A protocol's members define the interface declared by the protocol.
1739+
/// They therefore determine how the protocol should behave with regards to
1740+
/// assignability and subtyping.
1741+
///
1742+
/// The list of members consists of all bindings and declarations that take place
1743+
/// in the protocol's class body, except for a list of excluded attributes which should
1744+
/// not be taken into account. (This list includes `__init__` and `__new__`, which can
1745+
/// legally be defined on protocol classes but do not constitute protocol members.)
1746+
///
1747+
/// It is illegal for a protocol class to have any instance attributes that are not declared
1748+
/// in the protocol's class body. If any are assigned to, they are not taken into account in
1749+
/// the protocol's list of members.
1750+
pub(super) fn protocol_members(self, db: &'db dyn Db) -> &'db ordermap::set::Slice<Name> {
1751+
/// The list of excluded members is subject to change between Python versions,
1752+
/// especially for dunders, but it probably doesn't matter *too* much if this
1753+
/// list goes out of date. It's up to date as of Python commit 87b1ea016b1454b1e83b9113fa9435849b7743aa
1754+
/// (<https://github.com/python/cpython/blob/87b1ea016b1454b1e83b9113fa9435849b7743aa/Lib/typing.py#L1776-L1791>)
1755+
fn excluded_from_proto_members(member: &str) -> bool {
1756+
matches!(
1757+
member,
1758+
"_is_protocol"
1759+
| "__non_callable_proto_members__"
1760+
| "__static_attributes__"
1761+
| "__orig_class__"
1762+
| "__match_args__"
1763+
| "__weakref__"
1764+
| "__doc__"
1765+
| "__parameters__"
1766+
| "__module__"
1767+
| "_MutableMapping__marker"
1768+
| "__slots__"
1769+
| "__dict__"
1770+
| "__new__"
1771+
| "__protocol_attrs__"
1772+
| "__init__"
1773+
| "__class_getitem__"
1774+
| "__firstlineno__"
1775+
| "__abstractmethods__"
1776+
| "__orig_bases__"
1777+
| "_is_runtime_protocol"
1778+
| "__subclasshook__"
1779+
| "__type_params__"
1780+
| "__annotations__"
1781+
| "__annotate__"
1782+
| "__annotate_func__"
1783+
| "__annotations_cache__"
1784+
)
1785+
}
1786+
1787+
#[salsa::tracked(return_ref)]
1788+
fn cached_protocol_members<'db>(
1789+
db: &'db dyn Db,
1790+
class: ClassLiteralType<'db>,
1791+
) -> Box<ordermap::set::Slice<Name>> {
1792+
let mut members = FxOrderSet::default();
1793+
1794+
for parent_protocol in class
1795+
.iter_mro(db, None)
1796+
.filter_map(ClassBase::into_class)
1797+
.filter_map(|class| class.class_literal(db).0.into_protocol_class(db))
1798+
{
1799+
let parent_scope = parent_protocol.body_scope(db);
1800+
let use_def_map = use_def_map(db, parent_scope);
1801+
let symbol_table = symbol_table(db, parent_scope);
1802+
1803+
members.extend(
1804+
use_def_map
1805+
.all_public_declarations()
1806+
.flat_map(|(symbol_id, declarations)| {
1807+
symbol_from_declarations(db, declarations)
1808+
.map(|symbol| (symbol_id, symbol))
1809+
})
1810+
.filter_map(|(symbol_id, symbol)| {
1811+
symbol.symbol.ignore_possibly_unbound().map(|_| symbol_id)
1812+
})
1813+
// Bindings in the class body that are not declared in the class body
1814+
// are not valid protocol members, and we plan to emit diagnostics for them
1815+
// elsewhere. Invalid or not, however, it's important that we still consider
1816+
// them to be protocol members. The implementation of `issubclass()` and
1817+
// `isinstance()` for runtime-checkable protocols considers them to be protocol
1818+
// members at runtime, and it's important that we accurately understand
1819+
// type narrowing that uses `isinstance()` or `issubclass()` with
1820+
// runtime-checkable protocols.
1821+
.chain(use_def_map.all_public_bindings().filter_map(
1822+
|(symbol_id, bindings)| {
1823+
symbol_from_bindings(db, bindings)
1824+
.ignore_possibly_unbound()
1825+
.map(|_| symbol_id)
1826+
},
1827+
))
1828+
.map(|symbol_id| symbol_table.symbol(symbol_id).name())
1829+
.filter(|name| !excluded_from_proto_members(name))
1830+
.cloned(),
1831+
);
1832+
}
1833+
1834+
members.sort();
1835+
members.into_boxed_slice()
1836+
}
1837+
1838+
cached_protocol_members(db, *self)
1839+
}
1840+
}
1841+
1842+
impl<'db> Deref for ProtocolClassLiteral<'db> {
1843+
type Target = ClassLiteralType<'db>;
1844+
1845+
fn deref(&self) -> &Self::Target {
1846+
&self.0
1847+
}
1848+
}
1849+
17241850
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
17251851
pub(super) enum InheritanceCycle {
17261852
/// The class is cyclically defined and is a participant in the cycle.

0 commit comments

Comments
 (0)