forked from rust-lang/rust
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Auto merge of rust-lang#98112 - saethlin:mir-alignment-checks, r=oli-obk
Insert alignment checks for pointer dereferences when debug assertions are enabled Closes rust-lang#54915 - [x] Jake tells me this sounds like a place to use `MirPatch`, but I can't figure out how to insert a new basic block with a new terminator in the middle of an existing basic block, using `MirPatch`. (if nobody else backs up this point I'm checking this as "not actually a good idea" because the code looks pretty clean to me after rearranging it a bit) - [x] Using `CastKind::PointerExposeAddress` is definitely wrong, we don't want to expose. Calling a function to get the pointer address seems quite excessive. ~I'll see if I can add a new `CastKind`.~ `CastKind::Transmute` to the rescue! - [x] Implement a more helpful panic message like slice bounds checking. r? `@oli-obk`
- Loading branch information
Showing
35 changed files
with
372 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
use crate::MirPass; | ||
use rustc_hir::def_id::DefId; | ||
use rustc_index::vec::IndexVec; | ||
use rustc_middle::mir::*; | ||
use rustc_middle::mir::{ | ||
interpret::{ConstValue, Scalar}, | ||
visit::{PlaceContext, Visitor}, | ||
}; | ||
use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut}; | ||
use rustc_session::Session; | ||
|
||
pub struct CheckAlignment; | ||
|
||
impl<'tcx> MirPass<'tcx> for CheckAlignment { | ||
fn is_enabled(&self, sess: &Session) -> bool { | ||
sess.opts.debug_assertions | ||
} | ||
|
||
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | ||
let basic_blocks = body.basic_blocks.as_mut(); | ||
let local_decls = &mut body.local_decls; | ||
|
||
for block in (0..basic_blocks.len()).rev() { | ||
let block = block.into(); | ||
for statement_index in (0..basic_blocks[block].statements.len()).rev() { | ||
let location = Location { block, statement_index }; | ||
let statement = &basic_blocks[block].statements[statement_index]; | ||
let source_info = statement.source_info; | ||
|
||
let mut finder = PointerFinder { | ||
local_decls, | ||
tcx, | ||
pointers: Vec::new(), | ||
def_id: body.source.def_id(), | ||
}; | ||
for (pointer, pointee_ty) in finder.find_pointers(statement) { | ||
debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty); | ||
|
||
let new_block = split_block(basic_blocks, location); | ||
insert_alignment_check( | ||
tcx, | ||
local_decls, | ||
&mut basic_blocks[block], | ||
pointer, | ||
pointee_ty, | ||
source_info, | ||
new_block, | ||
); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
impl<'tcx, 'a> PointerFinder<'tcx, 'a> { | ||
fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> { | ||
self.pointers.clear(); | ||
self.visit_statement(statement, Location::START); | ||
core::mem::take(&mut self.pointers) | ||
} | ||
} | ||
|
||
struct PointerFinder<'tcx, 'a> { | ||
local_decls: &'a mut LocalDecls<'tcx>, | ||
tcx: TyCtxt<'tcx>, | ||
def_id: DefId, | ||
pointers: Vec<(Place<'tcx>, Ty<'tcx>)>, | ||
} | ||
|
||
impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> { | ||
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { | ||
if let PlaceContext::NonUse(_) = context { | ||
return; | ||
} | ||
if !place.is_indirect() { | ||
return; | ||
} | ||
|
||
let pointer = Place::from(place.local); | ||
let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty; | ||
|
||
// We only want to check unsafe pointers | ||
if !pointer_ty.is_unsafe_ptr() { | ||
trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty); | ||
return; | ||
} | ||
|
||
let Some(pointee) = pointer_ty.builtin_deref(true) else { | ||
debug!("Indirect but no builtin deref: {:?}", pointer_ty); | ||
return; | ||
}; | ||
let mut pointee_ty = pointee.ty; | ||
if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() { | ||
pointee_ty = pointee_ty.sequence_element_type(self.tcx); | ||
} | ||
|
||
if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) { | ||
debug!("Unsafe pointer, but unsized: {:?}", pointer_ty); | ||
return; | ||
} | ||
|
||
if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8, self.tcx.types.str_] | ||
.contains(&pointee_ty) | ||
{ | ||
debug!("Trivially aligned pointee type: {:?}", pointer_ty); | ||
return; | ||
} | ||
|
||
self.pointers.push((pointer, pointee_ty)) | ||
} | ||
} | ||
|
||
fn split_block( | ||
basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>, | ||
location: Location, | ||
) -> BasicBlock { | ||
let block_data = &mut basic_blocks[location.block]; | ||
|
||
// Drain every statement after this one and move the current terminator to a new basic block | ||
let new_block = BasicBlockData { | ||
statements: block_data.statements.split_off(location.statement_index), | ||
terminator: block_data.terminator.take(), | ||
is_cleanup: block_data.is_cleanup, | ||
}; | ||
|
||
basic_blocks.push(new_block) | ||
} | ||
|
||
fn insert_alignment_check<'tcx>( | ||
tcx: TyCtxt<'tcx>, | ||
local_decls: &mut LocalDecls<'tcx>, | ||
block_data: &mut BasicBlockData<'tcx>, | ||
pointer: Place<'tcx>, | ||
pointee_ty: Ty<'tcx>, | ||
source_info: SourceInfo, | ||
new_block: BasicBlock, | ||
) { | ||
// Cast the pointer to a *const () | ||
let const_raw_ptr = tcx.mk_ptr(TypeAndMut { ty: tcx.types.unit, mutbl: Mutability::Not }); | ||
let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr); | ||
let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into(); | ||
block_data | ||
.statements | ||
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) }); | ||
|
||
// Transmute the pointer to a usize (equivalent to `ptr.addr()`) | ||
let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize); | ||
let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); | ||
block_data | ||
.statements | ||
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) }); | ||
|
||
// Get the alignment of the pointee | ||
let alignment = | ||
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); | ||
let rvalue = Rvalue::NullaryOp(NullOp::AlignOf, pointee_ty); | ||
block_data.statements.push(Statement { | ||
source_info, | ||
kind: StatementKind::Assign(Box::new((alignment, rvalue))), | ||
}); | ||
|
||
// Subtract 1 from the alignment to get the alignment mask | ||
let alignment_mask = | ||
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); | ||
let one = Operand::Constant(Box::new(Constant { | ||
span: source_info.span, | ||
user_ty: None, | ||
literal: ConstantKind::Val( | ||
ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)), | ||
tcx.types.usize, | ||
), | ||
})); | ||
block_data.statements.push(Statement { | ||
source_info, | ||
kind: StatementKind::Assign(Box::new(( | ||
alignment_mask, | ||
Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(alignment), one))), | ||
))), | ||
}); | ||
|
||
// BitAnd the alignment mask with the pointer | ||
let alignment_bits = | ||
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); | ||
block_data.statements.push(Statement { | ||
source_info, | ||
kind: StatementKind::Assign(Box::new(( | ||
alignment_bits, | ||
Rvalue::BinaryOp( | ||
BinOp::BitAnd, | ||
Box::new((Operand::Copy(addr), Operand::Copy(alignment_mask))), | ||
), | ||
))), | ||
}); | ||
|
||
// Check if the alignment bits are all zero | ||
let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); | ||
let zero = Operand::Constant(Box::new(Constant { | ||
span: source_info.span, | ||
user_ty: None, | ||
literal: ConstantKind::Val( | ||
ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), | ||
tcx.types.usize, | ||
), | ||
})); | ||
block_data.statements.push(Statement { | ||
source_info, | ||
kind: StatementKind::Assign(Box::new(( | ||
is_ok, | ||
Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(alignment_bits), zero.clone()))), | ||
))), | ||
}); | ||
|
||
// Set this block's terminator to our assert, continuing to new_block if we pass | ||
block_data.terminator = Some(Terminator { | ||
source_info, | ||
kind: TerminatorKind::Assert { | ||
cond: Operand::Copy(is_ok), | ||
expected: true, | ||
target: new_block, | ||
msg: AssertKind::MisalignedPointerDereference { | ||
required: Operand::Copy(alignment), | ||
found: Operand::Copy(addr), | ||
}, | ||
cleanup: None, | ||
}, | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.