Skip to content

Commit 0a033e2

Browse files
committed
infer more precise types for annotated collection literals
1 parent 1cd8ab3 commit 0a033e2

File tree

3 files changed

+102
-7
lines changed

3 files changed

+102
-7
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,52 @@ b: tuple[int] = ("foo",)
7979
c: tuple[str | int, str] = ([], "foo")
8080
```
8181

82+
## Collection literal annotations are understood
83+
84+
`module.py`:
85+
86+
```py
87+
import typing
88+
89+
a: list[int] = [1, 2, 3]
90+
b: list[int | str] = [1, 2, 3]
91+
c: typing.List[int] = [1, 2, 3]
92+
d: list[typing.Any] = []
93+
94+
e: set[int] = {1, 2, 3}
95+
f: set[int | str] = {1, 2, 3}
96+
g: typing.Set[int] = {1, 2, 3}
97+
```
98+
99+
`script.py`:
100+
101+
```py
102+
import typing
103+
from module import a, b, c, d, e, f, g
104+
105+
reveal_type(a) # revealed: list[int]
106+
reveal_type(b) # revealed: list[int | str]
107+
reveal_type(c) # revealed: list[int]
108+
reveal_type(d) # revealed: list[Any]
109+
110+
reveal_type(e) # revealed: set[int]
111+
reveal_type(f) # revealed: set[int | str]
112+
reveal_type(g) # revealed: set[int]
113+
```
114+
115+
## Incorrect collection literal assignments are complained aobut
116+
117+
```py
118+
# error: [invalid-assignment] "Object of type `list[str | Literal[1, 2, 3]]` is not assignable to `list[str]`"
119+
a: list[str] = [1, 2,3]
120+
121+
# error: [invalid-assignment] "Object of type `set[int | Literal["3"]]` is not assignable to `set[int]`"
122+
b: set[int] = {1, 2, "3"}
123+
124+
# error: [invalid-assignment] "Object of type `list[@Todo(list literal element type)]` is not assignable to `set[int]`"
125+
h: set[int] = [1, 2, 3]
126+
```
127+
82128
## PEP-604 annotations are supported
83129

84130
```py

crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,7 @@ def f(l: list[str | None]):
328328
def _():
329329
l: list[str | None] = [None]
330330
def _():
331-
# TODO: should be `str | None`
332-
reveal_type(l[0]) # revealed: @Todo(list literal element type)
331+
reveal_type(l[0]) # revealed: str | None
333332

334333
def _():
335334
def _():

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ use crate::types::diagnostic::{
7272
use crate::types::function::{
7373
FunctionDecorators, FunctionLiteral, FunctionType, KnownFunction, OverloadLiteral,
7474
};
75-
use crate::types::generics::LegacyGenericBase;
7675
use crate::types::generics::{GenericContext, bind_typevar};
76+
use crate::types::generics::{LegacyGenericBase, SpecializationBuilder};
7777
use crate::types::instance::SliceLiteral;
7878
use crate::types::mro::MroErrorKind;
7979
use crate::types::signatures::Signature;
@@ -5253,15 +5253,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
52535253
Type::heterogeneous_tuple(db, element_types)
52545254
}
52555255

5256-
fn infer_list_expression(&mut self, list: &ast::ExprList, _tcx: TypeContext<'db>) -> Type<'db> {
5256+
fn infer_list_expression(&mut self, list: &ast::ExprList, tcx: TypeContext<'db>) -> Type<'db> {
52575257
let ast::ExprList {
52585258
range: _,
52595259
node_index: _,
52605260
elts,
52615261
ctx: _,
52625262
} = list;
52635263

5264-
// TODO: Use the type context for more precise inference.
5264+
if let Some(inferred) = self.infer_annotated_collection_literal(elts, tcx, KnownClass::List)
5265+
{
5266+
return inferred;
5267+
}
5268+
52655269
for elt in elts {
52665270
self.infer_expression(elt, TypeContext::default());
52675271
}
@@ -5270,21 +5274,66 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
52705274
.to_specialized_instance(self.db(), [todo_type!("list literal element type")])
52715275
}
52725276

5273-
fn infer_set_expression(&mut self, set: &ast::ExprSet, _tcx: TypeContext<'db>) -> Type<'db> {
5277+
fn infer_set_expression(&mut self, set: &ast::ExprSet, tcx: TypeContext<'db>) -> Type<'db> {
52745278
let ast::ExprSet {
52755279
range: _,
52765280
node_index: _,
52775281
elts,
52785282
} = set;
52795283

5280-
// TODO: Use the type context for more precise inference.
5284+
if let Some(inferred) = self.infer_annotated_collection_literal(elts, tcx, KnownClass::Set)
5285+
{
5286+
return inferred;
5287+
}
5288+
52815289
for elt in elts {
52825290
self.infer_expression(elt, TypeContext::default());
52835291
}
52845292

52855293
KnownClass::Set.to_specialized_instance(self.db(), [todo_type!("set literal element type")])
52865294
}
52875295

5296+
// Infer the type of a collection literal in an annotated assignment, e.g. `_: list[_] = [..]`.
5297+
fn infer_annotated_collection_literal(
5298+
&mut self,
5299+
elts: &[ast::Expr],
5300+
tcx: TypeContext<'db>,
5301+
collection_class: KnownClass,
5302+
) -> Option<Type<'db>> {
5303+
// Extract the annotated type of `list[T]` in the type annotation.
5304+
let class_type = tcx.annotation?.into_nominal_instance()?.class(self.db());
5305+
if !class_type.is_known(self.db(), collection_class) {
5306+
return None;
5307+
}
5308+
let specialization = class_type.into_generic_alias()?.specialization(self.db());
5309+
let [annotated_elts_ty] = specialization.types(self.db()) else {
5310+
return None;
5311+
};
5312+
5313+
// Extract the type variable `T` from the definition of `list[T]`.
5314+
let class_literal = collection_class.try_to_class_literal(self.db())?;
5315+
let generic_context = class_literal.generic_context(self.db())?;
5316+
let variables = generic_context.variables(self.db());
5317+
let elts_ty = Type::TypeVar(*variables.iter().exactly_one().ok()?);
5318+
5319+
// Infer a precise type for `T`, based on,
5320+
let mut builder = SpecializationBuilder::new(self.db());
5321+
5322+
// the annotated type,
5323+
builder.infer(elts_ty, *annotated_elts_ty).ok()?;
5324+
5325+
// as well as the type of each element in the collection literal.
5326+
for elt in elts {
5327+
let inferred_elt_ty = self.infer_expression(elt, TypeContext::default());
5328+
builder.infer(elts_ty, inferred_elt_ty).ok()?;
5329+
}
5330+
5331+
let class_type = class_literal
5332+
.apply_specialization(self.db(), |generic_context| builder.build(generic_context));
5333+
5334+
Type::from(class_type).to_instance(self.db())
5335+
}
5336+
52885337
fn infer_dict_expression(&mut self, dict: &ast::ExprDict, _tcx: TypeContext<'db>) -> Type<'db> {
52895338
let ast::ExprDict {
52905339
range: _,
@@ -5306,6 +5355,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
53065355
],
53075356
)
53085357
}
5358+
53095359
/// Infer the type of the `iter` expression of the first comprehension.
53105360
fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) {
53115361
let mut comprehensions_iter = comprehensions.iter();

0 commit comments

Comments
 (0)