Skip to content

Commit b3a1538

Browse files
committed
[red-knot] Emit a diagnostic if a non-protocol is passed to get_protocol_members
1 parent abc9ff6 commit b3a1538

File tree

4 files changed

+137
-12
lines changed

4 files changed

+137
-12
lines changed

crates/red_knot_python_semantic/resources/mdtest/protocols.md

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -375,15 +375,6 @@ class Foo(Protocol):
375375
reveal_type(get_protocol_members(Foo)) # revealed: @Todo(specialized non-generic class)
376376
```
377377

378-
Calling `get_protocol_members` on a non-protocol class raises an error at runtime:
379-
380-
```py
381-
class NotAProtocol: ...
382-
383-
# TODO: should emit `[invalid-protocol]` error, should reveal `Unknown`
384-
reveal_type(get_protocol_members(NotAProtocol)) # revealed: @Todo(specialized non-generic class)
385-
```
386-
387378
Certain special attributes and methods are not considered protocol members at runtime, and should
388379
not be considered protocol members by type checkers either:
389380

@@ -423,6 +414,34 @@ class Baz2(Bar, Foo, Protocol): ...
423414
reveal_type(get_protocol_members(Baz2)) # revealed: @Todo(specialized non-generic class)
424415
```
425416

417+
## Invalid calls to `get_protocol_members()`
418+
419+
<!-- snapshot-diagnostics -->
420+
421+
Calling `get_protocol_members` on a non-protocol class raises an error at runtime:
422+
423+
```toml
424+
[environment]
425+
python-version = "3.12"
426+
```
427+
428+
```py
429+
from typing_extensions import Protocol, get_protocol_members
430+
431+
class NotAProtocol: ...
432+
433+
get_protocol_members(NotAProtocol) # error: [invalid-argument-type]
434+
```
435+
436+
The original class object must be passed to the function; a specialised version of a generic version
437+
does not suffice:
438+
439+
```py
440+
class GenericProtocol[T](Protocol): ...
441+
442+
get_protocol_members(GenericProtocol[int]) # TODO: should emit a diagnostic here (https://github.com/astral-sh/ruff/issues/17549)
443+
```
444+
426445
## Subtyping of protocols with attribute members
427446

428447
In the following example, the protocol class `HasX` defines an interface such that any other fully
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
---
2+
source: crates/red_knot_test/src/lib.rs
3+
expression: snapshot
4+
---
5+
---
6+
mdtest name: protocols.md - Protocols - Invalid calls to `get_protocol_members()`
7+
mdtest path: crates/red_knot_python_semantic/resources/mdtest/protocols.md
8+
---
9+
10+
# Python source files
11+
12+
## mdtest_snippet.py
13+
14+
```
15+
1 | from typing_extensions import Protocol, get_protocol_members
16+
2 |
17+
3 | class NotAProtocol: ...
18+
4 |
19+
5 | get_protocol_members(NotAProtocol) # error: [invalid-argument-type]
20+
6 | class GenericProtocol[T](Protocol): ...
21+
7 |
22+
8 | get_protocol_members(GenericProtocol[int]) # TODO: should emit a diagnostic here (https://github.com/astral-sh/ruff/issues/17549)
23+
```
24+
25+
# Diagnostics
26+
27+
```
28+
error: lint:invalid-argument-type: Invalid argument to `get_protocol_members`
29+
--> /src/mdtest_snippet.py:5:1
30+
|
31+
3 | class NotAProtocol: ...
32+
4 |
33+
5 | get_protocol_members(NotAProtocol) # error: [invalid-argument-type]
34+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This call will raise `TypeError` at runtime
35+
6 | class GenericProtocol[T](Protocol): ...
36+
|
37+
info: Only protocol classes can be passed to `get_protocol_members`
38+
info: `NotAProtocol` is declared here, but it is not a protocol class:
39+
--> /src/mdtest_snippet.py:3:7
40+
|
41+
1 | from typing_extensions import Protocol, get_protocol_members
42+
2 |
43+
3 | class NotAProtocol: ...
44+
| ^^^^^^^^^^^^
45+
4 |
46+
5 | get_protocol_members(NotAProtocol) # error: [invalid-argument-type]
47+
|
48+
info: A class is only a protocol class if it explicitly inherits from `typing.Protocol` or `typing_extensions.Protocol`
49+
info: See https://typing.python.org/en/latest/spec/protocol.html#
50+
51+
```

