Skip to content

Commit 0837d74

Browse files
committed
Infer specialization of protocol parameter
1 parent 022b85a commit 0837d74

File tree

3 files changed

+98
-19
lines changed

3 files changed

+98
-19
lines changed

crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,42 @@ reveal_type(f("string")) # revealed: Literal["string"]
6666
## Inferring “deep” generic parameter types
6767

6868
The matching up of call arguments and discovery of constraints on typevars can be a recursive
69-
process for arbitrarily-nested generic types in parameters.
69+
process for arbitrarily-nested generic classes and protocols in parameters.
70+
71+
TODO: Note that we can currently only infer a specialization for a generic protocol when the
72+
argument _explicitly_ implements the protocol by listing it as a base class.
7073

7174
```py
72-
from typing import TypeVar
75+
from typing import Protocol, TypeVar
7376

7477
T = TypeVar("T")
7578

76-
def f(x: list[T]) -> T:
79+
class CanIndex(Protocol[T]):
80+
def __getitem__(self, index: int) -> T: ...
81+
82+
class ExplicitlyImplements(CanIndex[T]): ...
83+
84+
def takes_in_list(x: list[T]) -> T:
85+
return x[0]
86+
87+
def takes_in_protocol(x: CanIndex[T]) -> T:
7788
return x[0]
7889

79-
def deep(x: list[str]) -> None:
80-
reveal_type(f(x)) # revealed: str
90+
def deep_list(x: list[str]) -> None:
91+
reveal_type(takes_in_list(x)) # revealed: str
92+
# TODO: revealed: str
93+
reveal_type(takes_in_protocol(x)) # revealed: Unknown
8194

82-
def deeper(x: list[set[str]]) -> None:
83-
reveal_type(f(x)) # revealed: set[str]
95+
def deeper_list(x: list[set[str]]) -> None:
96+
reveal_type(takes_in_list(x)) # revealed: set[str]
97+
# TODO: revealed: set[str]
98+
reveal_type(takes_in_protocol(x)) # revealed: Unknown
99+
100+
def deep_explicit(x: ExplicitlyImplements[str]) -> None:
101+
reveal_type(takes_in_protocol(x)) # revealed: str
102+
103+
def deeper_explicit(x: ExplicitlyImplements[set[str]]) -> None:
104+
reveal_type(takes_in_protocol(x)) # revealed: set[str]
84105
```
85106

86107
This also works when passing in arguments that are subclasses of the parameter type.
@@ -89,8 +110,19 @@ This also works when passing in arguments that are subclasses of the parameter t
89110
class Sub(list[int]): ...
90111
class GenericSub(list[T]): ...
91112

92-
reveal_type(f(Sub())) # revealed: int
93-
reveal_type(f(GenericSub[str]())) # revealed: str
113+
reveal_type(takes_in_list(Sub())) # revealed: int
114+
# TODO: revealed: int
115+
reveal_type(takes_in_protocol(Sub())) # revealed: Unknown
116+
117+
reveal_type(takes_in_list(GenericSub[str]())) # revealed: str
118+
# TODO: revealed: str
119+
reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: Unknown
120+
121+
class ExplicitSub(ExplicitlyImplements[int]): ...
122+
class ExplicitGenericSub(ExplicitlyImplements[T]): ...
123+
124+
reveal_type(takes_in_protocol(ExplicitSub())) # revealed: int
125+
reveal_type(takes_in_protocol(ExplicitGenericSub[str]())) # revealed: str
94126
```
95127

96128
## Inferring a bound typevar

crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,42 @@ reveal_type(f("string")) # revealed: Literal["string"]
6161
## Inferring “deep” generic parameter types
6262

6363
The matching up of call arguments and discovery of constraints on typevars can be a recursive
64-
process for arbitrarily-nested generic types in parameters.
64+
process for arbitrarily-nested generic classes and protocols in parameters.
65+
66+
TODO: Note that we can currently only infer a specialization for a generic protocol when the
67+
argument _explicitly_ implements the protocol by listing it as a base class.
6568

6669
```py
67-
def f[T](x: list[T]) -> T:
70+
from typing import Protocol, TypeVar
71+
72+
S = TypeVar("S")
73+
74+
class CanIndex(Protocol[S]):
75+
def __getitem__(self, index: int) -> S: ...
76+
77+
class ExplicitlyImplements[T](CanIndex[T]): ...
78+
79+
def takes_in_list[T](x: list[T]) -> T:
80+
return x[0]
81+
82+
def takes_in_protocol[T](x: CanIndex[T]) -> T:
6883
return x[0]
6984

