Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f9046e6

Browse files
committedJan 17, 2024
Add trait obligation tracking to FulfillCtxt and expose FnCtxt in rustc_infer using callback.
Pass each obligation to an fn callback with its respective inference context. This avoids needing to keep around copies of obligations or inference contexts.
1 parent 5113ed2 commit f9046e6

File tree

7 files changed

+166
-67
lines changed

7 files changed

+166
-67
lines changed
 

‎compiler/rustc_hir_typeck/src/lib.rs

+17-3
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ use rustc_hir::{HirIdMap, Node};
6060
use rustc_hir_analysis::astconv::AstConv;
6161
use rustc_hir_analysis::check::check_abi;
6262
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
63+
use rustc_infer::traits::ObligationInspector;
6364
use rustc_middle::query::Providers;
6465
use rustc_middle::traits;
6566
use rustc_middle::ty::{self, Ty, TyCtxt};
@@ -139,7 +140,7 @@ fn used_trait_imports(tcx: TyCtxt<'_>, def_id: LocalDefId) -> &UnordSet<LocalDef
139140

140141
fn typeck<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> &ty::TypeckResults<'tcx> {
141142
let fallback = move || tcx.type_of(def_id.to_def_id()).instantiate_identity();
142-
typeck_with_fallback(tcx, def_id, fallback)
143+
typeck_with_fallback(tcx, def_id, fallback, None)
143144
}
144145

145146
/// Used only to get `TypeckResults` for type inference during error recovery.
@@ -149,14 +150,24 @@ fn diagnostic_only_typeck<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> &ty::T
149150
let span = tcx.hir().span(tcx.local_def_id_to_hir_id(def_id));
150151
Ty::new_error_with_message(tcx, span, "diagnostic only typeck table used")
151152
};
152-
typeck_with_fallback(tcx, def_id, fallback)
153+
typeck_with_fallback(tcx, def_id, fallback, None)
153154
}
154155

155-
#[instrument(level = "debug", skip(tcx, fallback), ret)]
156+
pub fn inspect_typeck<'tcx>(
157+
tcx: TyCtxt<'tcx>,
158+
def_id: LocalDefId,
159+
inspect: ObligationInspector<'tcx>,
160+
) -> &'tcx ty::TypeckResults<'tcx> {
161+
let fallback = move || tcx.type_of(def_id.to_def_id()).instantiate_identity();
162+
typeck_with_fallback(tcx, def_id, fallback, Some(inspect))
163+
}
164+
165+
#[instrument(level = "debug", skip(tcx, fallback, inspector), ret)]
156166
fn typeck_with_fallback<'tcx>(
157167
tcx: TyCtxt<'tcx>,
158168
def_id: LocalDefId,
159169
fallback: impl Fn() -> Ty<'tcx> + 'tcx,
170+
inspector: Option<ObligationInspector<'tcx>>,
160171
) -> &'tcx ty::TypeckResults<'tcx> {
161172
// Closures' typeck results come from their outermost function,
162173
// as they are part of the same "inference environment".
@@ -178,6 +189,9 @@ fn typeck_with_fallback<'tcx>(
178189
let param_env = tcx.param_env(def_id);
179190

180191
let inh = Inherited::new(tcx, def_id);
192+
if let Some(inspector) = inspector {
193+
inh.infcx.attach_obligation_inspector(inspector);
194+
}
181195
let mut fcx = FnCtxt::new(&inh, param_env, def_id);
182196

183197
if let Some(hir::FnSig { header, decl, .. }) = fn_sig {

‎compiler/rustc_infer/src/infer/at.rs

+1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ impl<'tcx> InferCtxt<'tcx> {
9090
universe: self.universe.clone(),
9191
intercrate,
9292
next_trait_solver: self.next_trait_solver,
93+
obligation_inspector: self.obligation_inspector.clone(),
9394
}
9495
}
9596
}

‎compiler/rustc_infer/src/infer/mod.rs

+14-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ use rustc_middle::infer::unify_key::{ConstVidKey, EffectVidKey};
1313
use self::opaque_types::OpaqueTypeStorage;
1414
pub(crate) use self::undo_log::{InferCtxtUndoLogs, Snapshot, UndoLog};
1515

16-
use crate::traits::{self, ObligationCause, PredicateObligations, TraitEngine, TraitEngineExt};
16+
use crate::traits::{
17+
self, ObligationCause, ObligationInspector, PredicateObligations, TraitEngine, TraitEngineExt,
18+
};
1719

