diff --git a/compiler/rustc_query_impl/src/plumbing.rs b/compiler/rustc_query_impl/src/plumbing.rs
index 005ce16dbb9b4..01d98c62aab83 100644
--- a/compiler/rustc_query_impl/src/plumbing.rs
+++ b/compiler/rustc_query_impl/src/plumbing.rs
@@ -52,7 +52,31 @@ impl<'tcx> HasDepContext for QueryCtxt<'tcx> {
     }
 }
 
-impl QueryContext for QueryCtxt<'_> {
+impl<'tcx> QueryCtxt<'tcx> {
+    // Define this closure separately from the one passed to `with_related_context` so it only has to be monomorphized once.
+    fn make_icx<'a>(
+        self,
+        token: QueryJobId,
+        depth_limit: bool,
+        diagnostics: Option<&'a Lock<ThinVec<Diagnostic>>>,
+        current_icx: &ImplicitCtxt<'a, 'tcx>,
+    ) -> ImplicitCtxt<'a, 'tcx> {
+        if depth_limit && !self.recursion_limit().value_within_limit(current_icx.query_depth) {
+            self.depth_limit_error(token);
+        }
+
+        // Update the `ImplicitCtxt` to point to our new query job.
+        ImplicitCtxt {
+            tcx: *self,
+            query: Some(token),
+            diagnostics,
+            query_depth: current_icx.query_depth + depth_limit as usize,
+            task_deps: current_icx.task_deps,
+        }
+    }
+}
+
+impl<'tcx> QueryContext for QueryCtxt<'tcx> {
     fn next_job_id(self) -> QueryJobId {
         QueryJobId(
             NonZeroU64::new(
@@ -110,21 +134,11 @@ impl QueryContext for QueryCtxt<'_> {
         // as `self`, so we use `with_related_context` to relate the 'tcx lifetimes
         // when accessing the `ImplicitCtxt`.
         tls::with_related_context(*self, move |current_icx| {
-            if depth_limit && !self.recursion_limit().value_within_limit(current_icx.query_depth) {
-                self.depth_limit_error(token);
-            }
-
-            // Update the `ImplicitCtxt` to point to our new query job.
-            let new_icx = ImplicitCtxt {
-                tcx: *self,
-                query: Some(token),
-                diagnostics,
-                query_depth: current_icx.query_depth + depth_limit as usize,
-                task_deps: current_icx.task_deps,
-            };
-
             // Use the `ImplicitCtxt` while we execute the query.
-            tls::enter_context(&new_icx, compute)
+            tls::enter_context(
+                &self.make_icx(token, depth_limit, diagnostics, current_icx),
+                compute,
+            )
         })
     }
 
diff --git a/compiler/rustc_query_system/src/dep_graph/graph.rs b/compiler/rustc_query_system/src/dep_graph/graph.rs
index 59e0c35974559..512f89fe01f49 100644
--- a/compiler/rustc_query_system/src/dep_graph/graph.rs
+++ b/compiler/rustc_query_system/src/dep_graph/graph.rs
@@ -377,14 +377,14 @@ impl<K: DepKind> DepGraph<K> {
 
     /// Executes something within an "anonymous" task, that is, a task the
     /// `DepNode` of which is determined by the list of inputs it read from.
-    pub fn with_anon_task<Tcx: DepContext<DepKind = K>, OP, R>(
+    pub fn with_anon_task<Tcx: DepContext<DepKind = K>, R>(
         &self,
         cx: Tcx,
         dep_kind: K,
-        op: OP,
+        op: &mut dyn FnMut() -> R,
     ) -> (R, DepNodeIndex)
-    where
-        OP: FnOnce() -> R,
+// where
+    //     OP: FnOnce() -> R,
     {
         debug_assert!(!cx.is_eval_always(dep_kind));
 
diff --git a/compiler/rustc_query_system/src/query/plumbing.rs b/compiler/rustc_query_system/src/query/plumbing.rs
index 005fcd8c4cc9d..c8e625009677e 100644
--- a/compiler/rustc_query_system/src/query/plumbing.rs
+++ b/compiler/rustc_query_system/src/query/plumbing.rs
@@ -2,7 +2,7 @@
 //! generate the actual methods on tcx which find and execute the provider,
 //! manage the caches, and so forth.
 
-use crate::dep_graph::HasDepContext;
+use crate::dep_graph::{HasDepContext, SerializedDepNodeIndex};
 use crate::dep_graph::{DepContext, DepKind, DepNode, DepNodeIndex, DepNodeParams};
 use crate::ich::StableHashingContext;
 use crate::query::caches::QueryCache;
@@ -458,7 +458,7 @@ where
     let (result, dep_node_index) =
         qcx.start_query(job_id, query.depth_limit(), Some(&diagnostics), || {
             if query.anon() {
-                return dep_graph.with_anon_task(*qcx.dep_context(), query.dep_kind(), || {
+                return dep_graph.with_anon_task(*qcx.dep_context(), query.dep_kind(), &mut || {
                     query.compute(qcx, key)
                 });
             }
@@ -492,6 +492,62 @@ where
     (result, dep_node_index)
 }
 
+#[inline(always)]
+fn load_from_disk_and_cache_in_memory<V, Qcx>(
+    qcx: Qcx,
+    dep_node: &DepNode<Qcx::DepKind>,
+    prev_dep_node_index: SerializedDepNodeIndex,
+    dep_node_index: DepNodeIndex,
+    try_load_from_disk: fn(Qcx, SerializedDepNodeIndex) -> Option<V>,
+    hash_result: Option<fn(&mut StableHashingContext<'_>, &V) -> Fingerprint>,
+) -> Option<(V, DepNodeIndex)>
+where
+    V: Debug,
+    Qcx: QueryContext,
+{
+    let dep_graph = qcx.dep_context().dep_graph();
+    let prof_timer = qcx.dep_context().profiler().incr_cache_loading();
+
+    // The call to `with_query_deserialization` enforces that no new `DepNodes`
+    // are created during deserialization. See the docs of that method for more
+    // details.
+    let result =
+        dep_graph.with_query_deserialization(|| try_load_from_disk(qcx, prev_dep_node_index));
+
+    prof_timer.finish_with_query_invocation_id(dep_node_index.into());
+
+    if let Some(result) = result {
+        if std::intrinsics::unlikely(
+            qcx.dep_context().sess().opts.unstable_opts.query_dep_graph,
+        ) {
+            dep_graph.mark_debug_loaded_from_disk(*dep_node)
+        }
+
+        let prev_fingerprint = qcx
+            .dep_context()
+            .dep_graph()
+            .prev_fingerprint_of(dep_node)
+            .unwrap_or(Fingerprint::ZERO);
+        // If `-Zincremental-verify-ich` is specified, re-hash results from
+        // the cache and make sure that they have the expected fingerprint.
+        //
+        // If not, we still seek to verify a subset of fingerprints loaded
+        // from disk. Re-hashing results is fairly expensive, so we can't
+        // currently afford to verify every hash. This subset should still
+        // give us some coverage of potential bugs though.
+        let try_verify = prev_fingerprint.as_value().1 % 32 == 0;
+        if std::intrinsics::unlikely(
+            try_verify || qcx.dep_context().sess().opts.unstable_opts.incremental_verify_ich,
+        ) {
+            incremental_verify_ich(*qcx.dep_context(), &result, dep_node, hash_result);
+        }
+
+        Some((result, dep_node_index))
+    } else {
+        None
+    }
+}
+
 #[inline(always)]
 fn try_load_from_disk_and_cache_in_memory<Q, Qcx>(
     query: Q,
@@ -514,51 +570,16 @@ where
     // First we try to load the result from the on-disk cache.
     // Some things are never cached on disk.
     if let Some(try_load_from_disk) = query.try_load_from_disk(qcx, &key) {
-        let prof_timer = qcx.dep_context().profiler().incr_cache_loading();
-
-        // The call to `with_query_deserialization` enforces that no new `DepNodes`
-        // are created during deserialization. See the docs of that method for more
-        // details.
-        let result =
-            dep_graph.with_query_deserialization(|| try_load_from_disk(qcx, prev_dep_node_index));
-
-        prof_timer.finish_with_query_invocation_id(dep_node_index.into());
-
-        if let Some(result) = result {
-            if std::intrinsics::unlikely(
-                qcx.dep_context().sess().opts.unstable_opts.query_dep_graph,
-            ) {
-                dep_graph.mark_debug_loaded_from_disk(*dep_node)
-            }
-
-            let prev_fingerprint = qcx
-                .dep_context()
-                .dep_graph()
-                .prev_fingerprint_of(dep_node)
-                .unwrap_or(Fingerprint::ZERO);
-            // If `-Zincremental-verify-ich` is specified, re-hash results from
-            // the cache and make sure that they have the expected fingerprint.
-            //
-            // If not, we still seek to verify a subset of fingerprints loaded
-            // from disk. Re-hashing results is fairly expensive, so we can't
-            // currently afford to verify every hash. This subset should still
-            // give us some coverage of potential bugs though.
-            let try_verify = prev_fingerprint.as_value().1 % 32 == 0;
-            if std::intrinsics::unlikely(
-                try_verify || qcx.dep_context().sess().opts.unstable_opts.incremental_verify_ich,
-            ) {
-                incremental_verify_ich(*qcx.dep_context(), &result, dep_node, query.hash_result());
-            }
-
-            return Some((result, dep_node_index));
+        if let Some(value) = load_from_disk_and_cache_in_memory::<Q::Value, Qcx>(
+            qcx,
+            dep_node,
+            prev_dep_node_index,
+            dep_node_index,
+            try_load_from_disk,
+            query.hash_result(),
+        ) {
+            return Some(value);
         }
-
-        // We always expect to find a cached result for things that
-        // can be forced from `DepNode`.
-        debug_assert!(
-            !qcx.dep_context().fingerprint_style(dep_node.kind).reconstructible(),
-            "missing on-disk cache entry for {dep_node:?}"
-        );
     }
 
     // We could not load a result from the on-disk cache, so
diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs
index 48c3b3601b4d3..15eca2bba8e97 100644
--- a/compiler/rustc_trait_selection/src/traits/select/mod.rs
+++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs
@@ -1417,12 +1417,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
         self.check_recursion_depth(obligation.recursion_depth, error_obligation)
     }
 
-    fn in_task<OP, R>(&mut self, op: OP) -> (R, DepNodeIndex)
+    fn in_task<OP, R>(&mut self, mut op: OP) -> (R, DepNodeIndex)
     where
-        OP: FnOnce(&mut Self) -> R,
+        OP: FnMut(&mut Self) -> R,
     {
         let (result, dep_node) =
-            self.tcx().dep_graph.with_anon_task(self.tcx(), DepKind::TraitSelect, || op(self));
+            self.tcx().dep_graph.with_anon_task(self.tcx(), DepKind::TraitSelect, &mut || op(self));
         self.tcx().dep_graph.read_index(dep_node);
         (result, dep_node)
     }