Skip to content

Mir-Opt for copying enums with large discrepancies #85158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 298 additions & 0 deletions compiler/rustc_mir_transform/src/large_enums.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
use crate::rustc_middle::ty::util::IntTypeExt;
use crate::MirPass;
use rustc_data_structures::fx::FxHashMap;
use rustc_middle::mir::interpret::AllocId;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, AdtDef, Const, ParamEnv, Ty, TyCtxt};
use rustc_session::Session;
use rustc_target::abi::{HasDataLayout, Size, TagEncoding, Variants};

/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
/// enough discrepancy between them.
///
/// i.e. If there is are two variants:
/// ```
/// enum Example {
/// Small,
/// Large([u32; 1024]),
/// }
/// ```
/// Instead of emitting moves of the large variant,
/// Perform a memcpy instead.
/// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD).
///
/// In summary, what this does is at runtime determine which enum variant is active,
/// and instead of copying all the bytes of the largest possible variant,
/// copy only the bytes for the currently active variant.
pub struct EnumSizeOpt {
pub(crate) discrepancy: u64,
}

impl<'tcx> MirPass<'tcx> for EnumSizeOpt {
fn is_enabled(&self, sess: &Session) -> bool {
sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3
}
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// NOTE: This pass may produce different MIR based on the alignment of the target
// platform, but it will still be valid.
self.optim(tcx, body);
}
}

impl EnumSizeOpt {
fn candidate<'tcx>(
&self,
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
ty: Ty<'tcx>,
alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>,
) -> Option<(AdtDef<'tcx>, usize, AllocId)> {
let adt_def = match ty.kind() {
ty::Adt(adt_def, _substs) if adt_def.is_enum() => adt_def,
_ => return None,
};
let layout = tcx.layout_of(param_env.and(ty)).ok()?;
let variants = match &layout.variants {
Variants::Single { .. } => return None,
Variants::Multiple { tag_encoding, .. }
if matches!(tag_encoding, TagEncoding::Niche { .. }) =>
{
return None;
}
Variants::Multiple { variants, .. } if variants.len() <= 1 => return None,
Variants::Multiple { variants, .. } => variants,
};
let min = variants.iter().map(|v| v.size).min().unwrap();
let max = variants.iter().map(|v| v.size).max().unwrap();
if max.bytes() - min.bytes() < self.discrepancy {
return None;
}

let num_discrs = adt_def.discriminants(tcx).count();
if variants.iter_enumerated().any(|(var_idx, _)| {
let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val;
(discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs)
}) {
return None;
}
if let Some(alloc_id) = alloc_cache.get(&ty) {
return Some((*adt_def, num_discrs, *alloc_id));
}

let data_layout = tcx.data_layout();
let ptr_sized_int = data_layout.ptr_sized_integer();
let target_bytes = ptr_sized_int.size().bytes() as usize;
let mut data = vec![0; target_bytes * num_discrs];
macro_rules! encode_store {
($curr_idx: expr, $endian: expr, $bytes: expr) => {
let bytes = match $endian {
rustc_target::abi::Endian::Little => $bytes.to_le_bytes(),
rustc_target::abi::Endian::Big => $bytes.to_be_bytes(),
};
for (i, b) in bytes.into_iter().enumerate() {
data[$curr_idx + i] = b;
}
};
}

for (var_idx, layout) in variants.iter_enumerated() {
let curr_idx =
target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize;
let sz = layout.size;
match ptr_sized_int {
rustc_target::abi::Integer::I32 => {
encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32);
}
rustc_target::abi::Integer::I64 => {
encode_store!(curr_idx, data_layout.endian, sz.bytes());
}
_ => unreachable!(),
};
}
let alloc = interpret::Allocation::from_bytes(
data,
tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
Mutability::Not,
);
let alloc = tcx.create_memory_alloc(tcx.intern_const_alloc(alloc));
Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc)))
}
fn optim<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let mut alloc_cache = FxHashMap::default();
let body_did = body.source.def_id();
let param_env = tcx.param_env(body_did);

let blocks = body.basic_blocks.as_mut();
let local_decls = &mut body.local_decls;

