Skip to content

Commit 0af63de

Browse files
committed
WIP: replace upvar locals with generator state field accesses
1 parent 833afb0 commit 0af63de

File tree

1 file changed

+52
-12
lines changed

1 file changed

+52
-12
lines changed

compiler/rustc_mir_transform/src/generator.rs

+52-12
Original file line numberDiff line numberDiff line change
@@ -1296,10 +1296,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
12961296
}
12971297
let upvar_locals = upvar_collector.finish();
12981298
tracing::info!("Upvar locals: {:?}", upvar_locals);
1299-
tracing::info!("Upvar count: {:?}", upvars.len());
1300-
if upvar_locals.len() != upvars.len() {
1301-
eprintln!("{:#?}", body);
1302-
assert_eq!(upvar_locals.len(), upvars.len());
1299+
tracing::info!("Expected upvar count: {:?}", upvars.len());
1300+
1301+
let mut replacer = ReplaceLocalWithGeneratorFieldAccess { tcx, upvar_locals };
1302+
for (block, data) in body.basic_blocks_mut().iter_enumerated_mut() {
1303+
replacer.visit_basic_block_data(block, data);
13031304
}
13041305

13051306
let first_block = &mut body.basic_blocks_mut()[BasicBlock::new(0)];
@@ -1388,23 +1389,26 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
13881389

13891390
/// Finds locals that are assigned from generator upvars.
13901391
#[derive(Default)]
1391-
struct ExtractGeneratorUpvarLocals {
1392-
upvar_locals: FxHashMap<Field, Vec<Local>>,
1392+
struct ExtractGeneratorUpvarLocals<'tcx> {
1393+
upvar_locals: FxHashMap<Local, Rvalue<'tcx>>,
13931394
}
13941395

1395-
impl ExtractGeneratorUpvarLocals {
1396-
fn finish(self) -> FxHashMap<Field, Vec<Local>> {
1396+
impl<'tcx> ExtractGeneratorUpvarLocals<'tcx> {
1397+
fn finish(self) -> FxHashMap<Local, Rvalue<'tcx>> {
13971398
self.upvar_locals
13981399
}
13991400
}
14001401

1401-
impl<'tcx> Visitor<'tcx> for ExtractGeneratorUpvarLocals {
1402+
impl<'tcx> Visitor<'tcx> for ExtractGeneratorUpvarLocals<'tcx> {
14021403
fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) {
14031404
let mut visitor = FindGeneratorFieldAccess { field_index: None };
14041405
visitor.visit_rvalue(rvalue, location);
14051406

1406-
if let Some(field_index) = visitor.field_index {
1407-
self.upvar_locals.entry(field_index).or_insert_with(|| vec![]).push(place.local);
1407+
if let Some(_) = visitor.field_index {
1408+
if !place.projection.is_empty() {
1409+
panic!("Non-empty projectsion: {place:#?}");
1410+
}
1411+
self.upvar_locals.insert(place.local, rvalue.clone());
14081412
}
14091413
}
14101414
}
@@ -1420,7 +1424,7 @@ impl<'tcx> Visitor<'tcx> for FindGeneratorFieldAccess {
14201424
_context: PlaceContext,
14211425
_location: Location,
14221426
) {
1423-
tracing::info!("visit_projection, place_ref={:#?}", place_ref);
1427+
tracing::info!("visit_projection, place_ref={place_ref:#?}");
14241428

14251429
if place_ref.local.as_usize() == 1 {
14261430
if !place_ref.projection.is_empty() {
@@ -1433,6 +1437,42 @@ impl<'tcx> Visitor<'tcx> for FindGeneratorFieldAccess {
14331437
}
14341438
}
14351439

1440+
struct ReplaceLocalWithGeneratorFieldAccess<'tcx> {
1441+
tcx: TyCtxt<'tcx>,
1442+
upvar_locals: FxHashMap<Local, Rvalue<'tcx>>,
1443+
}
1444+
1445+
impl<'tcx> MutVisitor<'tcx> for ReplaceLocalWithGeneratorFieldAccess<'tcx> {
1446+
fn tcx(&self) -> TyCtxt<'tcx> {
1447+
self.tcx
1448+
}
1449+
1450+
fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
1451+
for (statement_index, statement) in data.statements.iter_mut().enumerate() {
1452+
if let StatementKind::Assign(box (place, _rvalue)) = &statement.kind {
1453+
// Upvar was stored into a local => turn into nop
1454+
if self.upvar_locals.contains_key(&place.local) {
1455+
*statement = Statement { source_info: statement.source_info, kind: StatementKind::Nop };
1456+
}
1457+
}
1458+
self.visit_statement(statement, Location { block, statement_index });
1459+
}
1460+
}
1461+
1462+
fn visit_place(&mut self, source: &mut Place<'tcx>, _context: PlaceContext, _location: Location) {
1463+
if let Some(rvalue) = self.upvar_locals.get(&source.local) {
1464+
match rvalue {
1465+
Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) => {
1466+
tracing::info!("Replacing {source:#?} with {place:#?}");
1467+
*source = *place;
1468+
}
1469+
_ => {}
1470+
}
1471+
}
1472+
}
1473+
}
1474+
1475+
14361476
/// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields
14371477
/// in the generator state machine but whose storage is not marked as conflicting
14381478
///

0 commit comments

Comments
 (0)