Skip to content

Commit

Permalink
Flatten aggregates into locals.
Browse files Browse the repository at this point in the history
  • Loading branch information
cjgillot committed Sep 3, 2021
1 parent ad3407f commit c1accaf
Show file tree
Hide file tree
Showing 27 changed files with 987 additions and 390 deletions.
211 changes: 211 additions & 0 deletions compiler/rustc_mir/src/transform/flatten_locals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
use crate::transform::MirPass;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_index::vec::{Idx, IndexVec};
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;
use std::collections::hash_map::Entry;

pub struct FlattenLocals;

impl<'tcx> MirPass<'tcx> for FlattenLocals {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
if tcx.sess.mir_opt_level() < 4 {
return;
}

let replacements = compute_flattening(tcx, body);
let mut all_dead_locals = FxHashSet::default();
all_dead_locals.extend(replacements.fields.keys().map(|p| p.local));
if all_dead_locals.is_empty() {
return;
}

ReplacementVisitor { tcx, map: &replacements, all_dead_locals }.visit_body(body);

let mut replaced_locals: IndexVec<_, _> = IndexVec::new();
for (k, v) in replacements.fields {
replaced_locals.ensure_contains_elem(k.local, || Vec::new());
replaced_locals[k.local].push(v)
}
// Sort locals to avoid depending on FxHashMap order.
for v in replaced_locals.iter_mut() {
v.sort_unstable()
}
for bbdata in body.basic_blocks_mut().iter_mut() {
bbdata.expand_statements(|stmt| {
let source_info = stmt.source_info;
let (live, origin_local) = match &stmt.kind {
StatementKind::StorageLive(l) => (true, *l),
StatementKind::StorageDead(l) => (false, *l),
_ => return None,
};
replaced_locals.get(origin_local).map(move |final_locals| {
final_locals.iter().map(move |&l| {
let kind = if live {
StatementKind::StorageLive(l)
} else {
StatementKind::StorageDead(l)
};
Statement { source_info, kind }
})
})
});
}
}
}

fn escaping_locals(body: &Body<'_>) -> FxHashSet<Local> {
let mut set: FxHashSet<_> = (0..body.arg_count + 1).map(Local::new).collect();
for (local, decl) in body.local_decls().iter_enumerated() {
if decl.ty.is_union() || decl.ty.is_enum() {
set.insert(local);
}
}
let mut visitor = EscapeVisitor { set };
visitor.visit_body(body);
return visitor.set;

struct EscapeVisitor {
set: FxHashSet<Local>,
}

impl Visitor<'_> for EscapeVisitor {
fn visit_local(&mut self, local: &Local, _: PlaceContext, _: Location) {
self.set.insert(*local);
}

fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
// Mirror the implementation in PreFlattenVisitor.
if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
return;
}
self.super_place(place, context, location);
}

fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue {
if !place.is_indirect() {
// Raw pointers may be used to access anything inside the enclosing place.
self.set.insert(place.local);
return;
}
}
self.super_rvalue(rvalue, location)
}

fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
if let StatementKind::StorageLive(..) | StatementKind::StorageDead(..) = statement.kind
{
// Storage statements are expanded in run_pass.
return;
}
self.super_statement(statement, location)
}

fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
if let TerminatorKind::Drop { place, .. }
| TerminatorKind::DropAndReplace { place, .. } = terminator.kind
{
if !place.is_indirect() {
// Raw pointers may be used to access anything inside the enclosing place.
self.set.insert(place.local);
return;
}
}
self.super_terminator(terminator, location);
}
}
}

#[derive(Default, Debug)]
struct ReplacementMap<'tcx> {
fields: FxHashMap<PlaceRef<'tcx>, Local>,
}

fn compute_flattening<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> ReplacementMap<'tcx> {
let escaping = escaping_locals(&*body);
let (basic_blocks, local_decls, var_debug_info) =
body.basic_blocks_local_decls_mut_and_var_debug_info();
let mut visitor =
PreFlattenVisitor { tcx, escaping, local_decls: local_decls, map: Default::default() };
for (block, bbdata) in basic_blocks.iter_enumerated() {
visitor.visit_basic_block_data(block, bbdata);
}
for var_debug_info in &*var_debug_info {
visitor.visit_var_debug_info(var_debug_info);
}
return visitor.map;

