Skip to content

Commit 36d1915

Browse files
authored
Rollup merge of #123518 - compiler-errors:by-move-fixes, r=oli-obk
Fix `ByMove` coroutine-closure shim (for 2021 precise closure capturing behavior) This PR reworks the way that we perform the `ByMove` coroutine-closure shim to account for the fact that the upvars of the outer coroutine-closure and the inner coroutine might not line up due to edition-2021 closure capture rules changes. Specifically, the number of upvars may differ *and/or* the inner coroutine may have additional projections applied to an upvar. This PR reworks the information we pass into the `ByMoveBody` MIR visitor to account for both of these facts. I tried to leave comments explaining exactly what everything is doing, but let me know if you have questions. r? oli-obk
2 parents 0e27c99 + ad0fcac commit 36d1915

8 files changed

+461
-35
lines changed

Diff for: compiler/rustc_mir_transform/src/coroutine/by_move_body.rs

+162-35
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,24 @@
5858
//! borrowing from the outer closure, and we simply peel off a `deref` projection
5959
//! from them. This second body is stored alongside the first body, and optimized
6060
//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
61-
//! we use this "by move" body instead.
62-
63-
use itertools::Itertools;
61+
//! we use this "by-move" body instead.
62+
//!
63+
//! ## How does this work?
64+
//!
65+
//! This pass essentially remaps the body of the (child) closure of the coroutine-closure
66+
//! to take the set of upvars of the parent closure by value. This at least requires
67+
//! changing a by-ref upvar to be by-value in the case that the outer coroutine-closure
68+
//! captures something by value; however, it may also require renumbering field indices
69+
//! in case precise captures (edition 2021 closure capture rules) caused the inner coroutine
70+
//! to split one field capture into two.
6471
65-
use rustc_data_structures::unord::UnordSet;
72+
use rustc_data_structures::unord::UnordMap;
6673
use rustc_hir as hir;
74+
use rustc_middle::hir::place::{PlaceBase, Projection, ProjectionKind};
6775
use rustc_middle::mir::visit::MutVisitor;
6876
use rustc_middle::mir::{self, dump_mir, MirPass};
6977
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt};
70-
use rustc_target::abi::FieldIdx;
78+
use rustc_target::abi::{FieldIdx, VariantIdx};
7179

7280
pub struct ByMoveBody;
7381

@@ -116,32 +124,116 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
116124
.tuple_fields()
117125
.len();
118126