70-
def deep(x: list[str]) -> None:
71-
reveal_type(f(x)) # revealed: str
85+
def deep_list(x: list[str]) -> None:
86+
reveal_type(takes_in_list(x)) # revealed: str
87+
# TODO: revealed: str
88+
reveal_type(takes_in_protocol(x)) # revealed: Unknown
89+
90+
def deeper_list(x: list[set[str]]) -> None:
91+
reveal_type(takes_in_list(x)) # revealed: set[str]
92+
# TODO: revealed: set[str]
93+
reveal_type(takes_in_protocol(x)) # revealed: Unknown
94+
95+
def deep_explicit(x: ExplicitlyImplements[str]) -> None:
96+
reveal_type(takes_in_protocol(x)) # revealed: str
7297

73-
def deeper(x: list[set[str]]) -> None:
74-
reveal_type(f(x)) # revealed: set[str]
98+
def deeper_explicit(x: ExplicitlyImplements[set[str]]) -> None:
99+
reveal_type(takes_in_protocol(x)) # revealed: set[str]
75100
```
76101

77102
This also works when passing in arguments that are subclasses of the parameter type.
@@ -80,8 +105,19 @@ This also works when passing in arguments that are subclasses of the parameter t
80105
class Sub(list[int]): ...
81106
class GenericSub[T](list[T]): ...
82107

83-
reveal_type(f(Sub())) # revealed: int
84-
reveal_type(f(GenericSub[str]())) # revealed: str
108+
reveal_type(takes_in_list(Sub())) # revealed: int
109+
# TODO: revealed: int
110+
reveal_type(takes_in_protocol(Sub())) # revealed: Unknown
111+
112+
reveal_type(takes_in_list(GenericSub[str]())) # revealed: str
113+
# TODO: revealed: str
114+
reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: Unknown
115+
116+
class ExplicitSub(ExplicitlyImplements[int]): ...
117+
class ExplicitGenericSub[T](ExplicitlyImplements[T]): ...
118+
119+
reveal_type(takes_in_protocol(ExplicitSub())) # revealed: int
120+
reveal_type(takes_in_protocol(ExplicitGenericSub[str]())) # revealed: str
85121
```
86122

87123
## Inferring a bound typevar

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use rustc_hash::FxHashMap;
44
use crate::semantic_index::SemanticIndex;
55
use crate::types::class::ClassType;
66
use crate::types::class_base::ClassBase;
7-
use crate::types::instance::NominalInstanceType;
7+
use crate::types::instance::{NominalInstanceType, Protocol, ProtocolInstanceType};
88
use crate::types::signatures::{Parameter, Parameters, Signature};
99
use crate::types::{
1010
declaration_type, todo_type, KnownInstanceType, Type, TypeVarBoundOrConstraints,
@@ -630,7 +630,10 @@ impl<'db> SpecializationBuilder<'db> {
630630
// ```
631631
//
632632
// without specializing `T` to `None`.
633-
if !actual.is_never() && actual.is_subtype_of(self.db, formal) {
633+
if !matches!(formal, Type::ProtocolInstance(_))
634+
&& !actual.is_never()
635+
&& actual.is_subtype_of(self.db, formal)
636+
{
634637
return Ok(());
635638
}
636639

@@ -678,6 +681,14 @@ impl<'db> SpecializationBuilder<'db> {
678681
Type::NominalInstance(NominalInstanceType {
679682
class: ClassType::Generic(formal_alias),
680683
..
684+
})
685+
// TODO: This will only handle classes that explicit implement a generic protocol
686+
// by listing it as a base class. To handle classes that implicitly implement a
687+
// generic protocol, we will need to check the types of the protocol members to be
688+
// able to infer the specialization of the protocol that the class implements.
689+
| Type::ProtocolInstance(ProtocolInstanceType {
690+
inner: Protocol::FromClass(ClassType::Generic(formal_alias)),
691+
..
681692
}),
682693
Type::NominalInstance(NominalInstanceType {
683694
class: actual_class,

0 commit comments

Comments
 (0)