@@ -5535,6 +5535,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
55355535 ast_arguments : & ast:: Arguments ,
55365536 arguments : & mut CallArguments < ' a , ' db > ,
55375537 bindings : & Bindings < ' db > ,
5538+ call_expression_tcx : TypeContext < ' db > ,
55385539 ) {
55395540 debug_assert ! (
55405541 ast_arguments. len( ) == arguments. len( )
@@ -5603,10 +5604,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
56035604 return None ;
56045605 } ;
56055606
5606- let parameter_type =
5607+ let mut parameter_type =
56075608 overload. signature . parameters ( ) [ * parameter_index] . annotated_type ( ) ?;
56085609
5609- // TODO: For now, skip any parameter annotations that mention any typevars. There
5610+ // If this is a generic call, attempt to specialize the parameter type using the
5611+ // declared type context, if provided.
5612+ if let Some ( generic_context) = overload. signature . generic_context
5613+ && let Some ( return_ty) = overload. signature . return_ty
5614+ && let Some ( declared_return_ty) = call_expression_tcx. annotation
5615+ {
5616+ let mut builder =
5617+ SpecializationBuilder :: new ( db, generic_context. inferable_typevars ( db) ) ;
5618+
5619+ let _ = builder. infer ( return_ty, declared_return_ty) ;
5620+ let specialization = builder. build ( generic_context, call_expression_tcx) ;
5621+
5622+ // Note that we are not necessarily "preferring the declared type" here, as the
5623+ // type context will only be preferred during the inference of this expression
5624+ // by the same heuristics we use for the inference of the outer generic call.
5625+ parameter_type = parameter_type. apply_specialization ( db, specialization) ;
5626+ }
5627+
5628+ // TODO: For now, skip any parameter annotations that still mention any typevars. There
56105629 // are two issues:
56115630 //
56125631 // First, if we include those typevars in the type context that we use to infer the
@@ -6820,7 +6839,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
68206839 let infer_call_arguments = |bindings : Option < Bindings < ' db > > | {
68216840 if let Some ( bindings) = bindings {
68226841 let bindings = bindings. match_parameters ( self . db ( ) , & call_arguments) ;
6823- self . infer_all_argument_types ( arguments, & mut call_arguments, & bindings) ;
6842+ self . infer_all_argument_types (
6843+ arguments,
6844+ & mut call_arguments,
6845+ & bindings,
6846+ tcx,
6847+ ) ;
68246848 } else {
68256849 let argument_forms = vec ! [ Some ( ParameterForm :: Value ) ; call_arguments. len( ) ] ;
68266850 self . infer_argument_types ( arguments, & mut call_arguments, & argument_forms) ;
@@ -6841,7 +6865,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
68416865 let bindings = callable_type
68426866 . bindings ( self . db ( ) )
68436867 . match_parameters ( self . db ( ) , & call_arguments) ;
6844- self . infer_all_argument_types ( arguments, & mut call_arguments, & bindings) ;
6868+ self . infer_all_argument_types ( arguments, & mut call_arguments, & bindings, tcx ) ;
68456869
68466870 // Validate `TypedDict` constructor calls after argument type inference
68476871 if let Some ( class_literal) = callable_type. as_class_literal ( ) {
0 commit comments