Skip to content

Commit dffee4f

Browse files
committed
[red-knot] Check gradual equivalence between callable types
1 parent 6a8289a commit dffee4f

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

crates/red_knot_python_semantic/resources/mdtest/type_properties/is_gradual_equivalent_to.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,38 @@ static_assert(not is_gradual_equivalent_to(tuple[str, int], tuple[str, int, byte
6262
static_assert(not is_gradual_equivalent_to(tuple[str, int], tuple[int, str]))
6363
```
6464

65+
## Callable
66+
67+
```py
68+
from knot_extensions import Unknown, CallableTypeFromFunction, is_gradual_equivalent_to, static_assert
69+
from typing import Any, Callable
70+
71+
static_assert(is_gradual_equivalent_to(Callable[..., int], Callable[..., int]))
72+
static_assert(is_gradual_equivalent_to(Callable[..., Any], Callable[..., Unknown]))
73+
static_assert(is_gradual_equivalent_to(Callable[[int, Any], None], Callable[[int, Unknown], None]))
74+
75+
static_assert(not is_gradual_equivalent_to(Callable[[int, Any], None], Callable[[Any, int], None]))
76+
static_assert(not is_gradual_equivalent_to(Callable[[int, str], None], Callable[[int, str, bytes], None]))
77+
static_assert(not is_gradual_equivalent_to(Callable[..., None], Callable[[], None]))
78+
```
79+
80+
A function with no explicit return type should be gradual equivalent to a callable with a return
81+
type of `Any`.
82+
83+
```py
84+
def f1():
85+
return
86+
87+
static_assert(is_gradual_equivalent_to(CallableTypeFromFunction[f1], Callable[[], Any]))
88+
```
89+
90+
And, similarly for parameters with no annotations.
91+
92+
```py
93+
def f2(a, b) -> None:
94+
return
95+
96+
static_assert(is_gradual_equivalent_to(CallableTypeFromFunction[f2], Callable[[Any, Any], None]))
97+
```
98+
6599
[materializations]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-materialize

crates/red_knot_python_semantic/src/types.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,11 @@ impl<'db> Type<'db> {
976976
first.is_gradual_equivalent_to(db, second)
977977
}
978978

979+
(
980+
Type::Callable(CallableType::General(first)),
981+
Type::Callable(CallableType::General(second)),
982+
) => first.is_gradual_equivalent_to(db, second),
983+
979984
_ => false,
980985
}
981986
}
@@ -4548,6 +4553,49 @@ impl<'db> GeneralCallableType<'db> {
45484553
.return_ty
45494554
.is_some_and(|return_type| return_type.is_fully_static(db))
45504555
}
4556+
4557+
/// Return `true` if `self` has exactly the same set of possible static materializations as
4558+
/// `other` (if `self` represents the same set of possible sets of possible runtime objects as
4559+
/// `other`).
4560+
pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Self) -> bool {
4561+
let self_signature = self.signature(db);
4562+
let other_signature = other.signature(db);
4563+
4564+
if self_signature.parameters().is_gradual() != other_signature.parameters().is_gradual() {
4565+
return false;
4566+
}
4567+
4568+
if self_signature.parameters().len() != other_signature.parameters().len() {
4569+
return false;
4570+
}
4571+
4572+
// Check gradual equivalence between the two optional types. In the context of a callable
4573+
// type, the `None` type represents an `Unknown` type.
4574+
let are_optional_types_gradually_equivalent =
4575+
|self_type: Option<Type<'db>>, other_type: Option<Type<'db>>| {
4576+
self_type
4577+
.unwrap_or(Type::unknown())
4578+
.is_gradual_equivalent_to(db, other_type.unwrap_or(Type::unknown()))
4579+
};
4580+
4581+
if !are_optional_types_gradually_equivalent(
4582+
self_signature.return_ty,
4583+
other_signature.return_ty,
4584+
) {
4585+
return false;
4586+
}
4587+
4588+
self_signature
4589+
.parameters()
4590+
.iter()
4591+
.zip(other_signature.parameters().iter())
4592+
.all(|(self_param, other_param)| {
4593+
are_optional_types_gradually_equivalent(
4594+
self_param.annotated_type(),
4595+
other_param.annotated_type(),
4596+
)
4597+
})
4598+
}
45514599
}
45524600

45534601
/// A type that represents callable objects.

0 commit comments

Comments
 (0)