1820
use rustc_data_structures::fx::FxIndexMap;
1921
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
@@ -334,6 +336,8 @@ pub struct InferCtxt<'tcx> {
334336
pub intercrate: bool,
335337

336338
next_trait_solver: bool,
339+
340+
pub obligation_inspector: Cell<Option<ObligationInspector<'tcx>>>,
337341
}
338342

339343
impl<'tcx> ty::InferCtxtLike for InferCtxt<'tcx> {
@@ -708,6 +712,7 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
708712
universe: Cell::new(ty::UniverseIndex::ROOT),
709713
intercrate,
710714
next_trait_solver,
715+
obligation_inspector: Cell::new(None),
711716
}
712717
}
713718
}
@@ -1726,6 +1731,14 @@ impl<'tcx> InferCtxt<'tcx> {
17261731
}
17271732
}
17281733
}
1734+
1735+
pub fn attach_obligation_inspector(&self, inspector: ObligationInspector<'tcx>) {
1736+
debug_assert!(
1737+
self.obligation_inspector.get().is_none(),
1738+
"shouldn't override a set obligation inspector"
1739+
);
1740+
self.obligation_inspector.set(Some(inspector));
1741+
}
17291742
}
17301743

17311744
impl<'tcx> TypeErrCtxt<'_, 'tcx> {

‎compiler/rustc_infer/src/traits/mod.rs

+9
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@ use std::hash::{Hash, Hasher};
1313

1414
use hir::def_id::LocalDefId;
1515
use rustc_hir as hir;
16+
use rustc_middle::traits::query::NoSolution;
17+
use rustc_middle::traits::solve::Certainty;
1618
use rustc_middle::ty::error::{ExpectedFound, TypeError};
1719
use rustc_middle::ty::{self, Const, ToPredicate, Ty, TyCtxt};
1820
use rustc_span::Span;
1921

2022
pub use self::FulfillmentErrorCode::*;
2123
pub use self::ImplSource::*;
2224
pub use self::SelectionError::*;
25+
use crate::infer::InferCtxt;
2326

2427
pub use self::engine::{TraitEngine, TraitEngineExt};
2528
pub use self::project::MismatchedProjectionTypes;
@@ -117,6 +120,12 @@ pub type PredicateObligations<'tcx> = Vec<PredicateObligation<'tcx>>;
117120

118121
pub type Selection<'tcx> = ImplSource<'tcx, PredicateObligation<'tcx>>;
119122

123+
/// A callback that can be provided to `inspect_typeck`. Invoked on evaluation
124+
/// of root obligations.
125+
pub type ObligationInspector<'tcx> =
126+
fn(&InferCtxt<'tcx>, &PredicateObligation<'tcx>, Result<Certainty, NoSolution>);
127+
128+
#[derive(Clone)]
120129
pub struct FulfillmentError<'tcx> {
121130
pub obligation: PredicateObligation<'tcx>,
122131
pub code: FulfillmentErrorCode<'tcx>,

‎compiler/rustc_session/src/options.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1943,6 +1943,9 @@ written to standard error output)"),
19431943
"for every macro invocation, print its name and arguments (default: no)"),
19441944
track_diagnostics: bool = (false, parse_bool, [UNTRACKED],
19451945
"tracks where in rustc a diagnostic was emitted"),
1946+
track_trait_obligations: bool = (false, parse_bool, [TRACKED],
1947+
"tracks evaluated obligations while trait solving, option is only \
1948+
valid when -Z next-solver=globally (default: no)"),
19461949
// Diagnostics are considered side-effects of a query (see `QuerySideEffects`) and are saved
19471950
// alongside query results and changes to translation options can affect diagnostics - so
19481951
// translation options should be tracked.

‎compiler/rustc_trait_selection/src/solve/fulfill.rs

+90-63
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use rustc_middle::ty;
1111
use rustc_middle::ty::error::{ExpectedFound, TypeError};
1212

1313
use super::eval_ctxt::GenerateProofTree;
14-
use super::{Certainty, InferCtxtEvalExt};
14+
use super::{Certainty, Goal, InferCtxtEvalExt};
1515

1616
/// A trait engine using the new trait solver.
1717
///
@@ -32,11 +32,34 @@ pub struct FulfillmentCtxt<'tcx> {
3232
/// gets rolled back. Because of this we explicitly check that we only
3333
/// use the context in exactly this snapshot.
3434
usable_in_snapshot: usize,
35+
36+
track_obligations: bool,
3537
}
3638

