From c3f9c4f4d4bbc83c7de79a09c7ec0e7fda8efc5e Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Tue, 20 Aug 2024 11:35:05 -0400
Subject: [PATCH 1/2] Use equality when relating formal and expected type in
 arg checking

---
 .../rustc_hir_typeck/src/fn_ctxt/checks.rs    |  9 ++++-----
 .../coercion/constrain-expectation-in-arg.rs  | 19 +++++++++++++++++++
 2 files changed, 23 insertions(+), 5 deletions(-)
 create mode 100644 tests/ui/coercion/constrain-expectation-in-arg.rs

diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
index eebb0217990df..16d65726128c3 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
@@ -292,21 +292,20 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
 
             let coerce_error =
                 self.coerce(provided_arg, checked_ty, coerced_ty, AllowTwoPhase::Yes, None).err();
-
             if coerce_error.is_some() {
                 return Compatibility::Incompatible(coerce_error);
             }
 
-            // 3. Check if the formal type is a supertype of the checked one
-            //    and register any such obligations for future type checks
-            let supertype_error = self.at(&self.misc(provided_arg.span), self.param_env).sup(
+            // 3. Check if the formal type is actually equal to the checked one
+            //    and register any such obligations for future type checks.
+            let formal_ty_error = self.at(&self.misc(provided_arg.span), self.param_env).eq(
                 DefineOpaqueTypes::Yes,
                 formal_input_ty,
                 coerced_ty,
             );
 
             // If neither check failed, the types are compatible
-            match supertype_error {
+            match formal_ty_error {
                 Ok(InferOk { obligations, value: () }) => {
                     self.register_predicates(obligations);
                     Compatibility::Compatible
diff --git a/tests/ui/coercion/constrain-expectation-in-arg.rs b/tests/ui/coercion/constrain-expectation-in-arg.rs
new file mode 100644
index 0000000000000..858c3a0bdb572
--- /dev/null
+++ b/tests/ui/coercion/constrain-expectation-in-arg.rs
@@ -0,0 +1,19 @@
+//@ check-pass
+
+trait Trait {
+    type Item;
+}
+
+struct Struct<A: Trait<Item = B>, B> {
+    pub field: A,
+}
+
+fn identity<T>(x: T) -> T {
+    x
+}
+
+fn test<A: Trait<Item = B>, B>(x: &Struct<A, B>) {
+    let x: &Struct<_, _> = identity(x);
+}
+
+fn main() {}

From 95b9ecd6d671637e9e3db55ed31d06882d3cad4d Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Sun, 25 Aug 2024 12:45:58 -0400
Subject: [PATCH 2/2] Inline expected_inputs_for_expected_output into
 check_argument_types/check_expr_struct_fields

---
 compiler/rustc_hir_typeck/src/callee.rs       | 21 ++-----
 compiler/rustc_hir_typeck/src/expr.rs         | 25 +++++---
 .../rustc_hir_typeck/src/fn_ctxt/_impl.rs     | 39 +------------
 .../rustc_hir_typeck/src/fn_ctxt/checks.rs    | 58 ++++++++++++++-----
 .../coercion/constrain-expectation-in-arg.rs  |  5 ++
 5 files changed, 71 insertions(+), 77 deletions(-)

diff --git a/compiler/rustc_hir_typeck/src/callee.rs b/compiler/rustc_hir_typeck/src/callee.rs
index a4eec5f05a8ff..9863d0364498e 100644
--- a/compiler/rustc_hir_typeck/src/callee.rs
+++ b/compiler/rustc_hir_typeck/src/callee.rs
@@ -503,18 +503,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         let fn_sig = self.instantiate_binder_with_fresh_vars(call_expr.span, infer::FnCall, fn_sig);
         let fn_sig = self.normalize(call_expr.span, fn_sig);
 
-        // Call the generic checker.
-        let expected_arg_tys = self.expected_inputs_for_expected_output(
-            call_expr.span,
-            expected,
-            fn_sig.output(),
-            fn_sig.inputs(),
-        );
         self.check_argument_types(
             call_expr.span,
             call_expr,
             fn_sig.inputs(),
-            expected_arg_tys,
+            fn_sig.output(),
+            expected,
             arg_exprs,
             fn_sig.c_variadic,
             TupleArgumentsFlag::DontTupleArguments,
@@ -866,19 +860,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         // don't know the full details yet (`Fn` vs `FnMut` etc), but we
         // do know the types expected for each argument and the return
         // type.
-
-        let expected_arg_tys = self.expected_inputs_for_expected_output(
-            call_expr.span,
-            expected,
-            fn_sig.output(),
-            fn_sig.inputs(),
-        );
-
         self.check_argument_types(
             call_expr.span,
             call_expr,
             fn_sig.inputs(),
-            expected_arg_tys,
+            fn_sig.output(),
+            expected,
             arg_exprs,
             fn_sig.c_variadic,
             TupleArgumentsFlag::TupleArguments,
diff --git a/compiler/rustc_hir_typeck/src/expr.rs b/compiler/rustc_hir_typeck/src/expr.rs
index 1362d3626efd4..f0d47e584ac28 100644
--- a/compiler/rustc_hir_typeck/src/expr.rs
+++ b/compiler/rustc_hir_typeck/src/expr.rs
@@ -1673,15 +1673,22 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
     ) {
         let tcx = self.tcx;
 
-        let expected_inputs =
-            self.expected_inputs_for_expected_output(span, expected, adt_ty, &[adt_ty]);
-        let adt_ty_hint = if let Some(expected_inputs) = expected_inputs {
-            expected_inputs.get(0).cloned().unwrap_or(adt_ty)
-        } else {
-            adt_ty
-        };
-        // re-link the regions that EIfEO can erase.
-        self.demand_eqtype(span, adt_ty_hint, adt_ty);
+        let adt_ty = self.resolve_vars_with_obligations(adt_ty);
+        let adt_ty_hint = expected.only_has_type(self).and_then(|expected| {
+            self.fudge_inference_if_ok(|| {
+                let ocx = ObligationCtxt::new(self);
+                ocx.sup(&self.misc(span), self.param_env, expected, adt_ty)?;
+                if !ocx.select_where_possible().is_empty() {
+                    return Err(TypeError::Mismatch);
+                }
+                Ok(self.resolve_vars_if_possible(adt_ty))
+            })
+            .ok()
+        });
+        if let Some(adt_ty_hint) = adt_ty_hint {
+            // re-link the variables that the fudging above can create.
+            self.demand_eqtype(span, adt_ty_hint, adt_ty);
+        }
 
         let ty::Adt(adt, args) = adt_ty.kind() else {
             span_bug!(span, "non-ADT passed to check_expr_struct_fields");
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
index 97c27680959f0..19f7950287f93 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
@@ -20,7 +20,6 @@ use rustc_infer::infer::canonical::{Canonical, OriginalQueryValues, QueryRespons
 use rustc_infer::infer::{DefineOpaqueTypes, InferResult};
 use rustc_lint::builtin::SELF_CONSTRUCTOR_FROM_OUTER_ITEM;
 use rustc_middle::ty::adjustment::{Adjust, Adjustment, AutoBorrow, AutoBorrowMutability};
-use rustc_middle::ty::error::TypeError;
 use rustc_middle::ty::fold::TypeFoldable;
 use rustc_middle::ty::visit::{TypeVisitable, TypeVisitableExt};
 use rustc_middle::ty::{
@@ -36,7 +35,7 @@ use rustc_span::Span;
 use rustc_target::abi::FieldIdx;
 use rustc_trait_selection::error_reporting::infer::need_type_info::TypeAnnotationNeeded;
 use rustc_trait_selection::traits::{
-    self, NormalizeExt, ObligationCauseCode, ObligationCtxt, StructurallyNormalizeExt,
+    self, NormalizeExt, ObligationCauseCode, StructurallyNormalizeExt,
 };
 use tracing::{debug, instrument};
 
@@ -689,42 +688,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         vec![ty_error; len]
     }
 
-    /// Unifies the output type with the expected type early, for more coercions
-    /// and forward type information on the input expressions.
-    #[instrument(skip(self, call_span), level = "debug")]
-    pub(crate) fn expected_inputs_for_expected_output(
-        &self,
-        call_span: Span,
-        expected_ret: Expectation<'tcx>,
-        formal_ret: Ty<'tcx>,
-        formal_args: &[Ty<'tcx>],
-    ) -> Option<Vec<Ty<'tcx>>> {
-        let formal_ret = self.resolve_vars_with_obligations(formal_ret);
-        let ret_ty = expected_ret.only_has_type(self)?;
-
-        let expect_args = self
-            .fudge_inference_if_ok(|| {
-                let ocx = ObligationCtxt::new(self);
-
-                // Attempt to apply a subtyping relationship between the formal
-                // return type (likely containing type variables if the function
-                // is polymorphic) and the expected return type.
-                // No argument expectations are produced if unification fails.
-                let origin = self.misc(call_span);
-                ocx.sup(&origin, self.param_env, ret_ty, formal_ret)?;
-                if !ocx.select_where_possible().is_empty() {
-                    return Err(TypeError::Mismatch);
-                }
-
-                // Record all the argument types, with the args
-                // produced from the above subtyping unification.
-                Ok(Some(formal_args.iter().map(|&ty| self.resolve_vars_if_possible(ty)).collect()))
-            })
-            .unwrap_or_default();
-        debug!(?formal_args, ?formal_ret, ?expect_args, ?expected_ret);
-        expect_args
-    }
-
     pub(crate) fn resolve_lang_item_path(
         &self,
         lang_item: hir::LangItem,
diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
index 16d65726128c3..bdf84f332166d 100644
--- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
+++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
@@ -17,6 +17,7 @@ use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer;
 use rustc_index::IndexVec;
 use rustc_infer::infer::{DefineOpaqueTypes, InferOk, TypeTrace};
 use rustc_middle::ty::adjustment::AllowTwoPhase;
+use rustc_middle::ty::error::TypeError;
 use rustc_middle::ty::visit::TypeVisitableExt;
 use rustc_middle::ty::{self, IsSuggestable, Ty, TyCtxt};
 use rustc_middle::{bug, span_bug};
@@ -25,7 +26,7 @@ use rustc_span::symbol::{kw, Ident};
 use rustc_span::{sym, Span, DUMMY_SP};
 use rustc_trait_selection::error_reporting::infer::{FailureCode, ObligationCauseExt};
 use rustc_trait_selection::infer::InferCtxtExt;
-use rustc_trait_selection::traits::{self, ObligationCauseCode, SelectionContext};
+use rustc_trait_selection::traits::{self, ObligationCauseCode, ObligationCtxt, SelectionContext};
 use tracing::debug;
 use {rustc_ast as ast, rustc_hir as hir};
 
@@ -124,6 +125,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         };
         if let Err(guar) = has_error {
             let err_inputs = self.err_args(args_no_rcvr.len(), guar);
+            let err_output = Ty::new_error(self.tcx, guar);
 
             let err_inputs = match tuple_arguments {
                 DontTupleArguments => err_inputs,
@@ -134,28 +136,23 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
                 sp,
                 expr,
                 &err_inputs,
-                None,
+                err_output,
+                NoExpectation,
                 args_no_rcvr,
                 false,
                 tuple_arguments,
                 method.ok().map(|method| method.def_id),
             );
-            return Ty::new_error(self.tcx, guar);
+            return err_output;
         }
 
         let method = method.unwrap();
-        // HACK(eddyb) ignore self in the definition (see above).
-        let expected_input_tys = self.expected_inputs_for_expected_output(
-            sp,
-            expected,
-            method.sig.output(),
-            &method.sig.inputs()[1..],
-        );
         self.check_argument_types(
             sp,
             expr,
             &method.sig.inputs()[1..],
-            expected_input_tys,
+            method.sig.output(),
+            expected,
             args_no_rcvr,
             method.sig.c_variadic,
             tuple_arguments,
@@ -175,8 +172,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
         call_expr: &'tcx hir::Expr<'tcx>,
         // Types (as defined in the *signature* of the target function)
         formal_input_tys: &[Ty<'tcx>],
-        // More specific expected types, after unifying with caller output types
-        expected_input_tys: Option<Vec<Ty<'tcx>>>,
+        formal_output: Ty<'tcx>,
+        // Expected output from the parent expression or statement
+        expectation: Expectation<'tcx>,
         // The expressions for each provided argument
         provided_args: &'tcx [hir::Expr<'tcx>],
         // Whether the function is variadic, for example when imported from C
@@ -210,6 +208,40 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
             );
         }
 
+        // First, let's unify the formal method signature with the expectation eagerly.
+        // We use this to guide coercion inference; it's output is "fudged" which means
+        // any remaining type variables are assigned to new, unrelated variables. This
+        // is because the inference guidance here is only speculative.
+        let formal_output = self.resolve_vars_with_obligations(formal_output);
+        let expected_input_tys: Option<Vec<_>> = expectation
+            .only_has_type(self)
+            .and_then(|expected_output| {
+                self.fudge_inference_if_ok(|| {
+                    let ocx = ObligationCtxt::new(self);
+
+                    // Attempt to apply a subtyping relationship between the formal
+                    // return type (likely containing type variables if the function
+                    // is polymorphic) and the expected return type.
+                    // No argument expectations are produced if unification fails.
+                    let origin = self.misc(call_span);
+                    ocx.sup(&origin, self.param_env, expected_output, formal_output)?;
+                    if !ocx.select_where_possible().is_empty() {
+                        return Err(TypeError::Mismatch);
+                    }
+
+                    // Record all the argument types, with the args
+                    // produced from the above subtyping unification.
+                    Ok(Some(
+                        formal_input_tys
+                            .iter()
+                            .map(|&ty| self.resolve_vars_if_possible(ty))
+                            .collect(),
+                    ))
+                })
+                .ok()
+            })
+            .unwrap_or_default();
+
         let mut err_code = E0061;
 
         // If the arguments should be wrapped in a tuple (ex: closures), unwrap them here
diff --git a/tests/ui/coercion/constrain-expectation-in-arg.rs b/tests/ui/coercion/constrain-expectation-in-arg.rs
index 858c3a0bdb572..c515dedc4bb4d 100644
--- a/tests/ui/coercion/constrain-expectation-in-arg.rs
+++ b/tests/ui/coercion/constrain-expectation-in-arg.rs
@@ -1,5 +1,10 @@
 //@ check-pass
 
+// Regression test for for #129286.
+// Makes sure that we don't have unconstrained type variables that come from
+// bivariant type parameters due to the way that we construct expectation types
+// when checking call expressions in HIR typeck.
+
 trait Trait {
     type Item;
 }