Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 52 additions & 26 deletions crates/red_knot_python_semantic/resources/mdtest/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ reveal_type(Protocol()) # revealed: Unknown
class MyProtocol(Protocol):
x: int

# error
# TODO: should emit error
reveal_type(MyProtocol()) # revealed: MyProtocol
```

Expand Down Expand Up @@ -363,16 +363,8 @@ class Foo(Protocol):
def method_member(self) -> bytes:
return b"foo"

# TODO: at runtime, `get_protocol_members` returns a `frozenset`,
# but for now we might pretend it returns a `tuple`, as we support heterogeneous `tuple` types
# but not yet generic `frozenset`s
#
# So this should either be
#
# `tuple[Literal["x"], Literal["y"], Literal["z"], Literal["method_member"]]`
#
# `frozenset[Literal["x", "y", "z", "method_member"]]`
reveal_type(get_protocol_members(Foo)) # revealed: @Todo(specialized non-generic class)
# TODO: actually a frozenset (requires support for legacy generics)
reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["method_member"], Literal["x"], Literal["y"], Literal["z"]]
```

Certain special attributes and methods are not considered protocol members at runtime, and should
Expand All @@ -390,8 +382,8 @@ class Lumberjack(Protocol):
def __init__(self, x: int) -> None:
self.x = x

# TODO: `tuple[Literal["x"]]` or `frozenset[Literal["x"]]`
reveal_type(get_protocol_members(Lumberjack)) # revealed: @Todo(specialized non-generic class)
# TODO: actually a frozenset
reveal_type(get_protocol_members(Lumberjack)) # revealed: tuple[Literal["x"]]
```

A sub-protocol inherits and extends the members of its superclass protocol(s):
Expand All @@ -403,15 +395,42 @@ class Bar(Protocol):
class Baz(Bar, Protocol):
ham: memoryview

# TODO: `tuple[Literal["spam", "ham"]]` or `frozenset[Literal["spam", "ham"]]`
reveal_type(get_protocol_members(Baz)) # revealed: @Todo(specialized non-generic class)
# TODO: actually a frozenset
reveal_type(get_protocol_members(Baz)) # revealed: tuple[Literal["ham"], Literal["spam"]]

class Baz2(Bar, Foo, Protocol): ...

# TODO: either
# `tuple[Literal["spam"], Literal["x"], Literal["y"], Literal["z"], Literal["method_member"]]`
# or `frozenset[Literal["spam", "x", "y", "z", "method_member"]]`
reveal_type(get_protocol_members(Baz2)) # revealed: @Todo(specialized non-generic class)
# TODO: actually a frozenset
# revealed: tuple[Literal["method_member"], Literal["spam"], Literal["x"], Literal["y"], Literal["z"]]
reveal_type(get_protocol_members(Baz2))
```

## Protocol members in statically known branches

The list of protocol members does not include any members declared in branches that are statically
known to be unreachable:

```toml
[environment]
python-version = "3.9"
```

```py
import sys
from typing_extensions import Protocol, get_protocol_members

class Foo(Protocol):
if sys.version_info >= (3, 10):
a: int
b = 42
def c(self) -> None: ...
else:
d: int
e = 56
def f(self) -> None: ...

# TODO: actually a frozenset
reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["d"], Literal["e"], Literal["f"]]
```

## Invalid calls to `get_protocol_members()`
Expand Down Expand Up @@ -639,14 +658,14 @@ class LotsOfBindings(Protocol):
case l: # TODO: this should error with `[invalid-protocol]` (`l` is not declared)
...

# TODO: all bindings in the above class should be understood as protocol members,
# even those that we complained about with a diagnostic
reveal_type(get_protocol_members(LotsOfBindings)) # revealed: @Todo(specialized non-generic class)
# TODO: actually a frozenset
# revealed: tuple[Literal["Nested"], Literal["NestedProtocol"], Literal["a"], Literal["b"], Literal["c"], Literal["d"], Literal["e"], Literal["f"], Literal["g"], Literal["h"], Literal["i"], Literal["j"], Literal["k"], Literal["l"]]
reveal_type(get_protocol_members(LotsOfBindings))
```

Attribute members are allowed to have assignments in methods on the protocol class, just like
non-protocol classes. Unlike other classes, however, *implicit* instance attributes -- those that
are not declared in the class body -- are not allowed:
non-protocol classes. Unlike other classes, however, instance attributes that are not declared in
the class body are disallowed:

```py
class Foo(Protocol):
Expand All @@ -655,11 +674,18 @@ class Foo(Protocol):