3739
impl<'tcx> FulfillmentCtxt<'tcx> {
3840
pub fn new(infcx: &InferCtxt<'tcx>) -> FulfillmentCtxt<'tcx> {
39-
FulfillmentCtxt { obligations: Vec::new(), usable_in_snapshot: infcx.num_open_snapshots() }
41+
FulfillmentCtxt {
42+
obligations: Vec::new(),
43+
usable_in_snapshot: infcx.num_open_snapshots(),
44+
track_obligations: infcx.tcx.sess.opts.unstable_opts.track_trait_obligations,
45+
}
46+
}
47+
48+
fn track_evaluated_obligation(
49+
&self,
50+
infcx: &InferCtxt<'tcx>,
51+
obligation: &PredicateObligation<'tcx>,
52+
result: &Result<(bool, Certainty, Vec<Goal<'tcx, ty::Predicate<'tcx>>>), NoSolution>,
53+
) {
54+
if self.track_obligations {
55+
if let Some(inspector) = infcx.obligation_inspector.get() {
56+
let result = match result {
57+
Ok((_, c, _)) => Ok(*c),
58+
Err(NoSolution) => Err(NoSolution),
59+
};
60+
(inspector)(infcx, &obligation, result);
61+
}
62+
}
4063
}
4164
}
4265

@@ -52,7 +75,8 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
5275
}
5376

