Skip to content

Commit 02bb567

Browse files
committed
use type context for inference of generic constructors
1 parent a0144fd commit 02bb567

File tree

5 files changed

+71
-7
lines changed

5 files changed

+71
-7
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,42 @@ reveal_type(s) # revealed: list[Literal[1]]
310310
reveal_type(s) # revealed: list[Literal[1]]
311311
```
312312

313+
## Generic constructor annotations are understood
314+
315+
```toml
316+
[environment]
317+
python-version = "3.12"
318+
```
319+
320+
```py
321+
from typing import Any
322+
323+
class X[T]:
324+
def __init__(self, value: T):
325+
self.value = value
326+
327+
a: X[int] = X(1)
328+
reveal_type(a) # revealed: X[int]
329+
330+
b: X[int | None] = X(1)
331+
reveal_type(b) # revealed: X[int | None]
332+
333+
c: X[int | None] | None = X(1)
334+
reveal_type(c) # revealed: X[int | None]
335+
336+
def _[T](d: X[T]):
337+
_: X[T | int] = X(d.value)
338+
339+
e: X[Any] = X(1)
340+
reveal_type(e) # revealed: X[Any]
341+
342+
def _(flag: bool):
343+
# TODO: Handle unions correctly.
344+
# error: [invalid-assignment] "Object of type `X[int]` is not assignable to `X[int | None]`"
345+
a: X[int | None] = X(1) if flag else X(2)
346+
reveal_type(a) # revealed: X[int | None]
347+
```
348+
313349
## PEP-604 annotations are supported
314350

315351
```py

crates/ty_python_semantic/resources/mdtest/narrow/assignment.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ dd: defaultdict[int, int] = defaultdict(int)
206206
dd[0] = 0
207207
cm: ChainMap[int, int] = ChainMap({1: 1}, {0: 0})
208208
cm[0] = 0
209-
reveal_type(cm) # revealed: ChainMap[Unknown | int, Unknown | int]
209+
reveal_type(cm) # revealed: ChainMap[int | Unknown, int | Unknown]
210210

211211
reveal_type(l[0]) # revealed: Literal[0]
212212
reveal_type(d[0]) # revealed: Literal[0]

crates/ty_python_semantic/src/types.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,23 @@ impl<'db> Type<'db> {
11101110
}
11111111
}
11121112

1113+
/// If the type is a generic class constructor, returns the class instance type.
1114+
pub(crate) fn synthesized_constructor_return_ty(self, db: &'db dyn Db) -> Option<Type<'db>> {
1115+
// TODO: This does not correctly handle unions or intersections.
1116+
if let Type::BoundMethod(method) = self
1117+
&& let Type::NominalInstance(instance) = method.self_instance(db)
1118+
&& matches!(
1119+
method.function(db).name(db).as_str(),
1120+
"__init__" | "__new__"
1121+
)
1122+
{
1123+
let class_ty = instance.class_literal(db).identity_specialization(db);
1124+
Some(Type::instance(db, class_ty))
1125+
} else {
1126+
None
1127+
}
1128+
}
1129+
11131130
pub const fn is_property_instance(&self) -> bool {
11141131
matches!(self, Type::PropertyInstance(..))
11151132
}

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2667,6 +2667,7 @@ struct ArgumentTypeChecker<'a, 'db> {
26672667
arguments: &'a CallArguments<'a, 'db>,
26682668
argument_matches: &'a [MatchedArgument<'db>],
26692669
parameter_tys: &'a mut [Option<Type<'db>>],
2670+
callable_type: Type<'db>,
26702671
call_expression_tcx: TypeContext<'db>,
26712672
return_ty: Type<'db>,
26722673
errors: &'a mut Vec<BindingError<'db>>,
@@ -2683,6 +2684,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26832684
arguments: &'a CallArguments<'a, 'db>,
26842685
argument_matches: &'a [MatchedArgument<'db>],
26852686
parameter_tys: &'a mut [Option<Type<'db>>],
2687+
callable_type: Type<'db>,
26862688
call_expression_tcx: TypeContext<'db>,
26872689
return_ty: Type<'db>,
26882690
errors: &'a mut Vec<BindingError<'db>>,
@@ -2693,6 +2695,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26932695
arguments,
26942696
argument_matches,
26952697
parameter_tys,
2698+
callable_type,
26962699
call_expression_tcx,
26972700
return_ty,
26982701
errors,
@@ -2734,16 +2737,19 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27342737
};
27352738

27362739
let return_with_tcx = self
2737-
.signature
2738-
.return_ty
2740+
.callable_type
2741+
.synthesized_constructor_return_ty(self.db)
2742+
.or(self.signature.return_ty)
27392743
.zip(self.call_expression_tcx.annotation);
27402744

27412745
self.inferable_typevars = generic_context.inferable_typevars(self.db);
27422746
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
27432747

27442748
// Prefer the declared type of generic classes.
27452749
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
2746-
tcx.class_specialization(self.db)?;
2750+
tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some())
2751+
.class_specialization(self.db)?;
2752+
27472753
builder.infer(return_ty, tcx).ok()?;
27482754
Some(builder.type_mappings().clone())
27492755
});
@@ -3176,6 +3182,7 @@ impl<'db> Binding<'db> {
31763182
arguments,
31773183
&self.argument_matches,
31783184
&mut self.parameter_tys,
3185+
self.callable_type,
31793186
call_expression_tcx,
31803187
self.return_ty,
31813188
&mut self.errors,

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5749,9 +5749,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
57495749
// TODO: Checking assignability against the full declared type could help avoid
57505750
// cases where the constraint solver is not smart enough to solve complex unions.
57515751
// We should see revisit this after the new constraint solver is implemented.
5752-
if !speculated_bindings
5753-
.return_type(db)
5754-
.is_assignable_to(db, narrowed_ty)
5752+
if speculated_bindings
5753+
.callable_type()
5754+
.synthesized_constructor_return_ty(db)
5755+
.is_none()
5756+
&& !speculated_bindings
5757+
.return_type(db)
5758+
.is_assignable_to(db, narrowed_ty)
57555759
{
57565760
return None;
57575761
}

0 commit comments

Comments
 (0)