crates/red_knot_python_semantic/src/types/diagnostic.rs

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use super::context::InferContext;
2+
use super::ClassLiteralType;
23
use crate::declare_lint;
34
use crate::lint::{Level, LintRegistryBuilder, LintStatus};
45
use crate::suppression::FileSuppressionId;
@@ -8,9 +9,9 @@ use crate::types::string_annotation::{
89
RAW_STRING_TYPE_ANNOTATION,
910
};
1011
use crate::types::{KnownInstanceType, Type};
11-
use ruff_db::diagnostic::{Annotation, Diagnostic, Span};
12+
use ruff_db::diagnostic::{Annotation, Diagnostic, Severity, Span, SubDiagnostic};
1213
use ruff_python_ast::{self as ast, AnyNodeRef};
13-
use ruff_text_size::Ranged;
14+
use ruff_text_size::{Ranged, TextRange};
1415
use rustc_hash::FxHashSet;
1516
use std::fmt::Formatter;
1617

@@ -1313,6 +1314,46 @@ pub(crate) fn report_invalid_arguments_to_annotated(
13131314
));
13141315
}
13151316

1317+
pub(crate) fn report_bad_argument_to_get_protocol_members(
1318+
context: &InferContext,
1319+
call: &ast::ExprCall,
1320+
class: ClassLiteralType,
1321+
) {
1322+
let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, call) else {
1323+
return;
1324+
};
1325+
let db = context.db();
1326+
let mut diagnostic = builder.into_diagnostic("Invalid argument to `get_protocol_members`");
1327+
diagnostic.set_primary_message("This call will raise `TypeError` at runtime");
1328+
diagnostic.info("Only protocol classes can be passed to `get_protocol_members`");
1329+
1330+
let class_scope = class.body_scope(db);
1331+
let class_node = class_scope.node(db).expect_class();
1332+
let class_name = &class_node.name;
1333+
let class_def_diagnostic_range = TextRange::new(
1334+
class_name.start(),
1335+
class_node
1336+
.arguments
1337+
.as_deref()
1338+
.map(Ranged::end)
1339+
.unwrap_or_else(|| class_name.end()),
1340+
);
1341+
let mut class_def_diagnostic = SubDiagnostic::new(
1342+
Severity::Info,
1343+
format_args!("`{class_name}` is declared here, but it is not a protocol class:"),
1344+
);
1345+
class_def_diagnostic.annotate(Annotation::primary(
1346+
Span::from(class_scope.file(db)).with_range(class_def_diagnostic_range),
1347+
));
1348+
diagnostic.sub(class_def_diagnostic);
1349+
1350+
diagnostic.info(
1351+
"A class is only a protocol class if it explicitly inherits \
1352+
from `typing.Protocol` or `typing_extensions.Protocol`",
1353+
);
1354+
diagnostic.info("See https://typing.python.org/en/latest/spec/protocol.html#");
1355+
}
1356+
13161357
pub(crate) fn report_invalid_arguments_to_callable(
13171358
context: &InferContext,
13181359
subscript: &ast::ExprSubscript,

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ use crate::Db;
9696

9797
use super::context::{InNoTypeCheck, InferContext};
9898
use super::diagnostic::{
99-
report_index_out_of_bounds, report_invalid_exception_caught, report_invalid_exception_cause,
99+
report_bad_argument_to_get_protocol_members, report_index_out_of_bounds,
100+
report_invalid_exception_caught, report_invalid_exception_cause,
100101
report_invalid_exception_raised, report_invalid_type_checking_constant,
101102
report_non_subscriptable, report_possibly_unresolved_reference, report_slice_step_size_zero,
102103
report_unresolved_reference, INVALID_METACLASS, INVALID_PROTOCOL, REDUNDANT_CAST,
@@ -4486,6 +4487,19 @@ impl<'db> TypeInferenceBuilder<'db> {
44864487
}
44874488
}
44884489
}
4490+
KnownFunction::GetProtocolMembers => {
4491+
if let [Some(Type::ClassLiteral(class))] =
4492+
overload.parameter_types()
4493+
{
4494+
if !class.is_protocol(self.db()) {
4495+
report_bad_argument_to_get_protocol_members(
4496+
&self.context,
4497+
call_expression,
4498+
*class,
4499+
);
4500+
}
4501+
}
4502+
}
44894503
_ => {}
44904504
}
44914505
}

0 commit comments

Comments
 (0)