5477
fn collect_remaining_errors(&mut self, infcx: &InferCtxt<'tcx>) -> Vec<FulfillmentError<'tcx>> {
55-
self.obligations
78+
let errors = self
79+
.obligations
5680
.drain(..)
5781
.map(|obligation| {
5882
let code = infcx.probe(|_| {
@@ -81,7 +105,9 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
81105
root_obligation: obligation,
82106
}
83107
})
84-
.collect()
108+
.collect();
109+
110+
errors
85111
}
86112

87113
fn select_where_possible(&mut self, infcx: &InferCtxt<'tcx>) -> Vec<FulfillmentError<'tcx>> {
@@ -95,65 +121,66 @@ impl<'tcx> TraitEngine<'tcx> for FulfillmentCtxt<'tcx> {
95121
let mut has_changed = false;
96122
for obligation in mem::take(&mut self.obligations) {
97123
let goal = obligation.clone().into();
98-
let (changed, certainty, nested_goals) =
99-
match infcx.evaluate_root_goal(goal, GenerateProofTree::IfEnabled).0 {
100-
Ok(result) => result,
101-
Err(NoSolution) => {
102-
errors.push(FulfillmentError {
103-
obligation: obligation.clone(),
104-
code: match goal.predicate.kind().skip_binder() {
105-
ty::PredicateKind::Clause(ty::ClauseKind::Projection(_)) => {
106-
FulfillmentErrorCode::CodeProjectionError(
107-
// FIXME: This could be a `Sorts` if the term is a type
108-
MismatchedProjectionTypes { err: TypeError::Mismatch },
109-
)
110-
}
111-
ty::PredicateKind::NormalizesTo(..) => {
112-
FulfillmentErrorCode::CodeProjectionError(
113-
MismatchedProjectionTypes { err: TypeError::Mismatch },
114-
)
115-
}
116-
ty::PredicateKind::AliasRelate(_, _, _) => {
117-
FulfillmentErrorCode::CodeProjectionError(
118-
MismatchedProjectionTypes { err: TypeError::Mismatch },
119-
)
120-
}
121-
ty::PredicateKind::Subtype(pred) => {
122-
let (a, b) = infcx.instantiate_binder_with_placeholders(
123-
goal.predicate.kind().rebind((pred.a, pred.b)),
124-
);
125-
let expected_found = ExpectedFound::new(true, a, b);
126-
FulfillmentErrorCode::CodeSubtypeError(
127-
expected_found,
128-
TypeError::Sorts(expected_found),
129-
)
130-
}
131-
ty::PredicateKind::Coerce(pred) => {
132-
let (a, b) = infcx.instantiate_binder_with_placeholders(
133-
goal.predicate.kind().rebind((pred.a, pred.b)),
134-
);
135-
let expected_found = ExpectedFound::new(false, a, b);
136-
FulfillmentErrorCode::CodeSubtypeError(
137-
expected_found,
138-
TypeError::Sorts(expected_found),
139-
)
140-
}
141-
ty::PredicateKind::Clause(_)
142-
| ty::PredicateKind::ObjectSafe(_)
143-
| ty::PredicateKind::Ambiguous => {
144-
FulfillmentErrorCode::CodeSelectionError(
145-
SelectionError::Unimplemented,
146-
)
147-
}
148-
ty::PredicateKind::ConstEquate(..) => {
149-
bug!("unexpected goal: {goal:?}")
150-
}
151-
},
152-
root_obligation: obligation,
153-
});
154-
continue;
155-
}
156-
};
124+
let result = infcx.evaluate_root_goal(goal, GenerateProofTree::IfEnabled).0;
125+
self.track_evaluated_obligation(infcx, &obligation, &result);
126+
let (changed, certainty, nested_goals) = match result {
127+
Ok(result) => result,
128+
Err(NoSolution) => {
129+
errors.push(FulfillmentError {
130+
obligation: obligation.clone(),
131+
code: match goal.predicate.kind().skip_binder() {
132+
ty::PredicateKind::Clause(ty::ClauseKind::Projection(_)) => {
133+
FulfillmentErrorCode::CodeProjectionError(
134+
// FIXME: This could be a `Sorts` if the term is a type
135+
MismatchedProjectionTypes { err: TypeError::Mismatch },
136+
)
137+
}
138+
ty::PredicateKind::NormalizesTo(..) => {
139+
FulfillmentErrorCode::CodeProjectionError(
140+
MismatchedProjectionTypes { err: TypeError::Mismatch },
141+
)
142+
}
143+
ty::PredicateKind::AliasRelate(_, _, _) => {
144+
FulfillmentErrorCode::CodeProjectionError(
145+
MismatchedProjectionTypes { err: TypeError::Mismatch },
146+
)
147+
}
148+
ty::PredicateKind::Subtype(pred) => {
149+
let (a, b) = infcx.instantiate_binder_with_placeholders(
150+
goal.predicate.kind().rebind((pred.a, pred.b)),
151+
);
152+
let expected_found = ExpectedFound::new(true, a, b);
153+
FulfillmentErrorCode::CodeSubtypeError(
154+
expected_found,
155+
TypeError::Sorts(expected_found),
156+
)
157+
}
158+
ty::PredicateKind::Coerce(pred) => {
159+
let (a, b) = infcx.instantiate_binder_with_placeholders(
160+
goal.predicate.kind().rebind((pred.a, pred.b)),
161+
);
162+
let expected_found = ExpectedFound::new(false, a, b);
163+
FulfillmentErrorCode::CodeSubtypeError(
164+
expected_found,
165+
TypeError::Sorts(expected_found),
166+
)
167+
}
168+
ty::PredicateKind::Clause(_)
169+
| ty::PredicateKind::ObjectSafe(_)
170+
| ty::PredicateKind::Ambiguous => {
171+
FulfillmentErrorCode::CodeSelectionError(
172+
SelectionError::Unimplemented,
173+
)
174+
}
175+
ty::PredicateKind::ConstEquate(..) => {
176+
bug!("unexpected goal: {goal:?}")
177+
}
178+
},
179+
root_obligation: obligation,
180+
});
181+
continue;
182+
}
183+
};
157184
// Push any nested goals that we get from unifying our canonical response
158185
// with our obligation onto the fulfillment context.
159186
self.obligations.extend(nested_goals.into_iter().map(|goal| {
+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// compile-flags: -Ztrack-trait-obligations
2+
// run-pass
3+
4+
// Just making sure this flag is accepted and doesn't crash the compiler
5+
use traits::IntoString;
6+
7+
fn does_impl_into_string<T: IntoString>(_: T) {}
8+
9+
fn main() {
10+
let v = vec![(0, 1), (2, 3)];
11+
12+
does_impl_into_string(v);
13+
}
14+
15+
mod traits {
16+
pub trait IntoString {
17+
fn to_string(&self) -> String;
18+
}
19+
20+
impl IntoString for (i32, i32) {
21+
fn to_string(&self) -> String {
22+
format!("({}, {})", self.0, self.1)
23+
}
24+
}
25+
26+
impl<T: IntoString> IntoString for Vec<T> {
27+
fn to_string(&self) -> String {
28+
let s = self.iter().map(|v| v.to_string()).collect::<Vec<_>>().join(", ");
29+
format!("[{s}]")
30+
}
31+
}
32+
}

0 commit comments

Comments
 (0)
Please sign in to comment.