@@ -2523,6 +2523,7 @@ struct ArgumentTypeChecker<'a, 'db> {
25232523 arguments : & ' a CallArguments < ' a , ' db > ,
25242524 argument_matches : & ' a [ MatchedArgument < ' db > ] ,
25252525 parameter_tys : & ' a mut [ Option < Type < ' db > > ] ,
2526+ callable_type : Type < ' db > ,
25262527 call_expression_tcx : & ' a TypeContext < ' db > ,
25272528 return_ty : Type < ' db > ,
25282529 errors : & ' a mut Vec < BindingError < ' db > > ,
@@ -2539,6 +2540,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25392540 arguments : & ' a CallArguments < ' a , ' db > ,
25402541 argument_matches : & ' a [ MatchedArgument < ' db > ] ,
25412542 parameter_tys : & ' a mut [ Option < Type < ' db > > ] ,
2543+ callable_type : Type < ' db > ,
25422544 call_expression_tcx : & ' a TypeContext < ' db > ,
25432545 return_ty : Type < ' db > ,
25442546 errors : & ' a mut Vec < BindingError < ' db > > ,
@@ -2549,6 +2551,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25492551 arguments,
25502552 argument_matches,
25512553 parameter_tys,
2554+ callable_type,
25522555 call_expression_tcx,
25532556 return_ty,
25542557 errors,
@@ -2623,7 +2626,22 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26232626 . apply_specialization ( self . db , isolated_specialization) ;
26242627
26252628 let mut try_infer_tcx = || {
2626- let return_ty = self . signature . return_ty ?;
2629+ // For generic constructors, we use the type context to infer the specialization of the class
2630+ // instance instead of the method's return type.
2631+ let ( inference_context, generic_constructor) = if let Type :: BoundMethod ( method) =
2632+ self . callable_type
2633+ && let Type :: NominalInstance ( instance) = method. self_instance ( self . db )
2634+ && method. function ( self . db ) . name ( self . db ) == "__init__"
2635+ {
2636+ let class_ty = instance
2637+ . class_literal ( self . db )
2638+ . identity_specialization ( self . db ) ;
2639+
2640+ ( Type :: instance ( self . db , class_ty) , true )
2641+ } else {
2642+ ( self . signature . return_ty ?, false )
2643+ } ;
2644+
26272645 let call_expression_tcx = self . call_expression_tcx . annotation ?;
26282646
26292647 // A type variable is not a useful type-context for expression inference, and applying it
@@ -2634,17 +2652,19 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26342652
26352653 // If the return type is already assignable to the annotated type, we can ignore the
26362654 // type context and prefer the narrower inferred type.
2637- if isolated_return_ty. is_assignable_to ( self . db , call_expression_tcx) {
2655+ if !generic_constructor
2656+ && isolated_return_ty. is_assignable_to ( self . db , call_expression_tcx)
2657+ {
26382658 return None ;
26392659 }
26402660
26412661 // TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an
26422662 // annotated assignment, to closer match the order of any unions written in the type annotation.
2643- builder. infer ( return_ty , call_expression_tcx) . ok ( ) ?;
2663+ builder. infer ( inference_context , call_expression_tcx) . ok ( ) ?;
26442664
26452665 // Otherwise, build the specialization again after inferring the type context.
26462666 let specialization = builder. build ( generic_context, * self . call_expression_tcx ) ;
2647- let return_ty = return_ty. apply_specialization ( self . db , specialization) ;
2667+ let return_ty = self . return_ty . apply_specialization ( self . db , specialization) ;
26482668
26492669 Some ( ( Some ( specialization) , return_ty) )
26502670 } ;
@@ -3009,6 +3029,7 @@ impl<'db> Binding<'db> {
30093029 arguments,
30103030 & self . argument_matches ,
30113031 & mut self . parameter_tys ,
3032+ self . callable_type ,
30123033 call_expression_tcx,
30133034 self . return_ty ,
30143035 & mut self . errors ,
0 commit comments