Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve MIR modification #137203

Merged
merged 5 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 1 addition & 61 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

use std::borrow::Cow;
use std::fmt::{self, Debug, Formatter};
use std::iter;
use std::ops::{Index, IndexMut};
use std::{iter, mem};

pub use basic_blocks::BasicBlocks;
use either::Either;
Expand Down Expand Up @@ -1365,66 +1365,6 @@ impl<'tcx> BasicBlockData<'tcx> {
self.terminator.as_mut().expect("invalid terminator state")
}

pub fn retain_statements<F>(&mut self, mut f: F)
where
F: FnMut(&mut Statement<'_>) -> bool,
{
for s in &mut self.statements {
if !f(s) {
s.make_nop();
}
}
}

pub fn expand_statements<F, I>(&mut self, mut f: F)
where
F: FnMut(&mut Statement<'tcx>) -> Option<I>,
I: iter::TrustedLen<Item = Statement<'tcx>>,
{
// Gather all the iterators we'll need to splice in, and their positions.
let mut splices: Vec<(usize, I)> = vec![];
let mut extra_stmts = 0;
for (i, s) in self.statements.iter_mut().enumerate() {
if let Some(mut new_stmts) = f(s) {
if let Some(first) = new_stmts.next() {
// We can already store the first new statement.
*s = first;

// Save the other statements for optimized splicing.
let remaining = new_stmts.size_hint().0;
if remaining > 0 {
splices.push((i + 1 + extra_stmts, new_stmts));
extra_stmts += remaining;
}
} else {
s.make_nop();
}
}
}

// Splice in the new statements, from the end of the block.
// FIXME(eddyb) This could be more efficient with a "gap buffer"
// where a range of elements ("gap") is left uninitialized, with
// splicing adding new elements to the end of that gap and moving
// existing elements from before the gap to the end of the gap.
// For now, this is safe code, emulating a gap but initializing it.
let mut gap = self.statements.len()..self.statements.len() + extra_stmts;
self.statements.resize(
gap.end,
Statement { source_info: SourceInfo::outermost(DUMMY_SP), kind: StatementKind::Nop },
);
for (splice_start, new_stmts) in splices.into_iter().rev() {
let splice_end = splice_start + new_stmts.size_hint().0;
while gap.end > splice_end {
gap.start -= 1;
gap.end -= 1;
self.statements.swap(gap.start, gap.end);
}
self.statements.splice(splice_start..splice_end, new_stmts);
gap.end = splice_start;
}
}

pub fn visitable(&self, index: usize) -> &dyn MirVisitable<'tcx> {
if index < self.statements.len() { &self.statements[index] } else { &self.terminator }
}
Expand Down
9 changes: 0 additions & 9 deletions compiler/rustc_middle/src/mir/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,6 @@ impl Statement<'_> {
pub fn make_nop(&mut self) {
self.kind = StatementKind::Nop
}

/// Changes a statement to a nop and returns the original statement.
#[must_use = "If you don't need the statement, use `make_nop` instead"]
pub fn replace_nop(&mut self) -> Self {
Statement {
source_info: self.source_info,
kind: mem::replace(&mut self.kind, StatementKind::Nop),
}
}
}

impl<'tcx> StatementKind<'tcx> {
Expand Down
11 changes: 6 additions & 5 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,13 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {

fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
// Remove StorageLive and StorageDead statements for remapped locals
data.retain_statements(|s| match s.kind {
StatementKind::StorageLive(l) | StatementKind::StorageDead(l) => {
!self.remap.contains(l)
for s in &mut data.statements {
if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = s.kind
&& self.remap.contains(l)
{
s.make_nop();
}
_ => true,
});
}

let ret_val = match data.terminator().kind {
TerminatorKind::Return => {
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_mir_transform/src/elaborate_drops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ impl<'a, 'tcx> ElaborateDropsCtxt<'a, 'tcx> {
..
} = data.terminator().kind
{
assert!(!self.patch.is_patched(bb));
assert!(!self.patch.is_term_patched(bb));

let loc = Location { block: tgt, statement_index: 0 };
let path = self.move_data().rev_lookup.find(destination.as_ref());
Expand Down Expand Up @@ -462,7 +462,7 @@ impl<'a, 'tcx> ElaborateDropsCtxt<'a, 'tcx> {
// a Goto; see `MirPatch::new`).
}
_ => {
assert!(!self.patch.is_patched(bb));
assert!(!self.patch.is_term_patched(bb));
}
}
}
Expand All @@ -486,7 +486,7 @@ impl<'a, 'tcx> ElaborateDropsCtxt<'a, 'tcx> {
..
} = data.terminator().kind
{
assert!(!self.patch.is_patched(bb));
assert!(!self.patch.is_term_patched(bb));

let loc = Location { block: bb, statement_index: data.statements.len() };
let path = self.move_data().rev_lookup.find(destination.as_ref());
Expand Down
186 changes: 74 additions & 112 deletions compiler/rustc_mir_transform/src/large_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use rustc_middle::ty::util::IntTypeExt;
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
use rustc_session::Session;

use crate::patch::MirPatch;

/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
/// enough discrepancy between them.
///
Expand Down Expand Up @@ -41,31 +43,34 @@ impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
let mut alloc_cache = FxHashMap::default();
let typing_env = body.typing_env(tcx);

let blocks = body.basic_blocks.as_mut();
let local_decls = &mut body.local_decls;
let mut patch = MirPatch::new(body);

for bb in blocks {
bb.expand_statements(|st| {
for (block, data) in body.basic_blocks.as_mut().iter_enumerated_mut() {
for (statement_index, st) in data.statements.iter_mut().enumerate() {
let StatementKind::Assign(box (
lhs,
Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
)) = &st.kind
else {
return None;
continue;
};

let ty = lhs.ty(local_decls, tcx).ty;
let location = Location { block, statement_index };

let (adt_def, num_variants, alloc_id) =
self.candidate(tcx, typing_env, ty, &mut alloc_cache)?;
let ty = lhs.ty(&body.local_decls, tcx).ty;

let source_info = st.source_info;
let span = source_info.span;
let Some((adt_def, num_variants, alloc_id)) =
self.candidate(tcx, typing_env, ty, &mut alloc_cache)
else {
continue;
};

let span = st.source_info.span;

let tmp_ty = Ty::new_array(tcx, tcx.types.usize, num_variants as u64);
let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span));
let store_live =
Statement { source_info, kind: StatementKind::StorageLive(size_array_local) };
let size_array_local = patch.new_temp(tmp_ty, span);

let store_live = StatementKind::StorageLive(size_array_local);

let place = Place::from(size_array_local);
let constant_vals = ConstOperand {
Expand All @@ -77,108 +82,63 @@ impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
),
};
let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals)));
let const_assign =
Statement { source_info, kind: StatementKind::Assign(Box::new((place, rval))) };

let discr_place = Place::from(
local_decls.push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)),
);
let store_discr = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
discr_place,
Rvalue::Discriminant(*rhs),
))),
};