for bb in blocks {
bb.expand_statements(|st| {
if let StatementKind::Assign(box (
lhs,
Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
)) = &st.kind
{
let ty = lhs.ty(local_decls, tcx).ty;

let source_info = st.source_info;
let span = source_info.span;

let (adt_def, num_variants, alloc_id) =
self.candidate(tcx, param_env, ty, &mut alloc_cache)?;
let alloc = tcx.global_alloc(alloc_id).unwrap_memory();

let tmp_ty = tcx.mk_ty(ty::Array(
tcx.types.usize,
Const::from_usize(tcx, num_variants as u64),
));

let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span));
let store_live = Statement {
source_info,
kind: StatementKind::StorageLive(size_array_local),
};

let place = Place::from(size_array_local);
let constant_vals = Constant {
span,
user_ty: None,
literal: ConstantKind::Val(
interpret::ConstValue::ByRef { alloc, offset: Size::ZERO },
tmp_ty,
),
};
let rval = Rvalue::Use(Operand::Constant(box (constant_vals)));

let const_assign =
Statement { source_info, kind: StatementKind::Assign(box (place, rval)) };

let discr_place = Place::from(
local_decls
.push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)),
);

let store_discr = Statement {
source_info,
kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(*rhs))),
};

let discr_cast_place =
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));

let cast_discr = Statement {
source_info,
kind: StatementKind::Assign(box (
discr_cast_place,
Rvalue::Cast(
CastKind::IntToInt,
Operand::Copy(discr_place),
tcx.types.usize,
),
)),
};

let size_place =
Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));

let store_size = Statement {
source_info,
kind: StatementKind::Assign(box (
size_place,
Rvalue::Use(Operand::Copy(Place {
local: size_array_local,
projection: tcx.intern_place_elems(&[PlaceElem::Index(
discr_cast_place.local,
)]),
})),
)),
};

let dst =
Place::from(local_decls.push(LocalDecl::new(tcx.mk_mut_ptr(ty), span)));

let dst_ptr = Statement {
source_info,
kind: StatementKind::Assign(box (
dst,
Rvalue::AddressOf(Mutability::Mut, *lhs),
)),
};

let dst_cast_ty = tcx.mk_mut_ptr(tcx.types.u8);
let dst_cast_place =
Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span)));

let dst_cast = Statement {
source_info,
kind: StatementKind::Assign(box (
dst_cast_place,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
)),
};

let src =
Place::from(local_decls.push(LocalDecl::new(tcx.mk_imm_ptr(ty), span)));

let src_ptr = Statement {
source_info,
kind: StatementKind::Assign(box (
src,
Rvalue::AddressOf(Mutability::Not, *rhs),
)),
};

let src_cast_ty = tcx.mk_imm_ptr(tcx.types.u8);
let src_cast_place =
Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span)));

let src_cast = Statement {
source_info,
kind: StatementKind::Assign(box (
src_cast_place,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
)),
};

let deinit_old =
Statement { source_info, kind: StatementKind::Deinit(box dst) };

let copy_bytes = Statement {
source_info,
kind: StatementKind::Intrinsic(
box NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
src: Operand::Copy(src_cast_place),
dst: Operand::Copy(dst_cast_place),
count: Operand::Copy(size_place),
}),
),
};

let store_dead = Statement {
source_info,
kind: StatementKind::StorageDead(size_array_local),
};
let iter = [
store_live,
const_assign,
store_discr,
cast_discr,
store_size,
dst_ptr,
dst_cast,
src_ptr,
src_cast,
deinit_old,
copy_bytes,
store_dead,
]
.into_iter();

st.make_nop();
Some(iter)
} else {
None
}
});
}
}
}
3 changes: 3 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(rustc::potential_query_instability)]
#![feature(box_patterns)]
#![feature(drain_filter)]
#![feature(box_syntax)]
#![feature(let_chains)]
#![feature(map_try_insert)]
#![feature(min_specialization)]
Expand Down Expand Up @@ -73,6 +74,7 @@ mod function_item_references;
mod generator;
mod inline;
mod instcombine;
mod large_enums;
mod lower_intrinsics;
mod lower_slice_len;
mod match_branches;
Expand Down Expand Up @@ -583,6 +585,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&simplify::SimplifyLocals::new("final"),
&multiple_return_terminators::MultipleReturnTerminators,
&deduplicate_blocks::DeduplicateBlocks,
&large_enums::EnumSizeOpt { discrepancy: 128 },
// Some cleanup necessary at least for LLVM and potentially other codegen backends.
&add_call_guards::CriticalCallEdges,
// Dump the end result for testing and debugging purposes.
Expand Down
68 changes: 68 additions & 0 deletions tests/mir-opt/enum_opt.cand.EnumSizeOpt.32bit.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
- // MIR for `cand` before EnumSizeOpt
+ // MIR for `cand` after EnumSizeOpt

