Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions compiler/rustc_mir_transform/src/patch.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use rustc_index::{Idx, IndexVec};
use rustc_data_structures::fx::FxHashMap;
use rustc_index::Idx;
use rustc_middle::mir::*;
use rustc_middle::ty::Ty;
use rustc_span::Span;
Expand All @@ -9,7 +10,7 @@ use tracing::debug;
/// and replacement of terminators, and then apply the queued changes all at
/// once with `apply`. This is useful for MIR transformation passes.
pub(crate) struct MirPatch<'tcx> {
term_patch_map: IndexVec<BasicBlock, Option<TerminatorKind<'tcx>>>,
term_patch_map: FxHashMap<BasicBlock, TerminatorKind<'tcx>>,
new_blocks: Vec<BasicBlockData<'tcx>>,
new_statements: Vec<(Location, StatementKind<'tcx>)>,
new_locals: Vec<LocalDecl<'tcx>>,
Expand All @@ -22,17 +23,21 @@ pub(crate) struct MirPatch<'tcx> {
terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
body_span: Span,
next_local: usize,
/// The number of blocks at the start of the transformation. New blocks
/// get appended at the end.
next_block: usize,
}

impl<'tcx> MirPatch<'tcx> {
/// Creates a new, empty patch.
pub(crate) fn new(body: &Body<'tcx>) -> Self {
let mut result = MirPatch {
term_patch_map: IndexVec::from_elem(None, &body.basic_blocks),
term_patch_map: Default::default(),
new_blocks: vec![],
new_statements: vec![],
new_locals: vec![],
next_local: body.local_decls.len(),
next_block: body.basic_blocks.len(),
resume_block: None,
unreachable_cleanup_block: None,
unreachable_no_cleanup_block: None,
Expand Down Expand Up @@ -141,7 +146,7 @@ impl<'tcx> MirPatch<'tcx> {

/// Has a replacement of this block's terminator been queued in this patch?
pub(crate) fn is_term_patched(&self, bb: BasicBlock) -> bool {
self.term_patch_map[bb].is_some()
self.term_patch_map.contains_key(&bb)
}

/// Universal getter for block data, either it is in 'old' blocks or in patched ones
Expand Down Expand Up @@ -194,18 +199,17 @@ impl<'tcx> MirPatch<'tcx> {

/// Queues the addition of a new basic block.
pub(crate) fn new_block(&mut self, data: BasicBlockData<'tcx>) -> BasicBlock {
let block = self.term_patch_map.next_index();
let block = BasicBlock::from_usize(self.next_block + self.new_blocks.len());
debug!("MirPatch: new_block: {:?}: {:?}", block, data);
self.new_blocks.push(data);
self.term_patch_map.push(None);
block
}

/// Queues the replacement of a block's terminator.
pub(crate) fn patch_terminator(&mut self, block: BasicBlock, new: TerminatorKind<'tcx>) {
assert!(self.term_patch_map[block].is_none());
assert!(!self.term_patch_map.contains_key(&block));
debug!("MirPatch: patch_terminator({:?}, {:?})", block, new);
self.term_patch_map[block] = Some(new);
self.term_patch_map.insert(block, new);
}

/// Queues the insertion of a statement at a given location. The statement
Expand Down Expand Up @@ -244,18 +248,20 @@ impl<'tcx> MirPatch<'tcx> {
self.new_blocks.len(),
body.basic_blocks.len()
);
debug_assert_eq!(self.next_block, body.basic_blocks.len());
let bbs = if self.term_patch_map.is_empty() && self.new_blocks.is_empty() {
body.basic_blocks.as_mut_preserves_cfg()
} else {
body.basic_blocks.as_mut()
};
bbs.extend(self.new_blocks);
body.local_decls.extend(self.new_locals);
for (src, patch) in self.term_patch_map.into_iter_enumerated() {
if let Some(patch) = patch {
debug!("MirPatch: patching block {:?}", src);
bbs[src].terminator_mut().kind = patch;
}

// The order in which we patch terminators does not change the result.
#[allow(rustc::potential_query_instability)]
for (src, patch) in self.term_patch_map {
debug!("MirPatch: patching block {:?}", src);
bbs[src].terminator_mut().kind = patch;
}

let mut new_statements = self.new_statements;
Expand All @@ -273,8 +279,8 @@ impl<'tcx> MirPatch<'tcx> {
}
debug!("MirPatch: adding statement {:?} at loc {:?}+{}", stmt, loc, delta);
loc.statement_index += delta;
let source_info = Self::source_info_for_index(&body[loc.block], loc);
body[loc.block]
let source_info = Self::source_info_for_index(&bbs[loc.block], loc);
bbs[loc.block]
.statements
.insert(loc.statement_index, Statement::new(source_info, stmt));
delta += 1;
Expand Down
Loading