119-
let mut by_ref_fields = UnordSet::default();
120-
for (idx, (coroutine_capture, parent_capture)) in tcx
127+
let mut field_remapping = UnordMap::default();
128+
129+
// One parent capture may correspond to several child captures if we end up
130+
// refining the set of captures via edition-2021 precise captures. We want to
131+
// match up any number of child captures with one parent capture, so we keep
132+
// peeking off this `Peekable` until the child doesn't match anymore.
133+
let mut parent_captures =
134+
tcx.closure_captures(parent_def_id).iter().copied().enumerate().peekable();
135+
// Make sure we use every field at least once, b/c why are we capturing something
136+
// if it's not used in the inner coroutine.
137+
let mut field_used_at_least_once = false;
138+
139+
for (child_field_idx, child_capture) in tcx
121140
.closure_captures(coroutine_def_id)
122141
.iter()
142+
.copied()
123143
// By construction we capture all the args first.
124144
.skip(num_args)
125-
.zip_eq(tcx.closure_captures(parent_def_id))
126145
.enumerate()
127146
{
128-
// This upvar is captured by-move from the parent closure, but by-ref
129-
// from the inner async block. That means that it's being borrowed from
130-
// the outer closure body -- we need to change the coroutine to take the
131-
// upvar by value.
132-
if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
133-
assert_ne!(
134-
coroutine_kind,
135-
ty::ClosureKind::FnOnce,
136-
"`FnOnce` coroutine-closures return coroutines that capture from \
137-
their body; it will always result in a borrowck error!"
147+
loop {
148+
let Some(&(parent_field_idx, parent_capture)) = parent_captures.peek() else {
149+
bug!("we ran out of parent captures!")
150+
};
151+
152+
let PlaceBase::Upvar(parent_base) = parent_capture.place.base else {
153+
bug!("expected capture to be an upvar");
154+
};
155+
let PlaceBase::Upvar(child_base) = child_capture.place.base else {
156+
bug!("expected capture to be an upvar");
157+
};
158+
159+
assert!(
160+
child_capture.place.projections.len() >= parent_capture.place.projections.len()
138161
);
139-
by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
162+
// A parent matches a child they share the same prefix of projections.
163+
// The child may have more, if it is capturing sub-fields out of
164+
// something that is captured by-move in the parent closure.
165+
if parent_base.var_path.hir_id != child_base.var_path.hir_id
166+
|| !std::iter::zip(
167+
&child_capture.place.projections,
168+
&parent_capture.place.projections,
169+
)
170+
.all(|(child, parent)| child.kind == parent.kind)
171+
{
172+
// Make sure the field was used at least once.
173+
assert!(
174+
field_used_at_least_once,
175+
"we captured {parent_capture:#?} but it was not used in the child coroutine?"
176+
);
177+
field_used_at_least_once = false;
178+
// Skip this field.
179+
let _ = parent_captures.next().unwrap();
180+
continue;
181+
}
182+
183+
// Store this set of additional projections (fields and derefs).
184+
// We need to re-apply them later.
185+
let child_precise_captures =
186+
&child_capture.place.projections[parent_capture.place.projections.len()..];
187+
188+
// If the parent captures by-move, and the child captures by-ref, then we
189+
// need to peel an additional `deref` off of the body of the child.
190+
let needs_deref = child_capture.is_by_ref() && !parent_capture.is_by_ref();
191+
if needs_deref {
192+
assert_ne!(
193+
coroutine_kind,
194+
ty::ClosureKind::FnOnce,
195+
"`FnOnce` coroutine-closures return coroutines that capture from \
196+
their body; it will always result in a borrowck error!"
197+
);
198+
}
199+
200+
// Finally, store the type of the parent's captured place. We need
201+
// this when building the field projection in the MIR body later on.
202+
let mut parent_capture_ty = parent_capture.place.ty();
203+
parent_capture_ty = match parent_capture.info.capture_kind {
204+
ty::UpvarCapture::ByValue => parent_capture_ty,
205+
ty::UpvarCapture::ByRef(kind) => Ty::new_ref(
206+
tcx,
207+
tcx.lifetimes.re_erased,
208+
parent_capture_ty,
209+
kind.to_mutbl_lossy(),
210+
),
211+
};
212+
213+
field_remapping.insert(
214+
FieldIdx::from_usize(child_field_idx + num_args),
215+
(
216+
FieldIdx::from_usize(parent_field_idx + num_args),
217+
parent_capture_ty,
218+
needs_deref,
219+
child_precise_captures,
220+
),
221+
);
222+
223+
field_used_at_least_once = true;
224+
break;
140225
}
226+
}
227+
228+
// Pop the last parent capture
229+
if field_used_at_least_once {
230+
let _ = parent_captures.next().unwrap();
231+
}
232+
assert_eq!(parent_captures.next(), None, "leftover parent captures?");
141233

142-
// Make sure we're actually talking about the same capture.
143-
// FIXME(async_closures): We could look at the `hir::Upvar` instead?
144-
assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
234+
if coroutine_kind == ty::ClosureKind::FnOnce {
235+
assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len());
236+
return;
145237
}
146238

147239
let by_move_coroutine_ty = tcx
@@ -157,7 +249,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
157249
);
158250

