@@ -1296,10 +1296,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1296
1296
}
1297
1297
let upvar_locals = upvar_collector. finish ( ) ;
1298
1298
tracing:: info!( "Upvar locals: {:?}" , upvar_locals) ;
1299
- tracing:: info!( "Upvar count: {:?}" , upvars. len( ) ) ;
1300
- if upvar_locals. len ( ) != upvars. len ( ) {
1301
- eprintln ! ( "{:#?}" , body) ;
1302
- assert_eq ! ( upvar_locals. len( ) , upvars. len( ) ) ;
1299
+ tracing:: info!( "Expected upvar count: {:?}" , upvars. len( ) ) ;
1300
+
1301
+ let mut replacer = ReplaceLocalWithGeneratorFieldAccess { tcx, upvar_locals } ;
1302
+ for ( block, data) in body. basic_blocks_mut ( ) . iter_enumerated_mut ( ) {
1303
+ replacer. visit_basic_block_data ( block, data) ;
1303
1304
}
1304
1305
1305
1306
let first_block = & mut body. basic_blocks_mut ( ) [ BasicBlock :: new ( 0 ) ] ;
@@ -1388,23 +1389,26 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
1388
1389
1389
1390
/// Finds locals that are assigned from generator upvars.
1390
1391
#[ derive( Default ) ]
1391
- struct ExtractGeneratorUpvarLocals {
1392
- upvar_locals : FxHashMap < Field , Vec < Local > > ,
1392
+ struct ExtractGeneratorUpvarLocals < ' tcx > {
1393
+ upvar_locals : FxHashMap < Local , Rvalue < ' tcx > > ,
1393
1394
}
1394
1395
1395
- impl ExtractGeneratorUpvarLocals {
1396
- fn finish ( self ) -> FxHashMap < Field , Vec < Local > > {
1396
+ impl < ' tcx > ExtractGeneratorUpvarLocals < ' tcx > {
1397
+ fn finish ( self ) -> FxHashMap < Local , Rvalue < ' tcx > > {
1397
1398
self . upvar_locals
1398
1399
}
1399
1400
}
1400
1401
1401
- impl < ' tcx > Visitor < ' tcx > for ExtractGeneratorUpvarLocals {
1402
+ impl < ' tcx > Visitor < ' tcx > for ExtractGeneratorUpvarLocals < ' tcx > {
1402
1403
fn visit_assign ( & mut self , place : & Place < ' tcx > , rvalue : & Rvalue < ' tcx > , location : Location ) {
1403
1404
let mut visitor = FindGeneratorFieldAccess { field_index : None } ;
1404
1405
visitor. visit_rvalue ( rvalue, location) ;
1405
1406
1406
- if let Some ( field_index) = visitor. field_index {
1407
- self . upvar_locals . entry ( field_index) . or_insert_with ( || vec ! [ ] ) . push ( place. local ) ;
1407
+ if let Some ( _) = visitor. field_index {
1408
+ if !place. projection . is_empty ( ) {
1409
+ panic ! ( "Non-empty projectsion: {place:#?}" ) ;
1410
+ }
1411
+ self . upvar_locals . insert ( place. local , rvalue. clone ( ) ) ;
1408
1412
}
1409
1413
}
1410
1414
}
@@ -1420,7 +1424,7 @@ impl<'tcx> Visitor<'tcx> for FindGeneratorFieldAccess {
1420
1424
_context : PlaceContext ,
1421
1425
_location : Location ,
1422
1426
) {
1423
- tracing:: info!( "visit_projection, place_ref={:#?}" , place_ref ) ;
1427
+ tracing:: info!( "visit_projection, place_ref={place_ref :#?}" ) ;
1424
1428
1425
1429
if place_ref. local . as_usize ( ) == 1 {
1426
1430
if !place_ref. projection . is_empty ( ) {
@@ -1433,6 +1437,42 @@ impl<'tcx> Visitor<'tcx> for FindGeneratorFieldAccess {
1433
1437
}
1434
1438
}
1435
1439
1440
+ struct ReplaceLocalWithGeneratorFieldAccess < ' tcx > {
1441
+ tcx : TyCtxt < ' tcx > ,
1442
+ upvar_locals : FxHashMap < Local , Rvalue < ' tcx > > ,
1443
+ }
1444
+
1445
+ impl < ' tcx > MutVisitor < ' tcx > for ReplaceLocalWithGeneratorFieldAccess < ' tcx > {
1446
+ fn tcx ( & self ) -> TyCtxt < ' tcx > {
1447
+ self . tcx
1448
+ }
1449
+
1450
+ fn visit_basic_block_data ( & mut self , block : BasicBlock , data : & mut BasicBlockData < ' tcx > ) {
1451
+ for ( statement_index, statement) in data. statements . iter_mut ( ) . enumerate ( ) {
1452
+ if let StatementKind :: Assign ( box ( place, _rvalue) ) = & statement. kind {
1453
+ // Upvar was stored into a local => turn into nop
1454
+ if self . upvar_locals . contains_key ( & place. local ) {
1455
+ * statement = Statement { source_info : statement. source_info , kind : StatementKind :: Nop } ;
1456
+ }
1457
+ }
1458
+ self . visit_statement ( statement, Location { block, statement_index } ) ;
1459
+ }
1460
+ }
1461
+
1462
+ fn visit_place ( & mut self , source : & mut Place < ' tcx > , _context : PlaceContext , _location : Location ) {
1463
+ if let Some ( rvalue) = self . upvar_locals . get ( & source. local ) {
1464
+ match rvalue {
1465
+ Rvalue :: Use ( Operand :: Copy ( place) | Operand :: Move ( place) ) => {
1466
+ tracing:: info!( "Replacing {source:#?} with {place:#?}" ) ;
1467
+ * source = * place;
1468
+ }
1469
+ _ => { }
1470
+ }
1471
+ }
1472
+ }
1473
+ }
1474
+
1475
+
1436
1476
/// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields
1437
1477
/// in the generator state machine but whose storage is not marked as conflicting
1438
1478
///
0 commit comments