Skip to content

Commit 22bbbe3

Browse files
committed
use type context for inference of generic constructors
1 parent 96b1563 commit 22bbbe3

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,37 @@ reveal_type(s) # revealed: list[Literal[1]]
310310
reveal_type(s) # revealed: list[Literal[1]]
311311
```
312312

313+
## Generic constructor annotations are understood
314+
315+
```toml
316+
[environment]
317+
python-version = "3.12"
318+
```
319+
320+
```py
321+
from typing import Any
322+
323+
class X[T]:
324+
def __init__(self, value: T):
325+
self.value = value
326+
327+
a: X[int] = X(1)
328+
reveal_type(a) # revealed: X[int]
329+
330+
b: X[int | None] = X(1)
331+
reveal_type(b) # revealed: X[int | None]
332+
333+
c: X[int | None] | None = X(1)
334+
reveal_type(c) # revealed: X[int | None]
335+
336+
def _[T](d: X[T]):
337+
_: X[T | int] = X(d.value)
338+
339+
e: X[Any] = X(1)
340+
# TODO: Prefer the declared type here.
341+
reveal_type(e) # revealed: X[int | Any]
342+
```
343+
313344
## PEP-604 annotations are supported
314345

315346
```py

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,6 +2523,7 @@ struct ArgumentTypeChecker<'a, 'db> {
25232523
arguments: &'a CallArguments<'a, 'db>,
25242524
argument_matches: &'a [MatchedArgument<'db>],
25252525
parameter_tys: &'a mut [Option<Type<'db>>],
2526+
callable_type: Type<'db>,
25262527
call_expression_tcx: &'a TypeContext<'db>,
25272528
return_ty: Type<'db>,
25282529
errors: &'a mut Vec<BindingError<'db>>,
@@ -2539,6 +2540,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25392540
arguments: &'a CallArguments<'a, 'db>,
25402541
argument_matches: &'a [MatchedArgument<'db>],
25412542
parameter_tys: &'a mut [Option<Type<'db>>],
2543+
callable_type: Type<'db>,
25422544
call_expression_tcx: &'a TypeContext<'db>,
25432545
return_ty: Type<'db>,
25442546
errors: &'a mut Vec<BindingError<'db>>,
@@ -2549,6 +2551,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25492551
arguments,
25502552
argument_matches,
25512553
parameter_tys,
2554+
callable_type,
25522555
call_expression_tcx,
25532556
return_ty,
25542557
errors,
@@ -2623,7 +2626,22 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26232626
.apply_specialization(self.db, isolated_specialization);
26242627

26252628
let mut try_infer_tcx = || {
2626-
let return_ty = self.signature.return_ty?;
2629+
// For generic constructors, we use the type context to infer the specialization of the class
2630+
// instance instead of the method's return type.
2631+
let (inference_context, generic_constructor) = if let Type::BoundMethod(method) =
2632+
self.callable_type
2633+
&& let Type::NominalInstance(instance) = method.self_instance(self.db)
2634+
&& method.function(self.db).name(self.db) == "__init__"
2635+
{
2636+
let class_ty = instance
2637+
.class_literal(self.db)
2638+
.identity_specialization(self.db);
2639+
2640+
(Type::instance(self.db, class_ty), true)
2641+
} else {
2642+
(self.signature.return_ty?, false)
2643+
};
2644+
26272645
let call_expression_tcx = self.call_expression_tcx.annotation?;
26282646

26292647
// A type variable is not a useful type-context for expression inference, and applying it
@@ -2634,17 +2652,19 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26342652

26352653
// If the return type is already assignable to the annotated type, we can ignore the
26362654
// type context and prefer the narrower inferred type.
2637-
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
2655+
if !generic_constructor
2656+
&& isolated_return_ty.is_assignable_to(self.db, call_expression_tcx)
2657+
{
26382658
return None;
26392659
}
26402660

26412661
// TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an
26422662
// annotated assignment, to closer match the order of any unions written in the type annotation.
2643-
builder.infer(return_ty, call_expression_tcx).ok()?;
2663+
builder.infer(inference_context, call_expression_tcx).ok()?;
26442664

26452665
// Otherwise, build the specialization again after inferring the type context.
26462666
let specialization = builder.build(generic_context, *self.call_expression_tcx);
2647-
let return_ty = return_ty.apply_specialization(self.db, specialization);
2667+
let return_ty = self.return_ty.apply_specialization(self.db, specialization);
26482668

26492669
Some((Some(specialization), return_ty))
26502670
};
@@ -3009,6 +3029,7 @@ impl<'db> Binding<'db> {
30093029
arguments,
30103030
&self.argument_matches,
30113031
&mut self.parameter_tys,
3032+
self.callable_type,
30123033
call_expression_tcx,
30133034
self.return_ty,
30143035
&mut self.errors,

0 commit comments

Comments
 (0)