159251
let mut by_move_body = body.clone();
160-
MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
252+
MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body);
161253
dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(()));
162254
by_move_body.source = mir::MirSource::from_instance(InstanceDef::CoroutineKindShim {
163255
coroutine_def_id: coroutine_def_id.to_def_id(),
@@ -168,7 +260,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
168260

169261
struct MakeByMoveBody<'tcx> {
170262
tcx: TyCtxt<'tcx>,
171-
by_ref_fields: UnordSet<FieldIdx>,
263+
field_remapping: UnordMap<FieldIdx, (FieldIdx, Ty<'tcx>, bool, &'tcx [Projection<'tcx>])>,
172264
by_move_coroutine_ty: Ty<'tcx>,
173265
}
174266

@@ -183,24 +275,59 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
183275
context: mir::visit::PlaceContext,
184276
location: mir::Location,
185277
) {
278+
// Initializing an upvar local always starts with `CAPTURE_STRUCT_LOCAL` and a
279+
// field projection. If this is in `field_remapping`, then it must not be an
280+
// arg from calling the closure, but instead an upvar.
186281
if place.local == ty::CAPTURE_STRUCT_LOCAL
187-
&& let Some((&mir::ProjectionElem::Field(idx, ty), projection)) =
282+
&& let Some((&mir::ProjectionElem::Field(idx, _), projection)) =
188283
place.projection.split_first()
189-
&& self.by_ref_fields.contains(&idx)
284+
&& let Some(&(remapped_idx, remapped_ty, needs_deref, additional_projections)) =
285+
self.field_remapping.get(&idx)
190286
{
191-
let (begin, end) = projection.split_first().unwrap();
192-
// FIXME(async_closures): I'm actually a bit surprised to see that we always
193-
// initially deref the by-ref upvars. If this is not actually true, then we
194-
// will at least get an ICE that explains why this isn't true :^)
195-
assert_eq!(*begin, mir::ProjectionElem::Deref);
196-
// Peel one ref off of the ty.
197-
let peeled_ty = ty.builtin_deref(true).unwrap().ty;
287+
// As noted before, if the parent closure captures a field by value, and
288+
// the child captures a field by ref, then for the by-move body we're
289+
// generating, we also are taking that field by value. Peel off a deref,
290+
// since a layer of reffing has now become redundant.
291+
let final_deref = if needs_deref {
292+
let Some((mir::ProjectionElem::Deref, projection)) = projection.split_first()
293+
else {
294+
bug!(
295+
"There should be at least a single deref for an upvar local initialization, found {projection:#?}"
296+
);
297+
};
298+
// There may be more derefs, since we may also implicitly reborrow
299+
// a captured mut pointer.
300+
projection
301+
} else {
302+
projection
303+
};
304+
305+
// The only thing that should be left is a deref, if the parent captured
306+
// an upvar by-ref.
307+
std::assert_matches::assert_matches!(final_deref, [] | [mir::ProjectionElem::Deref]);
308+
309+
// For all of the additional projections that come out of precise capturing,
310+
// re-apply these projections.
311+
let additional_projections =
312+
additional_projections.iter().map(|elem| match elem.kind {
313+
ProjectionKind::Deref => mir::ProjectionElem::Deref,
314+
ProjectionKind::Field(idx, VariantIdx::ZERO) => {
315+
mir::ProjectionElem::Field(idx, elem.ty)
316+
}
317+
_ => unreachable!("precise captures only through fields and derefs"),
318+
});
319+
320+
// We start out with an adjusted field index (and ty), representing the
321+
// upvar that we get from our parent closure. We apply any of the additional
322+
// projections to make sure that to the rest of the body of the closure, the
323+
// place looks the same, and then apply that final deref if necessary.
198324
*place = mir::Place {
199325
local: place.local,
200326
projection: self.tcx.mk_place_elems_from_iter(
201-
[mir::ProjectionElem::Field(idx, peeled_ty)]
327+
[mir::ProjectionElem::Field(remapped_idx, remapped_ty)]
202328
.into_iter()
203-
.chain(end.iter().copied()),
329+
.chain(additional_projections)
330+
.chain(final_deref.iter().copied()),
204331
),
205332
};
206333
}
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//@ aux-build:block-on.rs
2+
//@ run-pass
3+
//@ check-run-results
4+
//@ revisions: e2021 e2018
5+
//@[e2018] edition:2018
6+
//@[e2021] edition:2021
7+
8+
#![feature(async_closure)]
9+
10+
extern crate block_on;
11+
12+
async fn call_once(f: impl async FnOnce()) { f().await; }
13+
14+
pub async fn async_closure(x: &mut i32) {
15+
let c = async move || {
16+
*x += 1;
17+
};
18+
call_once(c).await;
19+
}
20+
21+
fn main() {
22+
block_on::block_on(async {
23+
let mut x = 0;
24+
async_closure(&mut x).await;
25+
assert_eq!(x, 1);
26+
});
27+
}
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//@ aux-build:block-on.rs
2+
//@ edition:2021
3+
//@ run-pass
4+
//@ check-run-results
5+
6+
#![feature(async_closure)]
7+
8+
extern crate block_on;
9+
10+
async fn call_once(f: impl async FnOnce()) {
11+
f().await;
12+
}
13+
14+
async fn async_main() {
15+
let x = &mut 0;
16+
let y = &mut 0;
17+
let c = async || {
18+
*x = 1;
19+
*y = 2;
20+
};
21+
call_once(c).await;
22+
println!("{x} {y}");
23+
}
24+
25+
fn main() {
26+
block_on::block_on(async_main());
27+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1 2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
after call
2+
after await
3+
fixed
4+
uncaptured
5+
6+
after call
7+
after await
8+
fixed
9+
uncaptured
10+
11+
after call
12+
after await
13+
fixed
14+
uncaptured
15+
16+
after call
17+
after await
18+
fixed
19+
untouched
20+
21+
after call
22+
drop first
23+
after await
24+
uncaptured
25+
26+
after call
27+
drop first
28+
after await
29+
uncaptured
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
after call
2+
after await
3+
fixed
4+
uncaptured
5+
6+
after call
7+
after await
8+
fixed
9+
uncaptured
10+
11+
after call
12+
fixed
13+
after await
14+
uncaptured
15+
16+
after call
17+
after await
18+
fixed
19+
untouched
20+
21+
after call
22+
drop first
23+
after await
24+
uncaptured
25+
26+
after call
27+
drop first
28+
after await
29+
uncaptured

0 commit comments

Comments
 (0)