Skip to content

Commit

Permalink
Move alignment checks to codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
saethlin committed Apr 7, 2024
1 parent 6f83750 commit 9b79098
Show file tree
Hide file tree
Showing 24 changed files with 300 additions and 344 deletions.
14 changes: 14 additions & 0 deletions compiler/rustc_codegen_cranelift/example/mini_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,20 @@ fn panic_cannot_unwind() -> ! {
}
}

#[lang = "panic_misaligned_pointer_dereference"]
#[track_caller]
fn panic_misaligned_pointer_dereference(required: usize, found: usize) -> ! {
unsafe {
libc::printf(
"misaligned pointer dereference: address must be a multiple of %d but is %d\n\0"
as *const str as *const i8,
required,
found,
);
intrinsics::abort();
}
}

#[lang = "eh_personality"]
fn eh_personality() -> ! {
loop {}
Expand Down
77 changes: 65 additions & 12 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use cranelift_codegen::ir::UserFuncName;
use cranelift_codegen::CodegenError;
use cranelift_module::ModuleError;
use rustc_ast::InlineAsmOptions;
use rustc_codegen_ssa::mir::pointers_to_check;
use rustc_index::IndexVec;
use rustc_middle::ty::adjustment::PointerCoercion;
use rustc_middle::ty::layout::FnAbiOf;
Expand Down Expand Up @@ -359,18 +360,6 @@ fn codegen_fn_body(fx: &mut FunctionCx<'_, '_, '_>, start_block: Block) {
Some(source_info.span),
);
}
AssertKind::MisalignedPointerDereference { ref required, ref found } => {
let required = codegen_operand(fx, required).load_scalar(fx);
let found = codegen_operand(fx, found).load_scalar(fx);
let location = fx.get_caller_location(source_info).load_scalar(fx);

codegen_panic_inner(
fx,
rustc_hir::LangItem::PanicMisalignedPointerDereference,
&[required, found, location],
Some(source_info.span),
);
}
_ => {
let location = fx.get_caller_location(source_info).load_scalar(fx);

Expand Down Expand Up @@ -513,6 +502,49 @@ fn codegen_fn_body(fx: &mut FunctionCx<'_, '_, '_>, start_block: Block) {
}
}

fn codegen_alignment_check<'tcx>(
fx: &mut FunctionCx<'_, '_, 'tcx>,
pointer: mir::Operand<'tcx>,
required_alignment: u64,
source_info: mir::SourceInfo,
) {
// Compute the alignment mask
let required_alignment = required_alignment as i64;
let mask = fx.bcx.ins().iconst(fx.pointer_type, required_alignment - 1);
let required = fx.bcx.ins().iconst(fx.pointer_type, required_alignment);

// And the pointer with the mask
let pointer = codegen_operand(fx, &pointer);
let pointer = match pointer.layout().abi {
Abi::Scalar(_) => pointer.load_scalar(fx),
Abi::ScalarPair(..) => pointer.load_scalar_pair(fx).0,
_ => unreachable!(),
};
let masked = fx.bcx.ins().band(pointer, mask);

// Branch on whether the masked value is zero
let is_zero = fx.bcx.ins().icmp_imm(IntCC::Equal, masked, 0);

// Create destination blocks, branching on is_zero
let panic = fx.bcx.create_block();
let success = fx.bcx.create_block();
fx.bcx.ins().brif(is_zero, success, &[], panic, &[]);

// Switch to the failure block and codegen a call to the panic intrinsic
fx.bcx.switch_to_block(panic);
let location = fx.get_caller_location(source_info).load_scalar(fx);
codegen_panic_inner(
fx,
rustc_hir::LangItem::PanicMisalignedPointerDereference,
&[required, pointer, location],
Some(source_info.span),
);

// Continue codegen in the success block
fx.bcx.switch_to_block(success);
fx.bcx.ins().nop();
}

fn codegen_stmt<'tcx>(
fx: &mut FunctionCx<'_, '_, 'tcx>,
#[allow(unused_variables)] cur_block: Block,
Expand All @@ -534,6 +566,27 @@ fn codegen_stmt<'tcx>(
}
}

let required_align_of = |pointer| {
let pointer_ty = fx.mir.local_decls[pointer].ty;
let pointer_ty = fx.monomorphize(pointer_ty);
if !pointer_ty.is_unsafe_ptr() {
return None;
}

let pointee_ty =
pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer").ty;
let pointee_layout = fx.layout_of(pointee_ty);

Some(pointee_layout.align.abi.bytes() as u64)
};