struct PreFlattenVisitor<'tcx, 'll> {
tcx: TyCtxt<'tcx>,
local_decls: &'ll mut LocalDecls<'tcx>,
escaping: FxHashSet<Local>,
map: ReplacementMap<'tcx>,
}

impl<'tcx, 'll> PreFlattenVisitor<'tcx, 'll> {
fn create_place(&mut self, place: PlaceRef<'tcx>) {
if self.escaping.contains(&place.local) {
return;
}

match self.map.fields.entry(place) {
Entry::Occupied(_) => {}
Entry::Vacant(v) => {
let ty = place.ty(&*self.local_decls, self.tcx).ty;
let local = self.local_decls.push(LocalDecl {
ty,
user_ty: None,
..self.local_decls[place.local].clone()
});
v.insert(local);
}
}
}
}

impl<'tcx, 'll> Visitor<'tcx> for PreFlattenVisitor<'tcx, 'll> {
fn visit_place(&mut self, place: &Place<'tcx>, _: PlaceContext, _: Location) {
if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
self.create_place(pr)
}
}
}
}

struct ReplacementVisitor<'tcx, 'll> {
tcx: TyCtxt<'tcx>,
map: &'ll ReplacementMap<'tcx>,
all_dead_locals: FxHashSet<Local>,
}

impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
if let StatementKind::StorageLive(..) | StatementKind::StorageDead(..) = statement.kind {
// Storage statements are expanded in run_pass.
return;
}
self.super_statement(statement, location)
}

fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
if let &[PlaceElem::Field(..), ref rest @ ..] = &place.projection[..] {
let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
if let Some(local) = self.map.fields.get(&pr) {
*place = Place { local: *local, projection: self.tcx.intern_place_elems(&rest) };
return;
}
}
self.super_place(place, context, location)
}

fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
assert!(!self.all_dead_locals.contains(local),);
}
}
2 changes: 2 additions & 0 deletions compiler/rustc_mir/src/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub mod dest_prop;
pub mod dump_mir;
pub mod early_otherwise_branch;
pub mod elaborate_drops;
pub mod flatten_locals;
pub mod function_item_references;
pub mod generator;
pub mod inline;
Expand Down Expand Up @@ -499,6 +500,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// The main optimizations that we do on MIR.
let optimizations: &[&dyn MirPass<'tcx>] = &[
&remove_storage_markers::RemoveStorageMarkers,
&flatten_locals::FlattenLocals,
&remove_zsts::RemoveZsts,
&const_goto::ConstGoto,
&remove_unneeded_drops::RemoveUnneededDrops,
Expand Down
27 changes: 17 additions & 10 deletions src/test/mir-opt/const_prop/aggregate.main.ConstProp.diff
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,32 @@
let _1: i32; // in scope 0 at $DIR/aggregate.rs:5:9: 5:10
let mut _2: i32; // in scope 0 at $DIR/aggregate.rs:5:13: 5:24
let mut _3: (i32, i32, i32); // in scope 0 at $DIR/aggregate.rs:5:13: 5:22
let mut _4: i32; // in scope 0 at $DIR/aggregate.rs:5:13: 5:22
let mut _5: i32; // in scope 0 at $DIR/aggregate.rs:5:13: 5:22
let mut _6: i32; // in scope 0 at $DIR/aggregate.rs:5:13: 5:22
scope 1 {
debug x => _1; // in scope 1 at $DIR/aggregate.rs:5:9: 5:10
}

bb0: {
StorageLive(_1); // scope 0 at $DIR/aggregate.rs:5:9: 5:10
StorageLive(_2); // scope 0 at $DIR/aggregate.rs:5:13: 5:24
StorageLive(_3); // scope 0 at $DIR/aggregate.rs:5:13: 5:22
(_3.0: i32) = const 0_i32; // scope 0 at $DIR/aggregate.rs:5:13: 5:22
(_3.1: i32) = const 1_i32; // scope 0 at $DIR/aggregate.rs:5:13: 5:22
(_3.2: i32) = const 2_i32; // scope 0 at $DIR/aggregate.rs:5:13: 5:22
- _2 = (_3.1: i32); // scope 0 at $DIR/aggregate.rs:5:13: 5:24
nop; // scope 0 at $DIR/aggregate.rs:5:9: 5:10
nop; // scope 0 at $DIR/aggregate.rs:5:13: 5:24
StorageLive(_4); // scope 0 at $DIR/aggregate.rs:5:13: 5:22
StorageLive(_5); // scope 0 at $DIR/aggregate.rs:5:13: 5:22
StorageLive(_6); // scope 0 at $DIR/aggregate.rs:5:13: 5:22
_4 = const 0_i32; // scope 0 at $DIR/aggregate.rs:5:13: 5:22
_5 = const 1_i32; // scope 0 at $DIR/aggregate.rs:5:13: 5:22
_6 = const 2_i32; // scope 0 at $DIR/aggregate.rs:5:13: 5:22
- _2 = _5; // scope 0 at $DIR/aggregate.rs:5:13: 5:24
- _1 = Add(move _2, const 0_i32); // scope 0 at $DIR/aggregate.rs:5:13: 5:28
+ _2 = const 1_i32; // scope 0 at $DIR/aggregate.rs:5:13: 5:24
+ _1 = const 1_i32; // scope 0 at $DIR/aggregate.rs:5:13: 5:28
StorageDead(_2); // scope 0 at $DIR/aggregate.rs:5:27: 5:28
StorageDead(_3); // scope 0 at $DIR/aggregate.rs:5:28: 5:29
nop; // scope 0 at $DIR/aggregate.rs:5:27: 5:28
StorageDead(_4); // scope 0 at $DIR/aggregate.rs:5:28: 5:29
StorageDead(_5); // scope 0 at $DIR/aggregate.rs:5:28: 5:29
StorageDead(_6); // scope 0 at $DIR/aggregate.rs:5:28: 5:29
nop; // scope 0 at $DIR/aggregate.rs:4:11: 6:2
StorageDead(_1); // scope 0 at $DIR/aggregate.rs:6:1: 6:2
nop; // scope 0 at $DIR/aggregate.rs:6:1: 6:2
return; // scope 0 at $DIR/aggregate.rs:6:2: 6:2
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
let mut _6: usize; // in scope 0 at $DIR/optimizes_into_variable.rs:13:13: 13:34
let mut _7: bool; // in scope 0 at $DIR/optimizes_into_variable.rs:13:13: 13:34
let mut _9: Point; // in scope 0 at $DIR/optimizes_into_variable.rs:14:13: 14:36
let mut _10: u32; // in scope 0 at $DIR/optimizes_into_variable.rs:14:13: 14:36
let mut _11: u32; // in scope 0 at $DIR/optimizes_into_variable.rs:14:13: 14:36
scope 1 {
debug x => _1; // in scope 1 at $DIR/optimizes_into_variable.rs:12:9: 12:10
let _3: i32; // in scope 1 at $DIR/optimizes_into_variable.rs:13:9: 13:10
Expand All @@ -23,7 +25,7 @@
}

bb0: {
StorageLive(_1); // scope 0 at $DIR/optimizes_into_variable.rs:12:9: 12:10
nop; // scope 0 at $DIR/optimizes_into_variable.rs:12:9: 12:10
- _2 = CheckedAdd(const 2_i32, const 2_i32); // scope 0 at $DIR/optimizes_into_variable.rs:12:13: 12:18
- assert(!move (_2.1: bool), "attempt to compute `{} + {}`, which would overflow", const 2_i32, const 2_i32) -> bb1; // scope 0 at $DIR/optimizes_into_variable.rs:12:13: 12:18
+ _2 = const (4_i32, false); // scope 0 at $DIR/optimizes_into_variable.rs:12:13: 12:18
Expand All @@ -33,10 +35,10 @@
bb1: {
- _1 = move (_2.0: i32); // scope 0 at $DIR/optimizes_into_variable.rs:12:13: 12:18
+ _1 = const 4_i32; // scope 0 at $DIR/optimizes_into_variable.rs:12:13: 12:18
StorageLive(_3); // scope 1 at $DIR/optimizes_into_variable.rs:13:9: 13:10
StorageLive(_4); // scope 1 at $DIR/optimizes_into_variable.rs:13:13: 13:31
nop; // scope 1 at $DIR/optimizes_into_variable.rs:13:9: 13:10
nop; // scope 1 at $DIR/optimizes_into_variable.rs:13:13: 13:31
_4 = [const 0_i32, const 1_i32, const 2_i32, const 3_i32, const 4_i32, const 5_i32]; // scope 1 at $DIR/optimizes_into_variable.rs:13:13: 13:31
StorageLive(_5); // scope 1 at $DIR/optimizes_into_variable.rs:13:32: 13:33
nop; // scope 1 at $DIR/optimizes_into_variable.rs:13:32: 13:33
_5 = const 3_usize; // scope 1 at $DIR/optimizes_into_variable.rs:13:32: 13:33
_6 = const 6_usize; // scope 1 at $DIR/optimizes_into_variable.rs:13:13: 13:34
- _7 = Lt(_5, _6); // scope 1 at $DIR/optimizes_into_variable.rs:13:13: 13:34
Expand All @@ -48,19 +50,21 @@
bb2: {
- _3 = _4[_5]; // scope 1 at $DIR/optimizes_into_variable.rs:13:13: 13:34
+ _3 = const 3_i32; // scope 1 at $DIR/optimizes_into_variable.rs:13:13: 13:34
StorageDead(_5); // scope 1 at $DIR/optimizes_into_variable.rs:13:34: 13:35
StorageDead(_4); // scope 1 at $DIR/optimizes_into_variable.rs:13:34: 13:35
StorageLive(_8); // scope 2 at $DIR/optimizes_into_variable.rs:14:9: 14:10
StorageLive(_9); // scope 2 at $DIR/optimizes_into_variable.rs:14:13: 14:36
(_9.0: u32) = const 12_u32; // scope 2 at $DIR/optimizes_into_variable.rs:14:13: 14:36
(_9.1: u32) = const 42_u32; // scope 2 at $DIR/optimizes_into_variable.rs:14:13: 14:36
- _8 = (_9.1: u32); // scope 2 at $DIR/optimizes_into_variable.rs:14:13: 14:38
nop; // scope 1 at $DIR/optimizes_into_variable.rs:13:34: 13:35
nop; // scope 1 at $DIR/optimizes_into_variable.rs:13:34: 13:35
nop; // scope 2 at $DIR/optimizes_into_variable.rs:14:9: 14:10
StorageLive(_10); // scope 2 at $DIR/optimizes_into_variable.rs:14:13: 14:36
StorageLive(_11); // scope 2 at $DIR/optimizes_into_variable.rs:14:13: 14:36
_10 = const 12_u32; // scope 2 at $DIR/optimizes_into_variable.rs:14:13: 14:36
_11 = const 42_u32; // scope 2 at $DIR/optimizes_into_variable.rs:14:13: 14:36
- _8 = _11; // scope 2 at $DIR/optimizes_into_variable.rs:14:13: 14:38
+ _8 = const 42_u32; // scope 2 at $DIR/optimizes_into_variable.rs:14:13: 14:38
StorageDead(_9); // scope 2 at $DIR/optimizes_into_variable.rs:14:38: 14:39
StorageDead(_10); // scope 2 at $DIR/optimizes_into_variable.rs:14:38: 14:39
StorageDead(_11); // scope 2 at $DIR/optimizes_into_variable.rs:14:38: 14:39
nop; // scope 0 at $DIR/optimizes_into_variable.rs:11:11: 15:2
StorageDead(_8); // scope 2 at $DIR/optimizes_into_variable.rs:15:1: 15:2
StorageDead(_3); // scope 1 at $DIR/optimizes_into_variable.rs:15:1: 15:2
StorageDead(_1); // scope 0 at $DIR/optimizes_into_variable.rs:15:1: 15:2
nop; // scope 2 at $DIR/optimizes_into_variable.rs:15:1: 15:2
nop; // scope 1 at $DIR/optimizes_into_variable.rs:15:1: 15:2
nop; // scope 0 at $DIR/optimizes_into_variable.rs:15:1: 15:2
return; // scope 0 at $DIR/optimizes_into_variable.rs:15:2: 15:2
}
}
Expand Down
Loading

0 comments on commit c1accaf

Please sign in to comment.