fn cand() -> Candidate {
let mut _0: Candidate; // return place in scope 0 at $DIR/enum_opt.rs:+0:18: +0:27
let mut _1: Candidate; // in scope 0 at $DIR/enum_opt.rs:+1:7: +1:12
let mut _2: Candidate; // in scope 0 at $DIR/enum_opt.rs:+2:7: +2:34
let mut _3: [u8; 8196]; // in scope 0 at $DIR/enum_opt.rs:+2:24: +2:33
+ let mut _4: [usize; 2]; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
+ let mut _5: isize; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
+ let mut _6: usize; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
+ let mut _7: usize; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
+ let mut _8: *mut Candidate; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
+ let mut _9: *mut u8; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
+ let mut _10: *const Candidate; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
+ let mut _11: *const u8; // in scope 0 at $DIR/enum_opt.rs:+2:3: +2:34
+ let mut _12: [usize; 2]; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
+ let mut _13: isize; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
+ let mut _14: usize; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
+ let mut _15: usize; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
+ let mut _16: *mut Candidate; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
+ let mut _17: *mut u8; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
+ let mut _18: *const Candidate; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
+ let mut _19: *const u8; // in scope 0 at $DIR/enum_opt.rs:+3:3: +3:4
scope 1 {
debug a => _1; // in scope 1 at $DIR/enum_opt.rs:+1:7: +1:12
}

bb0: {
StorageLive(_1); // scope 0 at $DIR/enum_opt.rs:+1:7: +1:12
_1 = Candidate::Small(const 1_u8); // scope 0 at $DIR/enum_opt.rs:+1:15: +1:34
StorageLive(_2); // scope 1 at $DIR/enum_opt.rs:+2:7: +2:34
StorageLive(_3); // scope 1 at $DIR/enum_opt.rs:+2:24: +2:33
_3 = [const 1_u8; 8196]; // scope 1 at $DIR/enum_opt.rs:+2:24: +2:33
_2 = Candidate::Large(move _3); // scope 1 at $DIR/enum_opt.rs:+2:7: +2:34
StorageDead(_3); // scope 1 at $DIR/enum_opt.rs:+2:33: +2:34
- _1 = move _2; // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
+ StorageLive(_4); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
+ _4 = const [2_usize, 8197_usize]; // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
+ _5 = discriminant(_2); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
+ _6 = _5 as usize (IntToInt); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
+ _7 = _4[_6]; // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
+ _8 = &raw mut _1; // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
+ _9 = _8 as *mut u8 (PtrToPtr); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
+ _10 = &raw const _2; // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
+ _11 = _10 as *const u8 (PtrToPtr); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
+ Deinit(_8); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
+ copy_nonoverlapping(dst = _9, src = _11, count = _7); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
+ StorageDead(_4); // scope 1 at $DIR/enum_opt.rs:+2:3: +2:34
StorageDead(_2); // scope 1 at $DIR/enum_opt.rs:+2:33: +2:34
- _0 = move _1; // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
+ StorageLive(_12); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
+ _12 = const [2_usize, 8197_usize]; // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
+ _13 = discriminant(_1); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
+ _14 = _13 as usize (IntToInt); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
+ _15 = _12[_14]; // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
+ _16 = &raw mut _0; // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
+ _17 = _16 as *mut u8 (PtrToPtr); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
+ _18 = &raw const _1; // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
+ _19 = _18 as *const u8 (PtrToPtr); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
+ Deinit(_16); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
+ copy_nonoverlapping(dst = _17, src = _19, count = _15); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
+ StorageDead(_12); // scope 1 at $DIR/enum_opt.rs:+3:3: +3:4
StorageDead(_1); // scope 0 at $DIR/enum_opt.rs:+4:1: +4:2
return; // scope 0 at $DIR/enum_opt.rs:+4:2: +4:2
}
}

Loading