if fx.tcx.may_insert_alignment_checks() {
for (pointer, required_alignment) in pointers_to_check(stmt, required_align_of) {
let pointer = mir::Operand::Copy(pointer.into());
codegen_alignment_check(fx, pointer, required_alignment, stmt.source_info);
}
}

match &stmt.kind {
StatementKind::SetDiscriminant { place, variant_index } => {
let place = codegen_place(fx, **place);
Expand Down
14 changes: 14 additions & 0 deletions compiler/rustc_codegen_gcc/example/mini_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,20 @@ fn panic_bounds_check(index: usize, len: usize) -> ! {
}
}

#[lang = "panic_misaligned_pointer_dereference"]
#[track_caller]
fn panic_misaligned_pointer_dereference(required: usize, found: usize) -> ! {
unsafe {
libc::printf(
"misaligned pointer dereference: address must be a multiple of %d but is %d\n\0"
as *const str as *const i8,
required,
found,
);
intrinsics::abort();
}
}

#[lang = "eh_personality"]
fn eh_personality() -> ! {
loop {}
Expand Down
9 changes: 1 addition & 8 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -674,13 +674,6 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
// and `#[track_caller]` adds an implicit third argument.
(LangItem::PanicBoundsCheck, vec![index, len, location])
}
AssertKind::MisalignedPointerDereference { ref required, ref found } => {
let required = self.codegen_operand(bx, required).immediate();
let found = self.codegen_operand(bx, found).immediate();
// It's `fn panic_misaligned_pointer_dereference(required: usize, found: usize)`,
// and `#[track_caller]` adds an implicit third argument.
(LangItem::PanicMisalignedPointerDereference, vec![required, found, location])
}
_ => {
// It's `pub fn panic_...()` and `#[track_caller]` adds an implicit argument.
(msg.panic_function(), vec![location])
Expand Down Expand Up @@ -1583,7 +1576,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
tuple.layout.fields.count()
}

fn get_caller_location(
pub fn get_caller_location(
&mut self,
bx: &mut Bx,
source_info: mir::SourceInfo,
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ mod intrinsic;
mod locals;
pub mod operand;
pub mod place;
mod pointer_alignment_check;
mod rvalue;
mod statement;

use self::debuginfo::{FunctionDebugContext, PerLocalVarDebugInfo};
use self::operand::{OperandRef, OperandValue};
use self::place::PlaceRef;
pub use self::pointer_alignment_check::pointers_to_check;

// Used for tracking the state of generated basic blocks.
enum CachedLlbb<T> {
Expand Down
152 changes: 152 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/pointer_alignment_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
use rustc_hir::LangItem;
use rustc_middle::mir;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext};
use rustc_span::Span;

use super::FunctionCx;
use crate::base;
use crate::common;
use crate::mir::OperandValue;
use crate::traits::*;

pub fn pointers_to_check<F>(
statement: &mir::Statement<'_>,
required_align_of: F,
) -> Vec<(mir::Local, u64)>
where
F: Fn(mir::Local) -> Option<u64>,
{
let mut finder = PointerFinder { required_align_of, pointers: Vec::new() };
finder.visit_statement(statement, rustc_middle::mir::Location::START);
finder.pointers
}

struct PointerFinder<F> {
pointers: Vec<(mir::Local, u64)>,
required_align_of: F,
}

impl<'tcx, F> Visitor<'tcx> for PointerFinder<F>
where
F: Fn(mir::Local) -> Option<u64>,
{
fn visit_place(
&mut self,
place: &mir::Place<'tcx>,
context: PlaceContext,
location: mir::Location,
) {
// We want to only check reads and writes to Places, so we specifically exclude
// Borrows and AddressOf.
match context {
PlaceContext::MutatingUse(
MutatingUseContext::Store
| MutatingUseContext::AsmOutput
| MutatingUseContext::Call
| MutatingUseContext::Yield
| MutatingUseContext::Drop,
) => {}
PlaceContext::NonMutatingUse(
NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
) => {}
_ => {
return;
}
}

if !place.is_indirect() {
return;
}

let pointer = place.local;
let Some(required_alignment) = (self.required_align_of)(pointer) else {
return;
};

if required_alignment == 1 {
return;
}

// Ensure that this place is based on an aligned pointer.
self.pointers.push((pointer, required_alignment));

self.super_place(place, context, location);
}
}

impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
#[instrument(level = "debug", skip(self, bx))]
pub fn codegen_alignment_check(
&mut self,
bx: &mut Bx,
pointer: mir::Operand<'tcx>,
required_alignment: u64,
source_info: mir::SourceInfo,
) {
// Compute the alignment mask
let mask = bx.const_usize(required_alignment - 1);
let zero = bx.const_usize(0);
let required_alignment = bx.const_usize(required_alignment);

// And the pointer with the mask
let pointer = match self.codegen_operand(bx, &pointer).val {
OperandValue::Immediate(imm) => imm,
OperandValue::Pair(ptr, _) => ptr,
_ => {
unreachable!("{pointer:?}");
}
};
let addr = bx.ptrtoint(pointer, bx.cx().type_isize());
let masked = bx.and(addr, mask);

// Branch on whether the masked value is zero
let is_zero = bx.icmp(
base::bin_op_to_icmp_predicate(mir::BinOp::Eq.to_hir_binop(), false),
masked,
zero,
);

// Create destination blocks, branching on is_zero
let panic = bx.append_sibling_block("panic");
let success = bx.append_sibling_block("success");
bx.cond_br(is_zero, success, panic);

// Switch to the failure block and codegen a call to the panic intrinsic
bx.switch_to_block(panic);
self.set_debug_loc(bx, source_info);
let location = self.get_caller_location(bx, source_info).immediate();
self.codegen_panic(
bx,
LangItem::PanicMisalignedPointerDereference,
&[required_alignment, addr, location],
source_info.span,
);

// Continue codegen in the success block.
bx.switch_to_block(success);
self.set_debug_loc(bx, source_info);
}

#[instrument(level = "debug", skip(self, bx))]
fn codegen_panic(&mut self, bx: &mut Bx, lang_item: LangItem, args: &[Bx::Value], span: Span) {
let (fn_abi, fn_ptr, instance) = common::build_langcall(bx, Some(span), lang_item);
let fn_ty = bx.fn_decl_backend_type(&fn_abi);
let fn_attrs = if bx.tcx().def_kind(self.instance.def_id()).has_codegen_attrs() {
Some(bx.tcx().codegen_fn_attrs(self.instance.def_id()))
} else {
None
};

bx.call(
fn_ty,
fn_attrs,
Some(&fn_abi),
fn_ptr,
args,
None, /* funclet */
Some(instance),
);
bx.unreachable();
}
}
31 changes: 31 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use rustc_middle::mir;
use rustc_middle::mir::NonDivergingIntrinsic;
use rustc_session::config::OptLevel;

use super::pointers_to_check;
use super::FunctionCx;
use super::LocalRef;
use crate::traits::*;
Expand All @@ -10,6 +11,36 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
#[instrument(level = "debug", skip(self, bx))]
pub fn codegen_statement(&mut self, bx: &mut Bx, statement: &mir::Statement<'tcx>) {
self.set_debug_loc(bx, statement.source_info);

let required_align_of = |local| {
// Since Deref projections must come first and only once, the pointer for an indirect place
// is the Local that the Place is based on.
let pointer_ty = self.mir.local_decls[local].ty;
let pointer_ty = self.monomorphize(pointer_ty);

// We only want to check places based on unsafe pointers
if !pointer_ty.is_unsafe_ptr() {
return None;
}

let pointee_ty =
pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer").ty;
let pointee_layout = bx.layout_of(pointee_ty);

Some(pointee_layout.layout.align.abi.bytes())
};

if bx.tcx().may_insert_alignment_checks() {
for (pointer, required_alignment) in pointers_to_check(statement, required_align_of) {
let pointer = mir::Operand::Copy(pointer.into());
self.codegen_alignment_check(
bx,
pointer,
required_alignment,
statement.source_info,
);
}
}
match statement.kind {
mir::StatementKind::Assign(box (ref place, ref rvalue)) => {
if let Some(index) = place.as_local() {
Expand Down
6 changes: 0 additions & 6 deletions compiler/rustc_const_eval/src/const_eval/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,6 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for CompileTimeInterpreter<'mir,
RemainderByZero(op) => RemainderByZero(eval_to_int(op)?),
ResumedAfterReturn(coroutine_kind) => ResumedAfterReturn(*coroutine_kind),
ResumedAfterPanic(coroutine_kind) => ResumedAfterPanic(*coroutine_kind),
MisalignedPointerDereference { ref required, ref found } => {
MisalignedPointerDereference {
required: eval_to_int(required)?,
found: eval_to_int(found)?,
}
}
};
Err(ConstEvalErrKind::AssertFailure(err).into())
}
Expand Down
Loading

0 comments on commit 9b79098

Please sign in to comment.