Skip to content

Commit 2a2b719

Browse files
authored
[ty] Improve generic class constructor inference (#21442)
## Summary We currently fail to account for the type context when inferring generic classes constructed with `__new__`, or synthesized `__init__` for dataclasses.
1 parent ffb7bdd commit 2a2b719

File tree

4 files changed

+93
-46
lines changed

4 files changed

+93
-46
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -324,25 +324,25 @@ class X[T]:
324324
def __init__(self, value: T):
325325
self.value = value
326326

327-
a: X[int] = X(1)
328-
reveal_type(a) # revealed: X[int]
327+
x1: X[int] = X(1)
328+
reveal_type(x1) # revealed: X[int]
329329

330-
b: X[int | None] = X(1)
331-
reveal_type(b) # revealed: X[int | None]
330+
x2: X[int | None] = X(1)
331+
reveal_type(x2) # revealed: X[int | None]
332332

333-
c: X[int | None] | None = X(1)
334-
reveal_type(c) # revealed: X[int | None]
333+
x3: X[int | None] | None = X(1)
334+
reveal_type(x3) # revealed: X[int | None]
335335

336-
def _[T](a: X[T]):
337-
b: X[T | int] = X(a.value)
338-
reveal_type(b) # revealed: X[T@_ | int]
336+
def _[T](x1: X[T]):
337+
x2: X[T | int] = X(x1.value)
338+
reveal_type(x2) # revealed: X[T@_ | int]
339339

340-
d: X[Any] = X(1)
341-
reveal_type(d) # revealed: X[Any]
340+
x4: X[Any] = X(1)
341+
reveal_type(x4) # revealed: X[Any]
342342

343343
def _(flag: bool):
344-
a: X[int | None] = X(1) if flag else X(2)
345-
reveal_type(a) # revealed: X[int | None]
344+
x5: X[int | None] = X(1) if flag else X(2)
345+
reveal_type(x5) # revealed: X[int | None]
346346
```
347347

348348
```py
@@ -353,8 +353,7 @@ class Y[T]:
353353
value: T
354354

355355
y1: Y[Any] = Y(value=1)
356-
# TODO: This should reveal `Y[Any]`.
357-
reveal_type(y1) # revealed: Y[int]
356+
reveal_type(y1) # revealed: Y[Any]
358357
```
359358

360359
```py
@@ -363,8 +362,7 @@ class Z[T]:
363362
return super().__new__(cls)
364363

365364
z1: Z[Any] = Z(1)
366-
# TODO: This should reveal `Z[Any]`.
367-
reveal_type(z1) # revealed: Z[int]
365+
reveal_type(z1) # revealed: Z[Any]
368366
```
369367

370368
## PEP-604 annotations are supported

crates/ty_python_semantic/src/types.rs

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,22 +1132,6 @@ impl<'db> Type<'db> {
11321132
}
11331133
}
11341134

1135-
/// If the type is a generic class constructor, returns the class instance type.
1136-
pub(crate) fn synthesized_constructor_return_ty(self, db: &'db dyn Db) -> Option<Type<'db>> {
1137-
// TODO: This does not correctly handle unions or intersections. It also does not handle
1138-
// constructors that are not represented as bound methods, e.g. `__new__`, or synthesized
1139-
// dataclass initializers.
1140-
if let Type::BoundMethod(method) = self
1141-
&& let Type::NominalInstance(instance) = method.self_instance(db)
1142-
&& method.function(db).name(db).as_str() == "__init__"
1143-
{
1144-
let class_ty = instance.class_literal(db).identity_specialization(db);
1145-
Some(Type::instance(db, class_ty))
1146-
} else {
1147-
None
1148-
}
1149-
}
1150-
11511135
pub const fn is_property_instance(&self) -> bool {
11521136
matches!(self, Type::PropertyInstance(..))
11531137
}
@@ -6340,8 +6324,12 @@ impl<'db> Type<'db> {
63406324
let new_call_outcome = new_method.and_then(|new_method| {
63416325
match new_method.place.try_call_dunder_get(db, self_type) {
63426326
Place::Defined(new_method, _, boundness) => {
6343-
let result =
6344-
new_method.try_call(db, argument_types.with_self(Some(self_type)).as_ref());
6327+
let argument_types = argument_types.with_self(Some(self_type));
6328+
let result = new_method
6329+
.bindings(db)
6330+
.with_constructor_instance_type(init_ty)
6331+
.match_parameters(db, &argument_types)
6332+
.check_types(db, &argument_types, tcx, &[]);
63456333

63466334
if boundness == Definedness::PossiblyUndefined {
63476335
Some(Err(DunderNewCallError::PossiblyUnbound(result.err())))
@@ -6354,7 +6342,35 @@ impl<'db> Type<'db> {
63546342
});
63556343

63566344
let init_call_outcome = if new_call_outcome.is_none() || !init_method.is_undefined() {
6357-
Some(init_ty.try_call_dunder(db, "__init__", argument_types, tcx))
6345+
let call_result = match init_ty
6346+
.member_lookup_with_policy(
6347+
db,
6348+
"__init__".into(),
6349+
MemberLookupPolicy::NO_INSTANCE_FALLBACK,
6350+
)
6351+
.place
6352+
{
6353+
Place::Undefined => Err(CallDunderError::MethodNotAvailable),
6354+
Place::Defined(dunder_callable, _, boundness) => {
6355+
let bindings = dunder_callable
6356+
.bindings(db)
6357+
.with_constructor_instance_type(init_ty);
6358+
6359+
bindings
6360+
.match_parameters(db, &argument_types)
6361+
.check_types(db, &argument_types, tcx, &[])
6362+
.map_err(CallDunderError::from)
6363+
.and_then(|bindings| {
6364+
if boundness == Definedness::PossiblyUndefined {
6365+
Err(CallDunderError::PossiblyUnbound(Box::new(bindings)))
6366+
} else {
6367+
Ok(bindings)
6368+
}
6369+
})
6370+
}
6371+
};
6372+
6373+
Some(call_result)
63586374
} else {
63596375
None
63606376
};

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ pub(crate) struct Bindings<'db> {
5353
/// The type that is (hopefully) callable.
5454
callable_type: Type<'db>,
5555

56+
/// The type of the instance being constructed, if this signature is for a constructor.
57+
constructor_instance_type: Option<Type<'db>>,
58+
5659
/// By using `SmallVec`, we avoid an extra heap allocation for the common case of a non-union
5760
/// type.
5861
elements: SmallVec<[CallableBinding<'db>; 1]>,
@@ -77,6 +80,7 @@ impl<'db> Bindings<'db> {
7780
callable_type,
7881
elements,
7982
argument_forms: ArgumentForms::new(0),
83+
constructor_instance_type: None,
8084
}
8185
}
8286

@@ -89,6 +93,22 @@ impl<'db> Bindings<'db> {
8993
}
9094
}
9195

96+
pub(crate) fn with_constructor_instance_type(
97+
mut self,
98+
constructor_instance_type: Type<'db>,
99+
) -> Self {
100+
self.constructor_instance_type = Some(constructor_instance_type);
101+
102+
for binding in &mut self.elements {
103+
binding.constructor_instance_type = Some(constructor_instance_type);
104+
for binding in &mut binding.overloads {
105+
binding.constructor_instance_type = Some(constructor_instance_type);
106+
}
107+
}
108+
109+
self
110+
}
111+
92112
pub(crate) fn set_dunder_call_is_possibly_unbound(&mut self) {
93113
for binding in &mut self.elements {
94114
binding.dunder_call_is_possibly_unbound = true;
@@ -107,6 +127,7 @@ impl<'db> Bindings<'db> {
107127
Self {
108128
callable_type: self.callable_type,
109129
argument_forms: self.argument_forms,
130+
constructor_instance_type: self.constructor_instance_type,
110131
elements: self.elements.into_iter().map(f).collect(),
111132
}
112133
}
@@ -240,6 +261,10 @@ impl<'db> Bindings<'db> {
240261
self.callable_type
241262
}
242263

264+
pub(crate) fn constructor_instance_type(&self) -> Option<Type<'db>> {
265+
self.constructor_instance_type
266+
}
267+
243268
/// Returns the return type of the call. For successful calls, this is the actual return type.
244269
/// For calls with binding errors, this is a type that best approximates the return type. For
245270
/// types that are not callable, returns `Type::Unknown`.
@@ -1357,6 +1382,7 @@ impl<'db> From<CallableBinding<'db>> for Bindings<'db> {
13571382
callable_type: from.callable_type,
13581383
elements: smallvec_inline![from],
13591384
argument_forms: ArgumentForms::new(0),
1385+
constructor_instance_type: None,
13601386
}
13611387
}
13621388
}
@@ -1370,6 +1396,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
13701396
signature_type,
13711397
dunder_call_is_possibly_unbound: false,
13721398
bound_type: None,
1399+
constructor_instance_type: None,
13731400
overload_call_return_type: None,
13741401
matching_overload_before_type_checking: None,
13751402
overloads: smallvec_inline![from],
@@ -1378,6 +1405,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
13781405
callable_type,
13791406
elements: smallvec_inline![callable_binding],
13801407
argument_forms: ArgumentForms::new(0),
1408+
constructor_instance_type: None,
13811409
}
13821410
}
13831411
}
@@ -1409,6 +1437,9 @@ pub(crate) struct CallableBinding<'db> {
14091437
/// The type of the bound `self` or `cls` parameter if this signature is for a bound method.
14101438
pub(crate) bound_type: Option<Type<'db>>,
14111439

1440+
/// The type of the instance being constructed, if this signature is for a constructor.
1441+
pub(crate) constructor_instance_type: Option<Type<'db>>,
1442+
14121443
/// The return type of this overloaded callable.
14131444
///
14141445
/// This is [`Some`] only in the following cases:
@@ -1457,6 +1488,7 @@ impl<'db> CallableBinding<'db> {
14571488
signature_type,
14581489
dunder_call_is_possibly_unbound: false,
14591490
bound_type: None,
1491+
constructor_instance_type: None,
14601492
overload_call_return_type: None,
14611493
matching_overload_before_type_checking: None,
14621494
overloads,
@@ -1469,6 +1501,7 @@ impl<'db> CallableBinding<'db> {
14691501
signature_type,
14701502
dunder_call_is_possibly_unbound: false,
14711503
bound_type: None,
1504+
constructor_instance_type: None,
14721505
overload_call_return_type: None,
14731506
matching_overload_before_type_checking: None,
14741507
overloads: smallvec![],
@@ -2689,7 +2722,7 @@ struct ArgumentTypeChecker<'a, 'db> {
26892722
arguments: &'a CallArguments<'a, 'db>,
26902723
argument_matches: &'a [MatchedArgument<'db>],
26912724
parameter_tys: &'a mut [Option<Type<'db>>],
2692-
callable_type: Type<'db>,
2725+
constructor_instance_type: Option<Type<'db>>,
26932726
call_expression_tcx: TypeContext<'db>,
26942727
return_ty: Type<'db>,
26952728
errors: &'a mut Vec<BindingError<'db>>,
@@ -2706,7 +2739,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27062739
arguments: &'a CallArguments<'a, 'db>,
27072740
argument_matches: &'a [MatchedArgument<'db>],
27082741
parameter_tys: &'a mut [Option<Type<'db>>],
2709-
callable_type: Type<'db>,
2742+
constructor_instance_type: Option<Type<'db>>,
27102743
call_expression_tcx: TypeContext<'db>,
27112744
return_ty: Type<'db>,
27122745
errors: &'a mut Vec<BindingError<'db>>,
@@ -2717,7 +2750,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27172750
arguments,
27182751
argument_matches,
27192752
parameter_tys,
2720-
callable_type,
2753+
constructor_instance_type,
27212754
call_expression_tcx,
27222755
return_ty,
27232756
errors,
@@ -2759,8 +2792,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27592792
};
27602793

27612794
let return_with_tcx = self
2762-
.callable_type
2763-
.synthesized_constructor_return_ty(self.db)
2795+
.constructor_instance_type
27642796
.or(self.signature.return_ty)
27652797
.zip(self.call_expression_tcx.annotation);
27662798

@@ -3109,6 +3141,9 @@ pub(crate) struct Binding<'db> {
31093141
/// it may be a `__call__` method.
31103142
pub(crate) signature_type: Type<'db>,
31113143

3144+
/// The type of the instance being constructed, if this signature is for a constructor.
3145+
pub(crate) constructor_instance_type: Option<Type<'db>>,
3146+
31123147
/// Return type of the call.
31133148
return_ty: Type<'db>,
31143149

@@ -3140,6 +3175,7 @@ impl<'db> Binding<'db> {
31403175
signature,
31413176
callable_type: signature_type,
31423177
signature_type,
3178+
constructor_instance_type: None,
31433179
return_ty: Type::unknown(),
31443180
inferable_typevars: InferableTypeVars::None,
31453181
specialization: None,
@@ -3204,7 +3240,7 @@ impl<'db> Binding<'db> {
32043240
arguments,
32053241
&self.argument_matches,
32063242
&mut self.parameter_tys,
3207-
self.callable_type,
3243+
self.constructor_instance_type,
32083244
call_expression_tcx,
32093245
self.return_ty,
32103246
&mut self.errors,

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6527,10 +6527,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
65276527
// TODO: Checking assignability against the full declared type could help avoid
65286528
// cases where the constraint solver is not smart enough to solve complex unions.
65296529
// We should see revisit this after the new constraint solver is implemented.
6530-
if speculated_bindings
6531-
.callable_type()
6532-
.synthesized_constructor_return_ty(db)
6533-
.is_none()
6530+
if speculated_bindings.constructor_instance_type().is_none()
65346531
&& !speculated_bindings
65356532
.return_type(db)
65366533
.is_assignable_to(db, narrowed_ty)

0 commit comments

Comments
 (0)