Skip to content

Commit b83ac5e

Browse files
[ty] Clean up inherited generic contexts (#20647)
We add an `inherited_generic_context` to the constructors of a generic class. That lets us infer specializations of the class when invoking the constructor. The constructor might itself be generic, in which case we have to merge the list of typevars that we are willing to infer in the constructor call. Before we did that by tracking the two (and their specializations) separately, with distinct `Option` fields/parameters. This PR updates our call binding logic such that any given function call has _one_ optional generic context that we're willing to infer a specialization for. If needed, we use the existing `GenericContext::merge` method to create a new combined generic context for when the class and constructor are both generic. This simplifies the call binding code considerably, and is no more complex in the constructor call logic. We also have a heuristic that we will promote any literals in the specialized types of a generic class, but we don't promote literals in the specialized types of the function itself. To handle this, we now track this `should_promote_literals` property within `GenericContext`. And moreover, we track this separately for each typevar, instead of a single property for the generic context as a whole, so that we can correctly merge the generic context of a constructor method (where the option should be `false`) with the inherited generic context of its containing class (where the option should be `true`). --------- Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
1 parent c91b457 commit b83ac5e

File tree

9 files changed

+231
-290
lines changed

9 files changed

+231
-290
lines changed

crates/ruff_memory_usage/src/lib.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::sync::{LazyLock, Mutex};
22

33
use get_size2::{GetSize, StandardTracker};
4-
use ordermap::OrderSet;
4+
use ordermap::{OrderMap, OrderSet};
55

66
/// Returns the memory usage of the provided object, using a global tracker to avoid
77
/// double-counting shared objects.
@@ -18,3 +18,11 @@ pub fn heap_size<T: GetSize>(value: &T) -> usize {
1818
pub fn order_set_heap_size<T: GetSize, S>(set: &OrderSet<T, S>) -> usize {
1919
(set.capacity() * T::get_stack_size()) + set.iter().map(heap_size).sum::<usize>()
2020
}
21+
22+
/// An implementation of [`GetSize::get_heap_size`] for [`OrderMap`].
23+
pub fn order_map_heap_size<K: GetSize, V: GetSize, S>(map: &OrderMap<K, V, S>) -> usize {
24+
(map.capacity() * (K::get_stack_size() + V::get_stack_size()))
25+
+ (map.iter())
26+
.map(|(k, v)| heap_size(k) + heap_size(v))
27+
.sum::<usize>()
28+
}

crates/ty_python_semantic/src/types.rs

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5456,28 +5456,19 @@ impl<'db> Type<'db> {
54565456
}
54575457
}
54585458

5459-
let new_specialization = new_call_outcome
5460-
.and_then(Result::ok)
5461-
.as_ref()
5462-
.and_then(Bindings::single_element)
5463-
.into_iter()
5464-
.flat_map(CallableBinding::matching_overloads)
5465-
.next()
5466-
.and_then(|(_, binding)| binding.inherited_specialization())
5467-
.filter(|specialization| {
5468-
Some(specialization.generic_context(db)) == generic_context
5469-
});
5470-
let init_specialization = init_call_outcome
5471-
.and_then(Result::ok)
5472-
.as_ref()
5473-
.and_then(Bindings::single_element)
5474-
.into_iter()
5475-
.flat_map(CallableBinding::matching_overloads)
5476-
.next()
5477-
.and_then(|(_, binding)| binding.inherited_specialization())
5478-
.filter(|specialization| {
5479-
Some(specialization.generic_context(db)) == generic_context
5480-
});
5459+
let specialize_constructor = |outcome: Option<Bindings<'db>>| {
5460+
let (_, binding) = outcome
5461+
.as_ref()?
5462+
.single_element()?
5463+
.matching_overloads()
5464+
.next()?;
5465+
binding.specialization()?.restrict(db, generic_context?)
5466+
};
5467+
5468+
let new_specialization =
5469+
specialize_constructor(new_call_outcome.and_then(Result::ok));
5470+
let init_specialization =
5471+
specialize_constructor(init_call_outcome.and_then(Result::ok));
54815472
let specialization =
54825473
combine_specializations(db, new_specialization, init_specialization);
54835474
let specialized = specialization
@@ -6768,21 +6759,19 @@ impl<'db> TypeMapping<'_, 'db> {
67686759
db,
67696760
context
67706761
.variables(db)
6771-
.iter()
6772-
.filter(|var| !var.typevar(db).is_self(db))
6773-
.copied(),
6762+
.filter(|var| !var.typevar(db).is_self(db)),
67746763
),
67756764
TypeMapping::ReplaceSelf { new_upper_bound } => GenericContext::from_typevar_instances(
67766765
db,
6777-
context.variables(db).iter().map(|typevar| {
6766+
context.variables(db).map(|typevar| {
67786767
if typevar.typevar(db).is_self(db) {
67796768
BoundTypeVarInstance::synthetic_self(
67806769
db,
67816770
*new_upper_bound,
67826771
typevar.binding_context(db),
67836772
)
67846773
} else {
6785-
*typevar
6774+
typevar
67866775
}
67876776
}),
67886777
),

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 9 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use crate::types::tuple::{TupleLength, TupleType};
3232
use crate::types::{
3333
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
3434
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
35-
TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionBuilder, UnionType,
35+
TrackedConstraintSet, TypeAliasType, TypeContext, UnionBuilder, UnionType,
3636
WrapperDescriptorKind, enums, ide_support, todo_type,
3737
};
3838
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
@@ -1701,10 +1701,6 @@ impl<'db> CallableBinding<'db> {
17011701
parameter_type =
17021702
parameter_type.apply_specialization(db, specialization);
17031703
}
1704-
if let Some(inherited_specialization) = overload.inherited_specialization {
1705-
parameter_type =
1706-
parameter_type.apply_specialization(db, inherited_specialization);
1707-
}
17081704
union_parameter_types[parameter_index.saturating_sub(skipped_parameters)]
17091705
.add_in_place(parameter_type);
17101706
}
@@ -1983,7 +1979,7 @@ impl<'db> CallableBinding<'db> {
19831979
for overload in overloads.iter().take(MAXIMUM_OVERLOADS) {
19841980
diag.info(format_args!(
19851981
" {}",
1986-
overload.signature(context.db(), None).display(context.db())
1982+
overload.signature(context.db()).display(context.db())
19871983
));
19881984
}
19891985
if overloads.len() > MAXIMUM_OVERLOADS {
@@ -2444,7 +2440,6 @@ struct ArgumentTypeChecker<'a, 'db> {
24442440
errors: &'a mut Vec<BindingError<'db>>,
24452441

24462442
specialization: Option<Specialization<'db>>,
2447-
inherited_specialization: Option<Specialization<'db>>,
24482443
}
24492444

24502445
impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
@@ -2466,7 +2461,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
24662461
call_expression_tcx,
24672462
errors,
24682463
specialization: None,
2469-
inherited_specialization: None,
24702464
}
24712465
}
24722466

@@ -2498,9 +2492,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
24982492
}
24992493

