Skip to content

Commit 2380ade

Browse files
committed
Add CallableType::Specialized
1 parent 4346e7d commit 2380ade

File tree

4 files changed

+73
-1
lines changed

4 files changed

+73
-1
lines changed

crates/red_knot_python_semantic/src/types.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ use crate::symbol::{imported_symbol, Boundness, Symbol, SymbolAndQualifiers};
3535
use crate::types::call::{Bindings, CallArgumentTypes};
3636
use crate::types::class_base::ClassBase;
3737
use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION};
38+
use crate::types::generics::Specialization;
3839
use crate::types::infer::infer_unpack_types;
3940
use crate::types::mro::{Mro, MroError, MroIterator};
4041
pub(crate) use crate::types::narrow::infer_narrowing_constraint;
@@ -641,6 +642,10 @@ impl<'db> Type<'db> {
641642
.to_instance(db)
642643
.is_subtype_of(db, target),
643644

645+
(Type::Callable(CallableType::Specialized(specialized)), _) => {
646+
specialized.callable_type(db).is_subtype_of(db, target)
647+
}
648+
644649
// The same reasoning applies for these special callable types:
645650
(Type::Callable(CallableType::BoundMethod(_)), _) => KnownClass::MethodType
646651
.to_instance(db)
@@ -1048,6 +1053,7 @@ impl<'db> Type<'db> {
10481053
| Type::FunctionLiteral(..)
10491054
| Type::Callable(
10501055
CallableType::BoundMethod(..)
1056+
| CallableType::Specialized(..)
10511057
| CallableType::MethodWrapperDunderGet(..)
10521058
| CallableType::WrapperDescriptorDunderGet,
10531059
)
@@ -1062,6 +1068,7 @@ impl<'db> Type<'db> {
10621068
| Type::FunctionLiteral(..)
10631069
| Type::Callable(
10641070
CallableType::BoundMethod(..)
1071+
| CallableType::Specialized(..)
10651072
| CallableType::MethodWrapperDunderGet(..)
10661073
| CallableType::WrapperDescriptorDunderGet,
10671074
)
@@ -1236,6 +1243,11 @@ impl<'db> Type<'db> {
12361243
!KnownClass::FunctionType.is_subclass_of(db, class)
12371244
}
12381245

1246+
(Type::Callable(CallableType::Specialized(specialized)), Type::Instance(_))
1247+
| (Type::Instance(_), Type::Callable(CallableType::Specialized(specialized))) => {
1248+
specialized.callable_type(db).is_disjoint_from(db, other)
1249+
}
1250+
12391251
(
12401252
Type::Callable(CallableType::BoundMethod(_)),
12411253
Type::Instance(InstanceType { class }),
@@ -1359,6 +1371,9 @@ impl<'db> Type<'db> {
13591371
.iter()
13601372
.all(|elem| elem.is_fully_static(db)),
13611373
Type::Callable(CallableType::General(callable)) => callable.is_fully_static(db),
1374+
Type::Callable(CallableType::Specialized(specialized)) => {
1375+
specialized.callable_type(db).is_fully_static(db)
1376+
}
13621377
}
13631378
}
13641379

@@ -1398,6 +1413,9 @@ impl<'db> Type<'db> {
13981413
// signature.
13991414
false
14001415
}
1416+
Type::Callable(CallableType::Specialized(specialized)) => {
1417+
specialized.callable_type(db).is_singleton(db)
1418+
}
14011419
Type::Instance(InstanceType { class }) => {
14021420
class.known(db).is_some_and(KnownClass::is_singleton)
14031421
}
@@ -1447,6 +1465,10 @@ impl<'db> Type<'db> {
14471465
| Type::SliceLiteral(..)
14481466
| Type::KnownInstance(..) => true,
14491467

1468+
Type::Callable(CallableType::Specialized(specialized)) => {
1469+
specialized.callable_type(db).is_single_valued(db)
1470+
}
1471+
14501472
Type::SubclassOf(..) => {
14511473
// TODO: Same comment as above for `is_singleton`
14521474
false
@@ -1563,6 +1585,11 @@ impl<'db> Type<'db> {
15631585
.find_name_in_mro(db, name)
15641586
}
15651587

1588+
Type::Callable(CallableType::Specialized(specialized)) => {
1589+
// XXX: specialize the result
1590+
specialized.callable_type(db).find_name_in_mro(db, name)
1591+
}
1592+
15661593
Type::FunctionLiteral(_)
15671594
| Type::Callable(_)
15681595
| Type::ModuleLiteral(_)
@@ -1638,6 +1665,10 @@ impl<'db> Type<'db> {
16381665
Type::Callable(CallableType::BoundMethod(_)) => KnownClass::MethodType
16391666
.to_instance(db)
16401667
.instance_member(db, name),
1668+
Type::Callable(CallableType::Specialized(specialized)) => {
1669+
// XXX: specialize the result
1670+
specialized.callable_type(db).instance_member(db, name)
1671+
}
16411672
Type::Callable(CallableType::MethodWrapperDunderGet(_)) => {
16421673
KnownClass::MethodWrapperType
16431674
.to_instance(db)
@@ -1992,6 +2023,12 @@ impl<'db> Type<'db> {
19922023
})
19932024
}
19942025
},
2026+
Type::Callable(CallableType::Specialized(specialized)) => {
2027+
// XXX: specialize the result
2028+
specialized
2029+
.callable_type(db)
2030+
.member_lookup_with_policy(db, name, policy)
2031+
}
19952032
Type::Callable(CallableType::MethodWrapperDunderGet(_)) => {
19962033
KnownClass::MethodWrapperType
19972034
.to_instance(db)
@@ -3168,6 +3205,10 @@ impl<'db> Type<'db> {
31683205
Type::Callable(CallableType::BoundMethod(_)) => {
31693206
KnownClass::MethodType.to_class_literal(db)
31703207
}
3208+
Type::Callable(CallableType::Specialized(specialized)) => {
3209+
// XXX: specialize the result
3210+
specialized.callable_type(db).to_meta_type(db)
3211+
}
31713212
Type::Callable(CallableType::MethodWrapperDunderGet(_)) => {
31723213
KnownClass::MethodWrapperType.to_class_literal(db)
31733214
}
@@ -4313,6 +4354,18 @@ pub struct BoundMethodType<'db> {
43134354
self_instance: Type<'db>,
43144355
}
43154356

4357+
/// Represents the specialization of a callable that has access to generic typevars, either because
4358+
/// it is itself a generic function, or because it appears in the body of a generic class.
4359+
#[salsa::tracked(debug)]
4360+
pub struct SpecializedCallable<'db> {
4361+
/// The callable that has been specialized. (Note that this is not [`CallableType`] since there
4362+
/// are other types that are callable.)
4363+
pub(crate) callable_type: Type<'db>,
4364+
4365+
/// The specialization of any generic typevars that are visible to the callable.
4366+
pub(crate) specialization: Specialization<'db>,
4367+
}
4368+
43164369
/// This type represents a general callable type that are used to represent `typing.Callable`
43174370
/// and `lambda` expressions.
43184371
#[salsa::interned(debug)]
@@ -4843,6 +4896,11 @@ pub enum CallableType<'db> {
48434896
/// the bound instance when that type is displayed.
48444897
BoundMethod(BoundMethodType<'db>),
48454898

4899+
/// Represents the specialization of a callable that has access to generic typevars, either
4900+
/// because it is itself a generic function, or because it appears in the body of a generic
4901+
/// class.
4902+
Specialized(SpecializedCallable<'db>),
4903+
48464904
/// Represents the callable `f.__get__` where `f` is a function.
48474905
///
48484906
/// TODO: This could eventually be replaced by a more general `Callable` type that is

crates/red_knot_python_semantic/src/types/display.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,13 @@ impl Display for DisplayRepresentation<'_> {
100100
instance = bound_method.self_instance(self.db).display(self.db)
101101
)
102102
}
103+
Type::Callable(CallableType::Specialized(specialized)) => {
104+
write!(
105+
f,
106+
"<specialization of {callable}>",
107+
callable = specialized.callable_type(self.db).display(self.db),
108+
)
109+
}
103110
Type::Callable(CallableType::MethodWrapperDunderGet(function)) => {
104111
write!(
105112
f,

crates/red_knot_python_semantic/src/types/generics.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ impl<'db> GenericContext<'db> {
8686

8787
/// An assignment of a specific type to each type variable in a generic scope.
8888
#[salsa::tracked(debug)]
89-
pub(crate) struct Specialization<'db> {
89+
pub struct Specialization<'db> {
9090
generic_context: GenericContext<'db>,
9191
types: Box<[Type<'db>]>,
9292
}

crates/red_knot_python_semantic/src/types/type_ordering.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ pub(super) fn union_or_intersection_elements_ordering<'db>(
6969
(Type::Callable(CallableType::BoundMethod(_)), _) => Ordering::Less,
7070
(_, Type::Callable(CallableType::BoundMethod(_))) => Ordering::Greater,
7171

72+
(
73+
Type::Callable(CallableType::Specialized(left)),
74+
Type::Callable(CallableType::Specialized(right)),
75+
) => left.cmp(right),
76+
(Type::Callable(CallableType::Specialized(_)), _) => Ordering::Less,
77+
(_, Type::Callable(CallableType::Specialized(_))) => Ordering::Greater,
78+
7279
(
7380
Type::Callable(CallableType::MethodWrapperDunderGet(left)),
7481
Type::Callable(CallableType::MethodWrapperDunderGet(right)),

0 commit comments

Comments
 (0)