Skip to content

Commit 2a897c5

Browse files
dcreagerAlexWaygood
authored andcommitted
[ty] Infer parameter specializations of explicitly implemented generic protocols (astral-sh#18054)
Follows on from (and depends on) astral-sh#18021. This updates our function specialization inference to infer type mappings from parameters that are generic protocols. For now, this only works when the argument _explicitly_ implements the protocol by listing it as a base class. (We end up using exactly the same logic as for generic classes in astral-sh#18021.) For this to work with classes that _implicitly_ implement the protocol, we will have to check the types of the protocol members (which we are not currently doing), so that we can infer the specialization of the protocol that the class implements. --------- Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
1 parent c6d0cdf commit 2a897c5

File tree

6 files changed

+80
-58
lines changed

6 files changed

+80
-58
lines changed

crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,10 @@ def deeper_list(x: list[set[str]]) -> None:
9898
reveal_type(takes_in_protocol(x)) # revealed: Unknown
9999

100100
def deep_explicit(x: ExplicitlyImplements[str]) -> None:
101-
# TODO: revealed: str
102-
reveal_type(takes_in_protocol(x)) # revealed: Unknown
101+
reveal_type(takes_in_protocol(x)) # revealed: str
103102

104103
def deeper_explicit(x: ExplicitlyImplements[set[str]]) -> None:
105-
# TODO: revealed: set[str]
106-
reveal_type(takes_in_protocol(x)) # revealed: Unknown
104+
reveal_type(takes_in_protocol(x)) # revealed: set[str]
107105

108106
def takes_in_type(x: type[T]) -> type[T]:
109107
return x
@@ -128,10 +126,8 @@ reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: Unknown
128126
class ExplicitSub(ExplicitlyImplements[int]): ...
129127
class ExplicitGenericSub(ExplicitlyImplements[T]): ...
130128

131-
# TODO: revealed: int
132-
reveal_type(takes_in_protocol(ExplicitSub())) # revealed: Unknown
133-
# TODO: revealed: str
134-
reveal_type(takes_in_protocol(ExplicitGenericSub[str]())) # revealed: Unknown
129+
reveal_type(takes_in_protocol(ExplicitSub())) # revealed: int
130+
reveal_type(takes_in_protocol(ExplicitGenericSub[str]())) # revealed: str
135131
```
136132

137133
## Inferring a bound typevar

crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,10 @@ def deeper_list(x: list[set[str]]) -> None:
9393
reveal_type(takes_in_protocol(x)) # revealed: Unknown
9494

9595
def deep_explicit(x: ExplicitlyImplements[str]) -> None:
96-
# TODO: revealed: str
97-
reveal_type(takes_in_protocol(x)) # revealed: Unknown
96+
reveal_type(takes_in_protocol(x)) # revealed: str
9897

9998
def deeper_explicit(x: ExplicitlyImplements[set[str]]) -> None:
100-
# TODO: revealed: set[str]
101-
reveal_type(takes_in_protocol(x)) # revealed: Unknown
99+
reveal_type(takes_in_protocol(x)) # revealed: set[str]
102100

103101
def takes_in_type[T](x: type[T]) -> type[T]:
104102
return x
@@ -123,10 +121,8 @@ reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: Unknown
123121
class ExplicitSub(ExplicitlyImplements[int]): ...
124122
class ExplicitGenericSub[T](ExplicitlyImplements[T]): ...
125123

126-
# TODO: revealed: int
127-
reveal_type(takes_in_protocol(ExplicitSub())) # revealed: Unknown
128-
# TODO: revealed: str
129-
reveal_type(takes_in_protocol(ExplicitGenericSub[str]())) # revealed: Unknown
124+
reveal_type(takes_in_protocol(ExplicitSub())) # revealed: int
125+
reveal_type(takes_in_protocol(ExplicitGenericSub[str]())) # revealed: str
130126
```
131127

132128
## Inferring a bound typevar

crates/ty_python_semantic/src/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5329,7 +5329,7 @@ impl<'db> Type<'db> {
53295329

53305330
Self::TypeVar(var) => Some(TypeDefinition::TypeVar(var.definition(db))),
53315331

5332-
Self::ProtocolInstance(protocol) => match protocol.inner() {
5332+
Self::ProtocolInstance(protocol) => match protocol.inner {
53335333
Protocol::FromClass(class) => Some(TypeDefinition::Class(class.definition(db))),
53345334
Protocol::Synthesized(_) => None,
53355335
},

crates/ty_python_semantic/src/types/display.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ impl Display for DisplayRepresentation<'_> {
7676
(ClassType::Generic(alias), _) => alias.display(self.db).fmt(f),
7777
}
7878
}
79-
Type::ProtocolInstance(protocol) => match protocol.inner() {
79+
Type::ProtocolInstance(protocol) => match protocol.inner {
8080
Protocol::FromClass(ClassType::NonGeneric(class)) => {
8181
f.write_str(class.name(self.db))
8282
}

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use rustc_hash::FxHashMap;
44
use crate::semantic_index::SemanticIndex;
55
use crate::types::class::ClassType;
66
use crate::types::class_base::ClassBase;
7-
use crate::types::instance::NominalInstanceType;
7+
use crate::types::instance::{NominalInstanceType, Protocol, ProtocolInstanceType};
88
use crate::types::signatures::{Parameter, Parameters, Signature};
99
use crate::types::{
1010
declaration_type, todo_type, KnownInstanceType, Type, TypeVarBoundOrConstraints,
@@ -630,7 +630,10 @@ impl<'db> SpecializationBuilder<'db> {
630630
// ```
631631
//
632632
// without specializing `T` to `None`.
633-
if !actual.is_never() && actual.is_subtype_of(self.db, formal) {
633+
if !matches!(formal, Type::ProtocolInstance(_))
634+
&& !actual.is_never()
635+
&& actual.is_subtype_of(self.db, formal)
636+
{
634637
return Ok(());
635638
}
636639

@@ -678,6 +681,14 @@ impl<'db> SpecializationBuilder<'db> {
678681
Type::NominalInstance(NominalInstanceType {
679682
class: ClassType::Generic(formal_alias),
680683
..
684+
})
685+
// TODO: This will only handle classes that explicit implement a generic protocol
686+
// by listing it as a base class. To handle classes that implicitly implement a
687+
// generic protocol, we will need to check the types of the protocol members to be
688+
// able to infer the specialization of the protocol that the class implements.
689+
| Type::ProtocolInstance(ProtocolInstanceType {
690+
inner: Protocol::FromClass(ClassType::Generic(formal_alias)),
691+
..
681692
}),
682693
Type::NominalInstance(NominalInstanceType {
683694
class: actual_class,

crates/ty_python_semantic/src/types/instance.rs

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,9 @@ pub(super) use synthesized_protocol::SynthesizedProtocolType;
1414
impl<'db> Type<'db> {
1515
pub(crate) fn instance(db: &'db dyn Db, class: ClassType<'db>) -> Self {
1616
if class.class_literal(db).0.is_protocol(db) {
17-
Self::ProtocolInstance(ProtocolInstanceType(Protocol::FromClass(class)))
17+
Self::ProtocolInstance(ProtocolInstanceType::from_class(class))
1818
} else {
19-
Self::NominalInstance(NominalInstanceType {
20-
class,
21-
_phantom: PhantomData,
22-
})
19+
Self::NominalInstance(NominalInstanceType::from_class(class))
2320
}
2421
}
2522

@@ -34,9 +31,9 @@ impl<'db> Type<'db> {
3431
where
3532
M: IntoIterator<Item = (&'a str, Type<'db>)>,
3633
{
37-
Self::ProtocolInstance(ProtocolInstanceType(Protocol::Synthesized(
34+
Self::ProtocolInstance(ProtocolInstanceType::synthesized(
3835
SynthesizedProtocolType::new(db, ProtocolInterface::with_members(db, members)),
39-
)))
36+
))
4037
}
4138

4239
/// Return `true` if `self` conforms to the interface described by `protocol`.
@@ -51,7 +48,7 @@ impl<'db> Type<'db> {
5148
// TODO: this should consider the types of the protocol members
5249
// as well as whether each member *exists* on `self`.
5350
protocol
54-
.0
51+
.inner
5552
.interface(db)
5653
.members(db)
5754
.all(|member| !self.member(db, member.name()).symbol.is_unbound())
@@ -69,6 +66,15 @@ pub struct NominalInstanceType<'db> {
6966
}
7067

7168
impl<'db> NominalInstanceType<'db> {
69+
// Keep this method private, so that the only way of constructing `NominalInstanceType`
70+
// instances is through the `Type::instance` constructor function.
71+
fn from_class(class: ClassType<'db>) -> Self {
72+
Self {
73+
class,
74+
_phantom: PhantomData,
75+
}
76+
}
77+
7278
pub(super) fn is_subtype_of(self, db: &'db dyn Db, other: Self) -> bool {
7379
// N.B. The subclass relation is fully static
7480
self.class.is_subclass_of(db, other.class)
@@ -131,10 +137,7 @@ impl<'db> NominalInstanceType<'db> {
131137
db: &'db dyn Db,
132138
type_mapping: TypeMapping<'a, 'db>,
133139
) -> Self {
134-
Self {
135-
class: self.class.apply_type_mapping(db, type_mapping),
136-
_phantom: PhantomData,
137-
}
140+
Self::from_class(self.class.apply_type_mapping(db, type_mapping))
138141
}
139142

140143
pub(super) fn find_legacy_typevars(
@@ -155,21 +158,37 @@ impl<'db> From<NominalInstanceType<'db>> for Type<'db> {
155158
/// A `ProtocolInstanceType` represents the set of all possible runtime objects
156159
/// that conform to the interface described by a certain protocol.
157160
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord, salsa::Update)]
158-
pub struct ProtocolInstanceType<'db>(
161+
pub struct ProtocolInstanceType<'db> {
162+
pub(super) inner: Protocol<'db>,
163+
159164
// Keep the inner field here private,
160165
// so that the only way of constructing `ProtocolInstanceType` instances
161166
// is through the `Type::instance` constructor function.
162-
Protocol<'db>,
163-
);
167+
_phantom: PhantomData<()>,
168+
}
164169

165170
impl<'db> ProtocolInstanceType<'db> {
166-
pub(super) fn inner(self) -> Protocol<'db> {
167-
self.0
171+
// Keep this method private, so that the only way of constructing `ProtocolInstanceType`
172+
// instances is through the `Type::instance` constructor function.
173+
fn from_class(class: ClassType<'db>) -> Self {
174+
Self {
175+
inner: Protocol::FromClass(class),
176+
_phantom: PhantomData,
177+
}
178+
}
179+
180+
// Keep this method private, so that the only way of constructing `ProtocolInstanceType`
181+
// instances is through the `Type::instance` constructor function.
182+
fn synthesized(synthesized: SynthesizedProtocolType<'db>) -> Self {
183+
Self {
184+
inner: Protocol::Synthesized(synthesized),
185+
_phantom: PhantomData,
186+
}
168187
}
169188

170189
/// Return the meta-type of this protocol-instance type.
171190
pub(super) fn to_meta_type(self, db: &'db dyn Db) -> Type<'db> {
172-
match self.0 {
191+
match self.inner {
173192
Protocol::FromClass(class) => SubclassOfType::from(db, class),
174193

175194
// TODO: we can and should do better here.
@@ -197,35 +216,35 @@ impl<'db> ProtocolInstanceType<'db> {
197216
if object.satisfies_protocol(db, self) {
198217
return object;
199218
}
200-
match self.0 {
201-
Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized(
202-
SynthesizedProtocolType::new(db, self.0.interface(db)),
203-
))),
219+
match self.inner {
220+
Protocol::FromClass(_) => Type::ProtocolInstance(Self::synthesized(
221+
SynthesizedProtocolType::new(db, self.inner.interface(db)),
222+
)),
204223
Protocol::Synthesized(_) => Type::ProtocolInstance(self),
205224
}
206225
}
207226

208227
/// Replace references to `class` with a self-reference marker
209228
pub(super) fn replace_self_reference(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> Self {
210-
match self.0 {
229+
match self.inner {
211230
Protocol::FromClass(class_type) if class_type.class_literal(db).0 == class => {
212-
ProtocolInstanceType(Protocol::Synthesized(SynthesizedProtocolType::new(
231+
ProtocolInstanceType::synthesized(SynthesizedProtocolType::new(
213232
db,
214233
ProtocolInterface::SelfReference,
215-
)))
234+
))
216235
}
217236
_ => self,
218237
}
219238
}
220239

221240
/// Return `true` if any of the members of this protocol type contain any `Todo` types.
222241
pub(super) fn contains_todo(self, db: &'db dyn Db) -> bool {
223-
self.0.interface(db).contains_todo(db)
242+
self.inner.interface(db).contains_todo(db)
224243
}
225244

226245
/// Return `true` if this protocol type is fully static.
227246
pub(super) fn is_fully_static(self, db: &'db dyn Db) -> bool {
228-
self.0.interface(db).is_fully_static(db)
247+
self.inner.interface(db).is_fully_static(db)
229248
}
230249

231250
/// Return `true` if this protocol type is a subtype of the protocol `other`.
@@ -238,9 +257,9 @@ impl<'db> ProtocolInstanceType<'db> {
238257
/// TODO: consider the types of the members as well as their existence
239258
pub(super) fn is_assignable_to(self, db: &'db dyn Db, other: Self) -> bool {
240259
other
241-
.0
260+
.inner
242261
.interface(db)
243-
.is_sub_interface_of(db, self.0.interface(db))
262+
.is_sub_interface_of(db, self.inner.interface(db))
244263
}
245264

246265
/// Return `true` if this protocol type is equivalent to the protocol `other`.
@@ -269,7 +288,7 @@ impl<'db> ProtocolInstanceType<'db> {
269288
}
270289

271290
pub(crate) fn instance_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> {
272-
match self.inner() {
291+
match self.inner {
273292
Protocol::FromClass(class) => class.instance_member(db, name),
274293
Protocol::Synthesized(synthesized) => synthesized
275294
.interface()
@@ -287,13 +306,13 @@ impl<'db> ProtocolInstanceType<'db> {
287306
db: &'db dyn Db,
288307
type_mapping: TypeMapping<'a, 'db>,
289308
) -> Self {
290-
match self.0 {
291-
Protocol::FromClass(class) => Self(Protocol::FromClass(
292-
class.apply_type_mapping(db, type_mapping),
293-
)),
294-
Protocol::Synthesized(synthesized) => Self(Protocol::Synthesized(
295-
synthesized.apply_type_mapping(db, type_mapping),
296-
)),
309+
match self.inner {
310+
Protocol::FromClass(class) => {
311+
Self::from_class(class.apply_type_mapping(db, type_mapping))
312+
}
313+
Protocol::Synthesized(synthesized) => {
314+
Self::synthesized(synthesized.apply_type_mapping(db, type_mapping))
315+
}
297316
}
298317
}
299318

@@ -302,7 +321,7 @@ impl<'db> ProtocolInstanceType<'db> {
302321
db: &'db dyn Db,
303322
typevars: &mut FxOrderSet<TypeVarInstance<'db>>,
304323
) {
305-
match self.0 {
324+
match self.inner {
306325
Protocol::FromClass(class) => {
307326
class.find_legacy_typevars(db, typevars);
308327
}

0 commit comments

Comments
 (0)