Skip to content

Commit a5c0c88

Browse files
committed
promote literals in invariant return position
1 parent 039a69f commit a5c0c88

File tree

11 files changed

+132
-75
lines changed

11 files changed

+132
-75
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ f: dict[list[Literal[1]], list[Literal[Color.RED]]] = {[1]: [Color.RED, Color.RE
262262
reveal_type(f) # revealed: dict[list[Literal[1]], list[Color]]
263263

264264
class X[T]:
265+
value: T
266+
265267
def __init__(self, value: T): ...
266268

267269
g: X[Literal[1]] = X(1)
@@ -485,7 +487,7 @@ def f[T](x: T) -> list[T]:
485487
return [x]
486488

487489
a = f("a")
488-
reveal_type(a) # revealed: list[Literal["a"]]
490+
reveal_type(a) # revealed: list[str]
489491

490492
b: list[int | Literal["a"]] = f("a")
491493
reveal_type(b) # revealed: list[int | Literal["a"]]
@@ -499,10 +501,10 @@ reveal_type(d) # revealed: list[int | tuple[int, int]]
499501
e: list[int] = f(True)
500502
reveal_type(e) # revealed: list[int]
501503

502-
# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `list[int]`"
504+
# error: [invalid-assignment] "Object of type `list[str]` is not assignable to `list[int]`"
503505
g: list[int] = f("a")
504506

505-
# error: [invalid-assignment] "Object of type `list[Literal["a"]]` is not assignable to `tuple[int]`"
507+
# error: [invalid-assignment] "Object of type `list[str]` is not assignable to `tuple[int]`"
506508
h: tuple[int] = f("a")
507509

508510
def f2[T: int](x: T) -> T:
@@ -607,7 +609,7 @@ def f3[T](x: T) -> list[T] | dict[T, T]:
607609
return [x]
608610

609611
a = f(1)
610-
reveal_type(a) # revealed: list[Literal[1]]
612+
reveal_type(a) # revealed: list[int]
611613

612614
b: list[Any] = f(1)
613615
reveal_type(b) # revealed: list[Any]

crates/ty_python_semantic/resources/mdtest/bidirectional.md

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,19 @@ python-version = "3.12"
1616
```
1717

1818
```py
19+
from typing import Literal
20+
1921
def list1[T](x: T) -> list[T]:
2022
return [x]
2123

22-
l1 = list1(1)
24+
l1: list[Literal[1]] = list1(1)
2325
reveal_type(l1) # revealed: list[Literal[1]]
24-
l2: list[int] = list1(1)
25-
reveal_type(l2) # revealed: list[int]
2626

27-
# `list[Literal[1]]` and `list[int]` are incompatible, since `list[T]` is invariant in `T`.
28-
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
29-
l2 = l1
30-
31-
intermediate = list1(1)
32-
# TODO: the error will not occur if we can infer the type of `intermediate` to be `list[int]`
33-
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
34-
l3: list[int] = intermediate
35-
# TODO: it would be nice if this were `list[int]`
36-
reveal_type(intermediate) # revealed: list[Literal[1]]
37-
reveal_type(l3) # revealed: list[int]
27+
l2 = list1(1)
28+
reveal_type(l2) # revealed: list[int]
3829

39-
l4: list[int | str] | None = list1(1)
40-
reveal_type(l4) # revealed: list[int | str]
30+
l3: list[int | str] | None = list1(1)
31+
reveal_type(l3) # revealed: list[int | str]
4132

4233
def _(l: list[int] | None = None):
4334
l1 = l or list()
@@ -233,6 +224,9 @@ def _(flag: bool):
233224

234225
def _(c: C):
235226
c.x = lst(1)
227+
228+
# TODO: Use the parameter type of `__set__` as type context to avoid this error.
229+
# error: [invalid-assignment]
236230
C.x = lst(1)
237231
```
238232

crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -866,10 +866,10 @@ reveal_type(ParentDataclass.__init__)
866866
reveal_type(ChildOfParentDataclass.__init__)
867867

868868
result_int = uses_dataclass(42)
869-
reveal_type(result_int) # revealed: ChildOfParentDataclass[Literal[42]]
869+
reveal_type(result_int) # revealed: ChildOfParentDataclass[int]
870870

871871
result_str = uses_dataclass("hello")
872-
reveal_type(result_str) # revealed: ChildOfParentDataclass[Literal["hello"]]
872+
reveal_type(result_str) # revealed: ChildOfParentDataclass[str]
873873
```
874874

875875
## Descriptor-typed fields

crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,27 +315,31 @@ If either method comes from a generic base class, we don't currently use its inf
315315
to specialize the class.
316316

317317
```py
318+
from typing import Literal
319+
318320
class C[T, U]:
319321
def __new__(cls, *args, **kwargs) -> "C[T, U]":
320322
return object.__new__(cls)
321323

322324
class D[V](C[V, int]):
323325
def __init__(self, x: V) -> None: ...
324326

325-
reveal_type(D(1)) # revealed: D[int]
327+
reveal_type(D(1)) # revealed: D[Literal[1]]
326328
```
327329

328330
### Generic class inherits `__init__` from generic base class
329331

330332
```py
333+
from typing import Literal
334+
331335
class C[T, U]:
332336
def __init__(self, t: T, u: U) -> None: ...
333337

334338
class D[T, U](C[T, U]):
335339
pass
336340

337-
reveal_type(C(1, "str")) # revealed: C[int, str]
338-
reveal_type(D(1, "str")) # revealed: D[int, str]
341+
reveal_type(C(1, "str")) # revealed: C[Literal[1], Literal["str"]]
342+
reveal_type(D(1, "str")) # revealed: D[Literal[1], Literal["str"]]
339343
```
340344

341345
### Generic class inherits `__init__` from `dict`
@@ -358,7 +362,7 @@ context. But from the user's point of view, this is another example of the above
358362
```py
359363
class C[T, U](tuple[T, U]): ...
360364

361-
reveal_type(C((1, 2))) # revealed: C[int, int]
365+
reveal_type(C((1, 2))) # revealed: C[Literal[1], Literal[2]]
362366
```
363367

364368
### Upcasting a `tuple` to its `Sequence` supertype

crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -514,10 +514,8 @@ For covariant types, such as `frozenset`, the ideal behaviour would be to not pr
514514
types to their instance supertypes: doing so causes more false positives than it fixes:
515515

516516
```py
517-
# TODO: better here would be `frozenset[Literal[1, 2, 3]]`
518-
reveal_type(frozenset((1, 2, 3))) # revealed: frozenset[int]
519-
# TODO: better here would be `frozenset[tuple[Literal[1], Literal[2], Literal[3]]]`
520-
reveal_type(frozenset(((1, 2, 3),))) # revealed: frozenset[tuple[int, int, int]]
517+
reveal_type(frozenset((1, 2, 3))) # revealed: frozenset[Literal[1, 2, 3]]
518+
reveal_type(frozenset(((1, 2, 3),))) # revealed: frozenset[tuple[Literal[1], Literal[2], Literal[3]]]
521519
```
522520

523521
Literals are always promoted for invariant containers such as `list`, however, even though this can

crates/ty_python_semantic/resources/mdtest/typed_dict.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,21 +126,21 @@ Also, the value types ​​declared in a `TypedDict` affect generic call infere
126126

127127
```py
128128
class Plot(TypedDict):
129-
y: list[int]
130-
x: list[int] | None
129+
y: list[int | None]
130+
x: list[int | None] | None
131131

132132
plot1: Plot = {"y": [1, 2, 3], "x": None}
133133

134134
def homogeneous_list[T](*args: T) -> list[T]:
135135
return list(args)
136136

137-
reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[Literal[1, 2, 3]]
137+
reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[int]
138138
plot2: Plot = {"y": homogeneous_list(1, 2, 3), "x": None}
139-
reveal_type(plot2["y"]) # revealed: list[int]
139+
reveal_type(plot2["y"]) # revealed: list[int | None]
140140

141141
plot3: Plot = {"y": homogeneous_list(1, 2, 3), "x": homogeneous_list(1, 2, 3)}
142-
reveal_type(plot3["y"]) # revealed: list[int]
143-
reveal_type(plot3["x"]) # revealed: list[int] | None
142+
reveal_type(plot3["y"]) # revealed: list[int | None]
143+
reveal_type(plot3["x"]) # revealed: list[int | None] | None
144144

145145
Y = "y"
146146
X = "x"

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ use crate::types::function::{
3030
OverloadLiteral,
3131
};
3232
use crate::types::generics::{
33-
InferableTypeVars, Specialization, SpecializationBuilder, SpecializationError,
33+
GenericContextTypeVar, InferableTypeVars, Specialization, SpecializationBuilder,
34+
SpecializationError,
3435
};
3536
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
3637
use crate::types::tuple::{TupleLength, TupleType};
@@ -2762,6 +2763,51 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27622763
.or(self.signature.return_ty)
27632764
.zip(self.call_expression_tcx.annotation);
27642765

2766+
let tcx_specialization = self
2767+
.call_expression_tcx
2768+
.annotation
2769+
.and_then(|annotation| annotation.class_specialization(self.db));
2770+
2771+
let promote_literals = |typevar: GenericContextTypeVar<'db>, ty: Type<'db>| -> Type<'db> {
2772+
let bound_typevar = typevar.bound_typevar();
2773+
2774+
if typevar.is_inherited() && bound_typevar.variance(self.db).is_invariant() {
2775+
return ty.promote_literals(
2776+
self.db,
2777+
TypeContext::new(
2778+
tcx_specialization
2779+
.and_then(|specialization| specialization.get(self.db, bound_typevar)),
2780+
),
2781+
);
2782+
}
2783+
2784+
let Some(return_specialization) = self
2785+
.signature
2786+
.return_ty
2787+
.and_then(|return_ty| return_ty.class_specialization(self.db))
2788+
else {
2789+
return ty;
2790+
};
2791+
2792+
if let Some((typevar, _)) = return_specialization
2793+
.generic_context(self.db)
2794+
.variables(self.db)
2795+
.zip(return_specialization.types(self.db))
2796+
.find(|(_, ty)| **ty == Type::TypeVar(bound_typevar))
2797+
.filter(|(typevar, _)| typevar.variance(self.db).is_invariant())
2798+
{
2799+
return ty.promote_literals(
2800+
self.db,
2801+
TypeContext::new(
2802+
tcx_specialization
2803+
.and_then(|specialization| specialization.get(self.db, typevar)),
2804+
),
2805+
);
2806+
}
2807+
2808+
ty
2809+
};
2810+
27652811
self.inferable_typevars = generic_context.inferable_typevars(self.db);
27662812
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
27672813

@@ -2811,7 +2857,9 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
28112857
}
28122858

28132859
// Build the specialization first without inferring the complete type context.
2814-
let isolated_specialization = builder.build(generic_context, self.call_expression_tcx);
2860+
let isolated_specialization = builder
2861+
.mapped(generic_context, promote_literals)
2862+
.build(generic_context);
28152863
let isolated_return_ty = self
28162864
.return_ty
28172865
.apply_specialization(self.db, isolated_specialization);
@@ -2836,7 +2884,9 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
28362884
builder.infer(return_ty, call_expression_tcx).ok()?;
28372885

28382886
// Otherwise, build the specialization again after inferring the complete type context.
2839-
let specialization = builder.build(generic_context, self.call_expression_tcx);
2887+
let specialization = builder
2888+
.mapped(generic_context, promote_literals)
2889+
.build(generic_context);
28402890
let return_ty = return_ty.apply_specialization(self.db, specialization);
28412891

28422892
Some((Some(specialization), return_ty))

crates/ty_python_semantic/src/types/class.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1475,7 +1475,7 @@ impl<'db> ClassLiteral<'db> {
14751475
/// promote any typevars that are inferred as a literal to the corresponding instance type.
14761476
fn inherited_generic_context(self, db: &'db dyn Db) -> Option<GenericContext<'db>> {
14771477
self.generic_context(db)
1478-
.map(|generic_context| generic_context.promote_literals(db))
1478+
.map(|generic_context| generic_context.set_inherited(db))
14791479
}
14801480

14811481
pub(super) fn file(self, db: &dyn Db) -> File {

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -184,21 +184,29 @@ impl<'a, 'db> InferableTypeVars<'a, 'db> {
184184
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)]
185185
pub struct GenericContextTypeVar<'db> {
186186
bound_typevar: BoundTypeVarInstance<'db>,
187-
should_promote_literals: bool,
187+
is_inherited: bool,
188188
}
189189

190190
impl<'db> GenericContextTypeVar<'db> {
191191
fn new(bound_typevar: BoundTypeVarInstance<'db>) -> Self {
192192
Self {
193193
bound_typevar,
194-
should_promote_literals: false,
194+
is_inherited: false,
195195
}
196196
}
197197

198-
fn promote_literals(mut self) -> Self {
199-
self.should_promote_literals = true;
198+
pub fn is_inherited(&self) -> bool {
199+
self.is_inherited
200+
}
201+
202+
fn set_inherited(mut self) -> Self {
203+
self.is_inherited = true;
200204
self
201205
}
206+
207+
pub fn bound_typevar(&self) -> BoundTypeVarInstance<'db> {
208+
self.bound_typevar
209+
}
202210
}
203211

204212
/// A list of formal type variables for a generic function, class, or type alias.
@@ -262,14 +270,13 @@ impl<'db> GenericContext<'db> {
262270
Self::from_variables(db, type_params.into_iter().map(GenericContextTypeVar::new))
263271
}
264272

265-
/// Returns a copy of this generic context where we will promote literal types in any inferred
266-
/// specializations.
267-
pub(crate) fn promote_literals(self, db: &'db dyn Db) -> Self {
273+
/// Mark the variables in this generic context as inherited from an outer class definition.
274+
pub(crate) fn set_inherited(self, db: &'db dyn Db) -> Self {
268275
Self::from_variables(
269276
db,
270277
self.variables_inner(db)
271278
.values()
272-
.map(|variable| variable.promote_literals()),
279+
.map(|variable| variable.set_inherited()),
273280
)
274281
}
275282

@@ -1321,31 +1328,30 @@ impl<'db> SpecializationBuilder<'db> {
13211328
&self.types
13221329
}
13231330

1324-
pub(crate) fn build(
1325-
&mut self,
1331+
pub(crate) fn mapped(
1332+
&self,
13261333
generic_context: GenericContext<'db>,
1327-
tcx: TypeContext<'db>,
1328-
) -> Specialization<'db> {
1329-
let tcx_specialization = tcx
1330-
.annotation
1331-
.and_then(|annotation| annotation.class_specialization(self.db));
1332-
1333-
let types =
1334-
(generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| {
1335-
let mut ty = self.types.get(identity).copied();
1336-
1337-
// When inferring a specialization for a generic class typevar from a constructor call,
1338-
// promote any typevars that are inferred as a literal to the corresponding instance type.
1339-
if variable.should_promote_literals {
1340-
let tcx = tcx_specialization.and_then(|specialization| {
1341-
specialization.get(self.db, variable.bound_typevar)
1342-
});
1343-
1344-
ty = ty.map(|ty| ty.promote_literals(self.db, TypeContext::new(tcx)));
1345-
}
1334+
f: impl Fn(GenericContextTypeVar<'db>, Type<'db>) -> Type<'db>,
1335+
) -> Self {
1336+
let mut types = self.types.clone();
1337+
for (identity, variable) in generic_context.variables_inner(self.db) {
1338+
if let Some(ty) = types.get_mut(identity) {
1339+
*ty = f(*variable, *ty);
1340+
}
1341+
}
1342+
1343+
Self {
1344+
db: self.db,
1345+
inferable: self.inferable,
1346+
types,
1347+
}
1348+
}
13461349

1347-
ty
1348-
});
1350+
pub(crate) fn build(&mut self, generic_context: GenericContext<'db>) -> Specialization<'db> {
1351+
let types = generic_context
1352+
.variables_inner(self.db)
1353+
.iter()
1354+
.map(|(identity, _)| self.types.get(identity).copied());
13491355

13501356
// TODO Infer the tuple spec for a tuple type
13511357
generic_context.specialize_partial(self.db, types)

0 commit comments

Comments
 (0)