let discr_cast_place =
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
let cast_discr = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
discr_cast_place,
Rvalue::Cast(
CastKind::IntToInt,
Operand::Copy(discr_place),
tcx.types.usize,
),
))),
};

let size_place =
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
let store_size = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
size_place,
Rvalue::Use(Operand::Copy(Place {
local: size_array_local,
projection: tcx
.mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
})),
))),
};

let dst =
Place::from(local_decls.push(LocalDecl::new(Ty::new_mut_ptr(tcx, ty), span)));
let dst_ptr = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
dst,
Rvalue::RawPtr(RawPtrKind::Mut, *lhs),
))),
};
let const_assign = StatementKind::Assign(Box::new((place, rval)));

let discr_place =
Place::from(patch.new_temp(adt_def.repr().discr_type().to_ty(tcx), span));
let store_discr =
StatementKind::Assign(Box::new((discr_place, Rvalue::Discriminant(*rhs))));

let discr_cast_place = Place::from(patch.new_temp(tcx.types.usize, span));
let cast_discr = StatementKind::Assign(Box::new((
discr_cast_place,
Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_place), tcx.types.usize),
)));

let size_place = Place::from(patch.new_temp(tcx.types.usize, span));
let store_size = StatementKind::Assign(Box::new((
size_place,
Rvalue::Use(Operand::Copy(Place {
local: size_array_local,
projection: tcx.mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
})),
)));

let dst = Place::from(patch.new_temp(Ty::new_mut_ptr(tcx, ty), span));
let dst_ptr =
StatementKind::Assign(Box::new((dst, Rvalue::RawPtr(RawPtrKind::Mut, *lhs))));

let dst_cast_ty = Ty::new_mut_ptr(tcx, tcx.types.u8);
let dst_cast_place =
Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span)));
let dst_cast = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
dst_cast_place,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
))),
};
let dst_cast_place = Place::from(patch.new_temp(dst_cast_ty, span));
let dst_cast = StatementKind::Assign(Box::new((
dst_cast_place,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
)));

let src =
Place::from(local_decls.push(LocalDecl::new(Ty::new_imm_ptr(tcx, ty), span)));
let src_ptr = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
src,
Rvalue::RawPtr(RawPtrKind::Const, *rhs),
))),
};
let src = Place::from(patch.new_temp(Ty::new_imm_ptr(tcx, ty), span));
let src_ptr =
StatementKind::Assign(Box::new((src, Rvalue::RawPtr(RawPtrKind::Const, *rhs))));

let src_cast_ty = Ty::new_imm_ptr(tcx, tcx.types.u8);
let src_cast_place =
Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span)));
let src_cast = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
src_cast_place,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
))),
};
let src_cast_place = Place::from(patch.new_temp(src_cast_ty, span));
let src_cast = StatementKind::Assign(Box::new((
src_cast_place,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
)));

let deinit_old =
Statement { source_info, kind: StatementKind::Deinit(Box::new(dst)) };

let copy_bytes = Statement {
source_info,
kind: StatementKind::Intrinsic(Box::new(
NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
src: Operand::Copy(src_cast_place),
dst: Operand::Copy(dst_cast_place),
count: Operand::Copy(size_place),
}),
)),
};
let deinit_old = StatementKind::Deinit(Box::new(dst));

let copy_bytes = StatementKind::Intrinsic(Box::new(
NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
src: Operand::Copy(src_cast_place),
dst: Operand::Copy(dst_cast_place),
count: Operand::Copy(size_place),
}),
));

let store_dead =
Statement { source_info, kind: StatementKind::StorageDead(size_array_local) };
let store_dead = StatementKind::StorageDead(size_array_local);

let iter = [
let stmts = [
store_live,
const_assign,
store_discr,
Expand All @@ -191,14 +151,16 @@ impl<'tcx> crate::MirPass<'tcx> for EnumSizeOpt {
deinit_old,
copy_bytes,
store_dead,
]
.into_iter();
];
for stmt in stmts {
patch.add_statement(location, stmt);
}

st.make_nop();

Some(iter)
});
}
}

patch.apply(body);
}

fn is_required(&self) -> bool {
Expand Down
Loading
Loading