25002494
fn infer_specialization(&mut self) {
2501-
if self.signature.generic_context.is_none()
2502-
&& self.signature.inherited_generic_context.is_none()
2503-
{
2495+
if self.signature.generic_context.is_none() {
25042496
return;
25052497
}
25062498

@@ -2542,14 +2534,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25422534
}
25432535

25442536
self.specialization = self.signature.generic_context.map(|gc| builder.build(gc));
2545-
self.inherited_specialization = self.signature.inherited_generic_context.map(|gc| {
2546-
// The inherited generic context is used when inferring the specialization of a generic
2547-
// class from a constructor call. In this case (only), we promote any typevars that are
2548-
// inferred as a literal to the corresponding instance type.
2549-
builder
2550-
.build(gc)
2551-
.apply_type_mapping(self.db, &TypeMapping::PromoteLiterals)
2552-
});
25532537
}
25542538

25552539
fn check_argument_type(
@@ -2566,11 +2550,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25662550
argument_type = argument_type.apply_specialization(self.db, specialization);
25672551
expected_ty = expected_ty.apply_specialization(self.db, specialization);
25682552
}
2569-
if let Some(inherited_specialization) = self.inherited_specialization {
2570-
argument_type =
2571-
argument_type.apply_specialization(self.db, inherited_specialization);
2572-
expected_ty = expected_ty.apply_specialization(self.db, inherited_specialization);
2573-
}
25742553
// This is one of the few places where we want to check if there's _any_ specialization
25752554
// where assignability holds; normally we want to check that assignability holds for
25762555
// _all_ specializations.
@@ -2742,8 +2721,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27422721
}
27432722
}
27442723

2745-
fn finish(self) -> (Option<Specialization<'db>>, Option<Specialization<'db>>) {
2746-
(self.specialization, self.inherited_specialization)
2724+
fn finish(self) -> Option<Specialization<'db>> {
2725+
self.specialization
27472726
}
27482727
}
27492728

