Skip to content

Commit 37ac0b5

Browse files
committed
use type context for inference of generic constructors
1 parent 966b438 commit 37ac0b5

File tree

5 files changed

+69
-7
lines changed

5 files changed

+69
-7
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,43 @@ reveal_type(s) # revealed: list[Literal[1]]
310310
reveal_type(s) # revealed: list[Literal[1]]
311311
```
312312

313+
## Generic constructor annotations are understood
314+
315+
```toml
316+
[environment]
317+
python-version = "3.12"
318+
```
319+
320+
```py
321+
from typing import Any
322+
323+
class X[T]:
324+
def __init__(self, value: T):
325+
self.value = value
326+
327+
a: X[int] = X(1)
328+
reveal_type(a) # revealed: X[int]
329+
330+
b: X[int | None] = X(1)
331+
reveal_type(b) # revealed: X[int | None]
332+
333+
c: X[int | None] | None = X(1)
334+
reveal_type(c) # revealed: X[int | None]
335+
336+
def _[T](a: X[T]):
337+
b: X[T | int] = X(a.value)
338+
reveal_type(b) # revealed: X[T@_ | int]
339+
340+
d: X[Any] = X(1)
341+
reveal_type(d) # revealed: X[Any]
342+
343+
def _(flag: bool):
344+
# TODO: Handle unions correctly.
345+
# error: [invalid-assignment] "Object of type `X[int]` is not assignable to `X[int | None]`"
346+
a: X[int | None] = X(1) if flag else X(2)
347+
reveal_type(a) # revealed: X[int | None]
348+
```
349+
313350
## PEP-604 annotations are supported
314351

315352
```py

crates/ty_python_semantic/resources/mdtest/narrow/assignment.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ dd: defaultdict[int, int] = defaultdict(int)
206206
dd[0] = 0
207207
cm: ChainMap[int, int] = ChainMap({1: 1}, {0: 0})
208208
cm[0] = 0
209-
reveal_type(cm) # revealed: ChainMap[Unknown | int, Unknown | int]
209+
reveal_type(cm) # revealed: ChainMap[int | Unknown, int | Unknown]
210210

211211
reveal_type(l[0]) # revealed: Literal[0]
212212
reveal_type(d[0]) # revealed: Literal[0]

crates/ty_python_semantic/src/types.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,20 @@ impl<'db> Type<'db> {
11101110
}
11111111
}
11121112

1113+
/// If the type is a generic class constructor, returns the class instance type.
1114+
pub(crate) fn synthesized_constructor_return_ty(self, db: &'db dyn Db) -> Option<Type<'db>> {
1115+
// TODO: This does not correctly handle unions or intersections.
1116+
if let Type::BoundMethod(method) = self
1117+
&& let Type::NominalInstance(instance) = method.self_instance(db)
1118+
&& method.function(db).name(db).as_str() == "__init__"
1119+
{
1120+
let class_ty = instance.class_literal(db).identity_specialization(db);
1121+
Some(Type::instance(db, class_ty))
1122+
} else {
1123+
None
1124+
}
1125+
}
1126+
11131127
pub const fn is_property_instance(&self) -> bool {
11141128
matches!(self, Type::PropertyInstance(..))
11151129
}

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2687,6 +2687,7 @@ struct ArgumentTypeChecker<'a, 'db> {
26872687
arguments: &'a CallArguments<'a, 'db>,
26882688
argument_matches: &'a [MatchedArgument<'db>],
26892689
parameter_tys: &'a mut [Option<Type<'db>>],
2690+
callable_type: Type<'db>,
26902691
call_expression_tcx: TypeContext<'db>,
26912692
return_ty: Type<'db>,
26922693
errors: &'a mut Vec<BindingError<'db>>,
@@ -2703,6 +2704,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27032704
arguments: &'a CallArguments<'a, 'db>,
27042705
argument_matches: &'a [MatchedArgument<'db>],
27052706
parameter_tys: &'a mut [Option<Type<'db>>],
2707+
callable_type: Type<'db>,
27062708
call_expression_tcx: TypeContext<'db>,
27072709
return_ty: Type<'db>,
27082710
errors: &'a mut Vec<BindingError<'db>>,
@@ -2713,6 +2715,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27132715
arguments,
27142716
argument_matches,
27152717
parameter_tys,
2718+
callable_type,
27162719
call_expression_tcx,
27172720
return_ty,
27182721
errors,
@@ -2754,16 +2757,19 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27542757
};
27552758

27562759
let return_with_tcx = self
2757-
.signature
2758-
.return_ty
2760+
.callable_type
2761+
.synthesized_constructor_return_ty(self.db)
2762+
.or(self.signature.return_ty)
27592763
.zip(self.call_expression_tcx.annotation);
27602764

27612765
self.inferable_typevars = generic_context.inferable_typevars(self.db);
27622766
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
27632767

27642768
// Prefer the declared type of generic classes.
27652769
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
2766-
tcx.class_specialization(self.db)?;
2770+
tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some())
2771+
.class_specialization(self.db)?;
2772+
27672773
builder.infer(return_ty, tcx).ok()?;
27682774
Some(builder.type_mappings().clone())
27692775
});
@@ -3196,6 +3202,7 @@ impl<'db> Binding<'db> {
31963202
arguments,
31973203
&self.argument_matches,
31983204
&mut self.parameter_tys,
3205+
self.callable_type,
31993206
call_expression_tcx,
32003207
self.return_ty,
32013208
&mut self.errors,

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6025,9 +6025,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
60256025
// TODO: Checking assignability against the full declared type could help avoid
60266026
// cases where the constraint solver is not smart enough to solve complex unions.
60276027
// We should see revisit this after the new constraint solver is implemented.
6028-
if !speculated_bindings
6029-
.return_type(db)
6030-
.is_assignable_to(db, narrowed_ty)
6028+
if speculated_bindings
6029+
.callable_type()
6030+
.synthesized_constructor_return_ty(db)
6031+
.is_none()
6032+
&& !speculated_bindings
6033+
.return_type(db)
6034+
.is_assignable_to(db, narrowed_ty)
60316035
{
60326036
return None;
60336037
}

0 commit comments

Comments
 (0)