Skip to content

Commit 9db63fc

Browse files
authored
[red-knot] Handle generic constructors of generic classes (#17552)
We now handle generic constructor methods on generic classes correctly: ```py class C[T]: def __init__[S](self, t: T, s: S): ... x = C(1, "str") ``` Here, constructing `C` requires us to infer a specialization for the generic contexts of `C` and `__init__` at the same time. At first I thought I would need to track the full stack of nested generic contexts here (since the `[S]` context is nested within the `[T]` context). But I think this is the only way that we might need to specialize more than one generic context at once — in all other cases, a containing generic context must be specialized before we get to a nested one, and so we can just special-case this. While we're here, we also construct the generic context for a generic function lazily, when its signature is accessed, instead of eagerly when inferring the function body.
1 parent 61e7348 commit 9db63fc

File tree

7 files changed

+102
-60
lines changed

7 files changed

+102
-60
lines changed

crates/red_knot_python_semantic/resources/mdtest/generics/classes.md

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -232,21 +232,11 @@ TODO: These do not currently work yet, because we don't correctly model the nest
232232
class C[T]:
233233
def __init__[S](self, x: T, y: S) -> None: ...
234234

235-
# TODO: no error
236-
# TODO: revealed: C[Literal[1]]
237-
# error: [invalid-argument-type]
238-
reveal_type(C(1, 1)) # revealed: C[Unknown]
239-
# TODO: no error
240-
# TODO: revealed: C[Literal[1]]
241-
# error: [invalid-argument-type]
242-
reveal_type(C(1, "string")) # revealed: C[Unknown]
243-
# TODO: no error
244-
# TODO: revealed: C[Literal[1]]
245-
# error: [invalid-argument-type]
246-
reveal_type(C(1, True)) # revealed: C[Unknown]
247-
248-
# TODO: [invalid-assignment] "Object of type `C[Literal["five"]]` is not assignable to `C[int]`"
249-
# error: [invalid-argument-type] "Argument to this function is incorrect: Expected `S`, found `Literal[1]`"
235+
reveal_type(C(1, 1)) # revealed: C[Literal[1]]
236+
reveal_type(C(1, "string")) # revealed: C[Literal[1]]
237+
reveal_type(C(1, True)) # revealed: C[Literal[1]]
238+
239+
# error: [invalid-assignment] "Object of type `C[Literal["five"]]` is not assignable to `C[int]`"
250240
wrong_innards: C[int] = C("five", 1)
251241
```
252242

crates/red_knot_python_semantic/src/types.rs

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4268,13 +4268,13 @@ impl<'db> Type<'db> {
42684268
.as_ref()
42694269
.and_then(Bindings::single_element)
42704270
.and_then(CallableBinding::matching_overload)
4271-
.and_then(|(_, binding)| binding.specialization());
4271+
.and_then(|(_, binding)| binding.inherited_specialization());
42724272
let init_specialization = init_call_outcome
42734273
.and_then(Result::ok)
42744274
.as_ref()
42754275
.and_then(Bindings::single_element)
42764276
.and_then(CallableBinding::matching_overload)
4277-
.and_then(|(_, binding)| binding.specialization());
4277+
.and_then(|(_, binding)| binding.inherited_specialization());
42784278
let specialization = match (new_specialization, init_specialization) {
42794279
(None, None) => None,
42804280
(Some(specialization), None) | (None, Some(specialization)) => {
@@ -5940,8 +5940,10 @@ pub struct FunctionType<'db> {
59405940
/// with `@dataclass_transformer(...)`.
59415941
dataclass_transformer_params: Option<DataclassTransformerParams>,
59425942

5943-
/// The generic context of a generic function.
5944-
generic_context: Option<GenericContext<'db>>,
5943+
/// The inherited generic context, if this function is a class method being used to infer the
5944+
/// specialization of its generic class. If the method is itself generic, this is in addition
5945+
/// to its own generic context.
5946+
inherited_generic_context: Option<GenericContext<'db>>,
59455947

59465948
/// A specialization that should be applied to the function's parameter and return types,
59475949
/// either because the function is itself generic, or because it appears in the body of a
@@ -6007,11 +6009,7 @@ impl<'db> FunctionType<'db> {
60076009
/// would depend on the function's AST and rerun for every change in that file.
60086010
#[salsa::tracked(return_ref)]
60096011
pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
6010-
let mut internal_signature = self.internal_signature(db);
6011-
6012-
if let Some(specialization) = self.specialization(db) {
6013-
internal_signature = internal_signature.apply_specialization(db, specialization);
6014-
}
6012+
let internal_signature = self.internal_signature(db);
60156013

60166014
// The semantic model records a use for each function on the name node. This is used here
60176015
// to get the previous function definition with the same name.
@@ -6071,22 +6069,59 @@ impl<'db> FunctionType<'db> {
60716069
let scope = self.body_scope(db);
60726070
let function_stmt_node = scope.node(db).expect_function();
60736071
let definition = self.definition(db);
6074-
Signature::from_function(db, self.generic_context(db), definition, function_stmt_node)
6072+
let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| {
6073+
let index = semantic_index(db, scope.file(db));
6074+
GenericContext::from_type_params(db, index, type_params)
6075+
});
6076+
let mut signature = Signature::from_function(
6077+
db,
6078+
generic_context,
6079+
self.inherited_generic_context(db),
6080+
definition,
6081+
function_stmt_node,
6082+
);
6083+
if let Some(specialization) = self.specialization(db) {
6084+
signature = signature.apply_specialization(db, specialization);
6085+
}
6086+
signature
60756087
}
60766088

60776089
pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
60786090
self.known(db) == Some(known_function)
60796091
}
60806092

6081-
fn with_generic_context(self, db: &'db dyn Db, generic_context: GenericContext<'db>) -> Self {
6093+
fn with_dataclass_transformer_params(
6094+
self,
6095+
db: &'db dyn Db,
6096+
params: DataclassTransformerParams,
6097+
) -> Self {
6098+
Self::new(
6099+
db,
6100+
self.name(db).clone(),
6101+
self.known(db),
6102+
self.body_scope(db),
6103+
self.decorators(db),
6104+
Some(params),
6105+
self.inherited_generic_context(db),
6106+
self.specialization(db),
6107+
)
6108+
}
6109+
6110+
fn with_inherited_generic_context(
6111+
self,
6112+
db: &'db dyn Db,
6113+
inherited_generic_context: GenericContext<'db>,
6114+
) -> Self {
6115+
// A function cannot inherit more than one generic context from its containing class.
6116+
debug_assert!(self.inherited_generic_context(db).is_none());
60826117
Self::new(
60836118
db,
60846119
self.name(db).clone(),
60856120
self.known(db),
60866121
self.body_scope(db),
60876122
self.decorators(db),
60886123
self.dataclass_transformer_params(db),
6089-
Some(generic_context),
6124+
Some(inherited_generic_context),
60906125
self.specialization(db),
60916126
)
60926127
}
@@ -6103,7 +6138,7 @@ impl<'db> FunctionType<'db> {
61036138
self.body_scope(db),
61046139
self.decorators(db),
61056140
self.dataclass_transformer_params(db),
6106-
self.generic_context(db),
6141+
self.inherited_generic_context(db),
61076142
Some(specialization),
61086143
)
61096144
}

crates/red_knot_python_semantic/src/types/call/bind.rs

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ use crate::types::diagnostic::{
1919
use crate::types::generics::{Specialization, SpecializationBuilder};
2020
use crate::types::signatures::{Parameter, ParameterForm};
2121
use crate::types::{
22-
BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, FunctionType,
23-
KnownClass, KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType,
24-
UnionType, WrapperDescriptorKind,
22+
BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, KnownClass,
23+
KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType,
24+
WrapperDescriptorKind,
2525
};
2626
use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic};
2727
use ruff_python_ast as ast;
@@ -424,16 +424,9 @@ impl<'db> Bindings<'db> {
424424

425425
Type::DataclassTransformer(params) => {
426426
if let [Some(Type::FunctionLiteral(function))] = overload.parameter_types() {
427-
overload.set_return_type(Type::FunctionLiteral(FunctionType::new(
428-
db,
429-
function.name(db),
430-
function.known(db),
431-
function.body_scope(db),
432-
function.decorators(db),
433-
Some(params),
434-
function.generic_context(db),
435-
function.specialization(db),
436-
)));
427+
overload.set_return_type(Type::FunctionLiteral(
428+
function.with_dataclass_transformer_params(db, params),
429+
));
437430
}
438431
}
439432

@@ -961,6 +954,10 @@ pub(crate) struct Binding<'db> {
961954
/// The specialization that was inferred from the argument types, if the callable is generic.
962955
specialization: Option<Specialization<'db>>,
963956

957+
/// The specialization that was inferred for a class method's containing generic class, if it
958+
/// is being used to infer a specialization for the class.
959+
inherited_specialization: Option<Specialization<'db>>,
960+
964961
/// The formal parameter that each argument is matched with, in argument source order, or
965962
/// `None` if the argument was not matched to any parameter.
966963
argument_parameters: Box<[Option<usize>]>,
@@ -1097,6 +1094,7 @@ impl<'db> Binding<'db> {
10971094
Self {
10981095
return_ty: signature.return_ty.unwrap_or(Type::unknown()),
10991096
specialization: None,
1097+
inherited_specialization: None,
11001098
argument_parameters: argument_parameters.into_boxed_slice(),
11011099
parameter_tys: vec![None; parameters.len()].into_boxed_slice(),
11021100
errors,
@@ -1112,8 +1110,8 @@ impl<'db> Binding<'db> {
11121110
// If this overload is generic, first see if we can infer a specialization of the function
11131111
// from the arguments that were passed in.
11141112
let parameters = signature.parameters();
1115-
self.specialization = signature.generic_context.map(|generic_context| {
1116-
let mut builder = SpecializationBuilder::new(db, generic_context);
1113+
if signature.generic_context.is_some() || signature.inherited_generic_context.is_some() {
1114+
let mut builder = SpecializationBuilder::new(db);
11171115
for (argument_index, (_, argument_type)) in argument_types.iter().enumerate() {
11181116
let Some(parameter_index) = self.argument_parameters[argument_index] else {
11191117
// There was an error with argument when matching parameters, so don't bother
@@ -1126,8 +1124,11 @@ impl<'db> Binding<'db> {
11261124
};
11271125
builder.infer(expected_type, argument_type);
11281126
}
1129-
builder.build()
1130-
});
1127+
self.specialization = signature.generic_context.map(|gc| builder.build(gc));
1128+
self.inherited_specialization = signature
1129+
.inherited_generic_context
1130+
.map(|gc| builder.build(gc));
1131+
}
11311132

11321133
let mut num_synthetic_args = 0;
11331134
let get_argument_index = |argument_index: usize, num_synthetic_args: usize| {
@@ -1155,6 +1156,9 @@ impl<'db> Binding<'db> {
11551156
if let Some(specialization) = self.specialization {
11561157
expected_ty = expected_ty.apply_specialization(db, specialization);
11571158
}
1159+
if let Some(inherited_specialization) = self.inherited_specialization {
1160+
expected_ty = expected_ty.apply_specialization(db, inherited_specialization);
1161+
}
11581162
if !argument_type.is_assignable_to(db, expected_ty) {
11591163
let positional = matches!(argument, Argument::Positional | Argument::Synthetic)
11601164
&& !parameter.is_variadic();
@@ -1180,6 +1184,11 @@ impl<'db> Binding<'db> {
11801184
if let Some(specialization) = self.specialization {
11811185
self.return_ty = self.return_ty.apply_specialization(db, specialization);
11821186
}
1187+
if let Some(inherited_specialization) = self.inherited_specialization {
1188+
self.return_ty = self
1189+
.return_ty
1190+
.apply_specialization(db, inherited_specialization);
1191+
}
11831192
}
11841193

11851194
pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) {
@@ -1190,8 +1199,8 @@ impl<'db> Binding<'db> {
11901199
self.return_ty
11911200
}
11921201

1193-
pub(crate) fn specialization(&self) -> Option<Specialization<'db>> {
1194-
self.specialization
1202+
pub(crate) fn inherited_specialization(&self) -> Option<Specialization<'db>> {
1203+
self.inherited_specialization
11951204
}
11961205

11971206
pub(crate) fn parameter_types(&self) -> &[Option<Type<'db>>] {

crates/red_knot_python_semantic/src/types/class.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,7 @@ impl<'db> ClassLiteralType<'db> {
10171017
Some(_),
10181018
"__new__" | "__init__",
10191019
) => Type::FunctionLiteral(
1020-
function.with_generic_context(db, origin.generic_context(db)),
1020+
function.with_inherited_generic_context(db, origin.generic_context(db)),
10211021
),
10221022
_ => ty,
10231023
}

crates/red_knot_python_semantic/src/types/generics.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -299,22 +299,19 @@ impl<'db> Specialization<'db> {
299299
/// specialization of a generic function.
300300
pub(crate) struct SpecializationBuilder<'db> {
301301
db: &'db dyn Db,
302-
generic_context: GenericContext<'db>,
303302
types: FxHashMap<TypeVarInstance<'db>, UnionBuilder<'db>>,
304303
}
305304

306305
impl<'db> SpecializationBuilder<'db> {
307-
pub(crate) fn new(db: &'db dyn Db, generic_context: GenericContext<'db>) -> Self {
306+
pub(crate) fn new(db: &'db dyn Db) -> Self {
308307
Self {
309308
db,
310-
generic_context,
311309
types: FxHashMap::default(),
312310
}
313311
}
314312

315-
pub(crate) fn build(mut self) -> Specialization<'db> {
316-
let types: Box<[_]> = self
317-
.generic_context
313+
pub(crate) fn build(&mut self, generic_context: GenericContext<'db>) -> Specialization<'db> {
314+
let types: Box<[_]> = generic_context
318315
.variables(self.db)
319316
.iter()
320317
.map(|variable| {
@@ -324,7 +321,7 @@ impl<'db> SpecializationBuilder<'db> {
324321
.unwrap_or(variable.default_ty(self.db).unwrap_or(Type::unknown()))
325322
})
326323
.collect();
327-
Specialization::new(self.db, self.generic_context, types)
324+
Specialization::new(self.db, generic_context, types)
328325
}
329326

330327
fn add_type_mapping(&mut self, typevar: TypeVarInstance<'db>, ty: Type<'db>) {

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,10 +1525,6 @@ impl<'db> TypeInferenceBuilder<'db> {
15251525
}
15261526
}
15271527

1528-
let generic_context = type_params.as_ref().map(|type_params| {
1529-
GenericContext::from_type_params(self.db(), self.index, type_params)
1530-
});
1531-
15321528
let function_kind =
15331529
KnownFunction::try_from_definition_and_name(self.db(), definition, name);
15341530

@@ -1537,6 +1533,7 @@ impl<'db> TypeInferenceBuilder<'db> {
15371533
.node_scope(NodeWithScopeRef::Function(function))
15381534
.to_scope_id(self.db(), self.file());
15391535

1536+
let inherited_generic_context = None;
15401537
let specialization = None;
15411538

15421539
let mut inferred_ty = Type::FunctionLiteral(FunctionType::new(
@@ -1546,7 +1543,7 @@ impl<'db> TypeInferenceBuilder<'db> {
15461543
body_scope,
15471544
function_decorators,
15481545
dataclass_transformer_params,
1549-
generic_context,
1546+
inherited_generic_context,
15501547
specialization,
15511548
));
15521549

0 commit comments

Comments
 (0)