@@ -2807,10 +2786,6 @@ pub(crate) struct Binding<'db> {
28072786
/// The specialization that was inferred from the argument types, if the callable is generic.
28082787
specialization: Option<Specialization<'db>>,
28092788

2810-
/// The specialization that was inferred for a class method's containing generic class, if it
2811-
/// is being used to infer a specialization for the class.
2812-
inherited_specialization: Option<Specialization<'db>>,
2813-
28142789
/// Information about which parameter(s) each argument was matched with, in argument source
28152790
/// order.
28162791
argument_matches: Box<[MatchedArgument<'db>]>,
@@ -2835,7 +2810,6 @@ impl<'db> Binding<'db> {
28352810
signature_type,
28362811
return_ty: Type::unknown(),
28372812
specialization: None,
2838-
inherited_specialization: None,
28392813
argument_matches: Box::from([]),
28402814
variadic_argument_matched_to_variadic_parameter: false,
28412815
parameter_tys: Box::from([]),
@@ -2906,15 +2880,10 @@ impl<'db> Binding<'db> {
29062880
checker.infer_specialization();
29072881

29082882
checker.check_argument_types();
2909-
(self.specialization, self.inherited_specialization) = checker.finish();
2883+
self.specialization = checker.finish();
29102884
if let Some(specialization) = self.specialization {
29112885
self.return_ty = self.return_ty.apply_specialization(db, specialization);
29122886
}
2913-
if let Some(inherited_specialization) = self.inherited_specialization {
2914-
self.return_ty = self
2915-
.return_ty
2916-
.apply_specialization(db, inherited_specialization);
2917-
}
29182887
}
29192888

29202889
pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) {
@@ -2925,8 +2894,8 @@ impl<'db> Binding<'db> {
29252894
self.return_ty
29262895
}
29272896

2928-
pub(crate) fn inherited_specialization(&self) -> Option<Specialization<'db>> {
2929-
self.inherited_specialization
2897+
pub(crate) fn specialization(&self) -> Option<Specialization<'db>> {
2898+
self.specialization
29302899
}
29312900

29322901
/// Returns the bound types for each parameter, in parameter source order, or `None` if no
@@ -2988,7 +2957,6 @@ impl<'db> Binding<'db> {
29882957
BindingSnapshot {
29892958
return_ty: self.return_ty,
29902959
specialization: self.specialization,
2991-
inherited_specialization: self.inherited_specialization,
29922960
argument_matches: self.argument_matches.clone(),
29932961
parameter_tys: self.parameter_tys.clone(),
29942962
errors: self.errors.clone(),
@@ -2999,15 +2967,13 @@ impl<'db> Binding<'db> {
29992967
let BindingSnapshot {
30002968
return_ty,
30012969
specialization,
3002-
inherited_specialization,
30032970
argument_matches,
30042971
parameter_tys,
30052972
errors,
30062973
} = snapshot;
30072974

30082975
self.return_ty = return_ty;
30092976
self.specialization = specialization;
3010-
self.inherited_specialization = inherited_specialization;
30112977
self.argument_matches = argument_matches;
30122978
self.parameter_tys = parameter_tys;
30132979
self.errors = errors;
@@ -3027,7 +2993,6 @@ impl<'db> Binding<'db> {
30272993
fn reset(&mut self) {
30282994
self.return_ty = Type::unknown();
30292995
self.specialization = None;
3030-
self.inherited_specialization = None;
30312996
self.argument_matches = Box::from([]);
30322997
self.parameter_tys = Box::from([]);
30332998
self.errors.clear();
@@ -3038,7 +3003,6 @@ impl<'db> Binding<'db> {
30383003
struct BindingSnapshot<'db> {
30393004
return_ty: Type<'db>,
30403005
specialization: Option<Specialization<'db>>,
3041-
inherited_specialization: Option<Specialization<'db>>,
30423006
argument_matches: Box<[MatchedArgument<'db>]>,
30433007
parameter_tys: Box<[Option<Type<'db>>]>,
30443008
errors: Vec<BindingError<'db>>,
@@ -3078,7 +3042,6 @@ impl<'db> CallableBindingSnapshot<'db> {
30783042
// ... and update the snapshot with the current state of the binding.
30793043
snapshot.return_ty = binding.return_ty;
30803044
snapshot.specialization = binding.specialization;
3081-
snapshot.inherited_specialization = binding.inherited_specialization;
30823045
snapshot
30833046
.argument_matches
30843047
.clone_from(&binding.argument_matches);
@@ -3373,7 +3336,7 @@ impl<'db> BindingError<'db> {
33733336
}
33743337
diag.info(format_args!(
33753338
" {}",
3376-
overload.signature(context.db(), None).display(context.db())
3339+
overload.signature(context.db()).display(context.db())
33773340
));
33783341
}
33793342
if overloads.len() > MAXIMUM_OVERLOADS {

crates/ty_python_semantic/src/types/class.rs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,6 @@ impl<'db> VarianceInferable<'db> for GenericAlias<'db> {
324324
specialization
325325
.generic_context(db)
326326
.variables(db)
327-
.iter()
328327
.zip(specialization.types(db))
329328
.map(|(generic_typevar, ty)| {
330329
if let Some(explicit_variance) =
@@ -346,7 +345,7 @@ impl<'db> VarianceInferable<'db> for GenericAlias<'db> {
346345
let typevar_variance_in_substituted_type = ty.variance_of(db, typevar);
347346
origin
348347
.with_polarity(typevar_variance_in_substituted_type)
349-
.variance_of(db, *generic_typevar)
348+
.variance_of(db, generic_typevar)
350349
}
351350
}),
352351
)
@@ -1013,8 +1012,7 @@ impl<'db> ClassType<'db> {
10131012

10141013
let synthesized_dunder = CallableType::function_like(
10151014
db,
1016-
Signature::new(parameters, None)
1017-
.with_inherited_generic_context(inherited_generic_context),
1015+
Signature::new_generic(inherited_generic_context, parameters, None),
10181016
);
10191017

10201018
Place::bound(synthesized_dunder).into()
@@ -1454,6 +1452,16 @@ impl<'db> ClassLiteral<'db> {
14541452
)
14551453
}
14561454

1455+
/// Returns the generic context that should be inherited by any constructor methods of this
1456+
/// class.
1457+
///
1458+
/// When inferring a specialization of the class's generic context from a constructor call, we
1459+
/// promote any typevars that are inferred as a literal to the corresponding instance type.
1460+
fn inherited_generic_context(self, db: &'db dyn Db) -> Option<GenericContext<'db>> {
1461+
self.generic_context(db)
1462+
.map(|generic_context| generic_context.promote_literals(db))
1463+
}
1464+
14571465
fn file(self, db: &dyn Db) -> File {
14581466
self.body_scope(db).file(db)
14591467
}
@@ -1996,7 +2004,7 @@ impl<'db> ClassLiteral<'db> {
19962004
lookup_result = lookup_result.or_else(|lookup_error| {
19972005
lookup_error.or_fall_back_to(
19982006
db,
1999-
class.own_class_member(db, self.generic_context(db), name),
2007+
class.own_class_member(db, self.inherited_generic_context(db), name),
20002008
)
20012009
});
20022010
}
@@ -2246,8 +2254,14 @@ impl<'db> ClassLiteral<'db> {
22462254
// so that the keyword-only parameters appear after positional parameters.
22472255
parameters.sort_by_key(Parameter::is_keyword_only);
22482256

2249-
let mut signature = Signature::new(Parameters::new(parameters), return_ty);
2250-
signature.inherited_generic_context = self.generic_context(db);
2257+
let signature = match name {
2258+
"__new__" | "__init__" => Signature::new_generic(
2259+
self.inherited_generic_context(db),
2260+
Parameters::new(parameters),
2261+
return_ty,
2262+
),
2263+
_ => Signature::new(Parameters::new(parameters), return_ty),
2264+
};
22512265
Some(CallableType::function_like(db, signature))
22522266
};
22532267

@@ -2295,7 +2309,7 @@ impl<'db> ClassLiteral<'db> {
22952309
KnownClass::NamedTupleFallback
22962310
.to_class_literal(db)
22972311
.into_class_literal()?
2298-
.own_class_member(db, self.generic_context(db), None, name)
2312+
.own_class_member(db, self.inherited_generic_context(db), None, name)
22992313
.place
23002314
.ignore_possibly_unbound()
23012315
.map(|ty| {
@@ -5421,7 +5435,7 @@ enum SlotsKind {
54215435
impl SlotsKind {
54225436
fn from(db: &dyn Db, base: ClassLiteral) -> Self {
54235437
let Place::Type(slots_ty, bound) = base
5424-
.own_class_member(db, base.generic_context(db), None, "__slots__")
5438+
.own_class_member(db, base.inherited_generic_context(db), None, "__slots__")
54255439
.place
54265440
else {
54275441
return Self::NotSpecified;

0 commit comments

Comments
 (0)