def __init__(self) -> None:
self.x = 42 # fine
self.a = 56 # error
self.a = 56 # TODO: should emit diagnostic
self.b: int = 128 # TODO: should emit diagnostic

def non_init_method(self) -> None:
self.y = 64 # fine
self.b = 72 # error
self.c = 72 # TODO: should emit diagnostic

# Note: the list of members does not include `a`, `b` or `c`,
# as none of these attributes is declared in the class body.
#
# TODO: actually a frozenset
reveal_type(get_protocol_members(Foo)) # revealed: tuple[Literal["non_init_method"], Literal["x"], Literal["y"]]
```

If a protocol has 0 members, then all other types are assignable to it, and all fully static types
Expand Down
9 changes: 9 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index/use_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,15 @@ impl<'db> UseDefMap<'db> {
.map(|symbol_id| (symbol_id, self.public_declarations(symbol_id)))
}

pub(crate) fn all_public_bindings<'map>(
&'map self,
) -> impl Iterator<Item = (ScopedSymbolId, BindingWithConstraintsIterator<'map, 'db>)> + 'map
{
(0..self.public_symbols.len())
.map(ScopedSymbolId::from_usize)
.map(|symbol_id| (symbol_id, self.public_bindings(symbol_id)))
}

/// This function is intended to be called only once inside `TypeInferenceBuilder::infer_function_body`.
pub(crate) fn can_implicit_return(&self, db: &dyn crate::Db) -> bool {
!self
Expand Down
20 changes: 18 additions & 2 deletions crates/red_knot_python_semantic/src/types/call/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use crate::types::generics::{Specialization, SpecializationBuilder};
use crate::types::signatures::{Parameter, ParameterForm};
use crate::types::{
BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, KnownClass,
KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType,
WrapperDescriptorKind,
KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, TupleType,
UnionType, WrapperDescriptorKind,
};
use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic};
use ruff_python_ast as ast;
Expand Down Expand Up @@ -561,6 +561,22 @@ impl<'db> Bindings<'db> {
}
}

Some(KnownFunction::GetProtocolMembers) => {
if let [Some(Type::ClassLiteral(class))] = overload.parameter_types() {
if let Some(protocol_class) = class.into_protocol_class(db) {
// TODO: actually a frozenset at runtime (requires support for legacy generic classes)
overload.set_return_type(Type::Tuple(TupleType::new(
db,
protocol_class
.protocol_members(db)
.iter()
.map(|member| Type::string_literal(db, member))
.collect::<Box<[Type<'db>]>>(),
)));
}
}
}

Some(KnownFunction::Overload) => {
// TODO: This can be removed once we understand legacy generics because the
// typeshed definition for `typing.overload` is an identity function.
Expand Down
126 changes: 126 additions & 0 deletions crates/red_knot_python_semantic/src/types/class.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::hash::BuildHasherDefault;
use std::ops::Deref;
use std::sync::{LazyLock, Mutex};

use super::{
Expand All @@ -13,6 +14,7 @@ use crate::types::signatures::{Parameter, Parameters};
use crate::types::{
CallableType, DataclassParams, DataclassTransformerParams, KnownInstanceType, Signature,
};
use crate::FxOrderSet;
use crate::{
module_resolver::file_to_module,
semantic_index::{
Expand Down Expand Up @@ -1710,6 +1712,11 @@ impl<'db> ClassLiteralType<'db> {
Some(InheritanceCycle::Inherited)
}
}

/// Returns `Some` if this is a protocol class, `None` otherwise.
pub(super) fn into_protocol_class(self, db: &'db dyn Db) -> Option<ProtocolClassLiteral<'db>> {
self.is_protocol(db).then_some(ProtocolClassLiteral(self))
}
}

impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
Expand All @@ -1721,6 +1728,125 @@ impl<'db> From<ClassLiteralType<'db>> for Type<'db> {
}
}

/// Representation of a single `Protocol` class definition.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(super) struct ProtocolClassLiteral<'db>(ClassLiteralType<'db>);

impl<'db> ProtocolClassLiteral<'db> {
/// Returns the protocol members of this class.
///
/// A protocol's members define the interface declared by the protocol.
/// They therefore determine how the protocol should behave with regards to
/// assignability and subtyping.
///
/// The list of members consists of all bindings and declarations that take place
/// in the protocol's class body, except for a list of excluded attributes which should
/// not be taken into account. (This list includes `__init__` and `__new__`, which can
/// legally be defined on protocol classes but do not constitute protocol members.)
///
/// It is illegal for a protocol class to have any instance attributes that are not declared
/// in the protocol's class body. If any are assigned to, they are not taken into account in
/// the protocol's list of members.
pub(super) fn protocol_members(self, db: &'db dyn Db) -> &'db ordermap::set::Slice<Name> {
/// The list of excluded members is subject to change between Python versions,
/// especially for dunders, but it probably doesn't matter *too* much if this
/// list goes out of date. It's up to date as of Python commit 87b1ea016b1454b1e83b9113fa9435849b7743aa
/// (<https://github.com/python/cpython/blob/87b1ea016b1454b1e83b9113fa9435849b7743aa/Lib/typing.py#L1776-L1791>)
fn excluded_from_proto_members(member: &str) -> bool {
matches!(
member,
"_is_protocol"
| "__non_callable_proto_members__"
| "__static_attributes__"
| "__orig_class__"
| "__match_args__"
| "__weakref__"
| "__doc__"
| "__parameters__"
| "__module__"
| "_MutableMapping__marker"
| "__slots__"
| "__dict__"
| "__new__"
| "__protocol_attrs__"
| "__init__"
| "__class_getitem__"
| "__firstlineno__"
| "__abstractmethods__"
| "__orig_bases__"
| "_is_runtime_protocol"
| "__subclasshook__"
| "__type_params__"
| "__annotations__"
| "__annotate__"
| "__annotate_func__"
| "__annotations_cache__"
)
}

#[salsa::tracked(return_ref)]
fn cached_protocol_members<'db>(
db: &'db dyn Db,
class: ClassLiteralType<'db>,
) -> Box<ordermap::set::Slice<Name>> {
let mut members = FxOrderSet::default();

for parent_protocol in class
.iter_mro(db, None)
.filter_map(ClassBase::into_class)
.filter_map(|class| class.class_literal(db).0.into_protocol_class(db))
{
let parent_scope = parent_protocol.body_scope(db);
let use_def_map = use_def_map(db, parent_scope);
let symbol_table = symbol_table(db, parent_scope);

members.extend(
use_def_map
.all_public_declarations()
.flat_map(|(symbol_id, declarations)| {
symbol_from_declarations(db, declarations)
.map(|symbol| (symbol_id, symbol))
})
.filter_map(|(symbol_id, symbol)| {
symbol.symbol.ignore_possibly_unbound().map(|_| symbol_id)
})
// Bindings in the class body that are not declared in the class body
// are not valid protocol members, and we plan to emit diagnostics for them
// elsewhere. Invalid or not, however, it's important that we still consider
// them to be protocol members. The implementation of `issubclass()` and
// `isinstance()` for runtime-checkable protocols considers them to be protocol
// members at runtime, and it's important that we accurately understand
// type narrowing that uses `isinstance()` or `issubclass()` with
// runtime-checkable protocols.
.chain(use_def_map.all_public_bindings().filter_map(
|(symbol_id, bindings)| {
symbol_from_bindings(db, bindings)
.ignore_possibly_unbound()
.map(|_| symbol_id)
},
))
.map(|symbol_id| symbol_table.symbol(symbol_id).name())
.filter(|name| !excluded_from_proto_members(name))
.cloned(),
);
}

members.sort();
members.into_boxed_slice()
}

cached_protocol_members(db, *self)
}
}

impl<'db> Deref for ProtocolClassLiteral<'db> {
type Target = ClassLiteralType<'db>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub(super) enum InheritanceCycle {
/// The class is cyclically defined and is a participant in the cycle.
Expand Down
Loading