Skip to content

Commit 22a7a19

Browse files
committed
Auto merge of #98112 - saethlin:mir-alignment-checks, r=oli-obk
Insert alignment checks for pointer dereferences when debug assertions are enabled Closes #54915 - [x] Jake tells me this sounds like a place to use `MirPatch`, but I can't figure out how to insert a new basic block with a new terminator in the middle of an existing basic block, using `MirPatch`. (if nobody else backs up this point I'm checking this as "not actually a good idea" because the code looks pretty clean to me after rearranging it a bit) - [x] Using `CastKind::PointerExposeAddress` is definitely wrong, we don't want to expose. Calling a function to get the pointer address seems quite excessive. ~I'll see if I can add a new `CastKind`.~ `CastKind::Transmute` to the rescue! - [x] Implement a more helpful panic message like slice bounds checking. r? `@oli-obk`
2 parents ec7bb8d + 7507078 commit 22a7a19

35 files changed

+372
-21
lines changed

Diff for: compiler/rustc_codegen_cranelift/src/base.rs

+12
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,18 @@ fn codegen_fn_body(fx: &mut FunctionCx<'_, '_, '_>, start_block: Block) {
379379
source_info.span,
380380
);
381381
}
382+
AssertKind::MisalignedPointerDereference { ref required, ref found } => {
383+
let required = codegen_operand(fx, required).load_scalar(fx);
384+
let found = codegen_operand(fx, found).load_scalar(fx);
385+
let location = fx.get_caller_location(source_info).load_scalar(fx);
386+
387+
codegen_panic_inner(
388+
fx,
389+
rustc_hir::LangItem::PanicBoundsCheck,
390+
&[required, found, location],
391+
source_info.span,
392+
);
393+
}
382394
_ => {
383395
let msg_str = msg.description();
384396
codegen_panic(fx, msg_str, source_info);

Diff for: compiler/rustc_codegen_ssa/src/mir/block.rs

+7
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,13 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
600600
// and `#[track_caller]` adds an implicit third argument.
601601
(LangItem::PanicBoundsCheck, vec![index, len, location])
602602
}
603+
AssertKind::MisalignedPointerDereference { ref required, ref found } => {
604+
let required = self.codegen_operand(bx, required).immediate();
605+
let found = self.codegen_operand(bx, found).immediate();
606+
// It's `fn panic_bounds_check(index: usize, len: usize)`,
607+
// and `#[track_caller]` adds an implicit third argument.
608+
(LangItem::PanicMisalignedPointerDereference, vec![required, found, location])
609+
}
603610
_ => {
604611
let msg = bx.const_str(msg.description());
605612
// It's `pub fn panic(expr: &str)`, with the wide reference being passed

Diff for: compiler/rustc_const_eval/src/const_eval/machine.rs

+6
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,12 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for CompileTimeInterpreter<'mir,
544544
RemainderByZero(op) => RemainderByZero(eval_to_int(op)?),
545545
ResumedAfterReturn(generator_kind) => ResumedAfterReturn(*generator_kind),
546546
ResumedAfterPanic(generator_kind) => ResumedAfterPanic(*generator_kind),
547+
MisalignedPointerDereference { ref required, ref found } => {
548+
MisalignedPointerDereference {
549+
required: eval_to_int(required)?,
550+
found: eval_to_int(found)?,
551+
}
552+
}
547553
};
548554
Err(ConstEvalErrKind::AssertFailure(err).into())
549555
}

Diff for: compiler/rustc_hir/src/lang_items.rs

+1
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ language_item_table! {
240240
PanicDisplay, sym::panic_display, panic_display, Target::Fn, GenericRequirement::None;
241241
ConstPanicFmt, sym::const_panic_fmt, const_panic_fmt, Target::Fn, GenericRequirement::None;
242242
PanicBoundsCheck, sym::panic_bounds_check, panic_bounds_check_fn, Target::Fn, GenericRequirement::Exact(0);
243+
PanicMisalignedPointerDereference, sym::panic_misaligned_pointer_dereference, panic_misaligned_pointer_dereference_fn, Target::Fn, GenericRequirement::Exact(0);
243244
PanicInfo, sym::panic_info, panic_info, Target::Struct, GenericRequirement::None;
244245
PanicLocation, sym::panic_location, panic_location, Target::Struct, GenericRequirement::None;
245246
PanicImpl, sym::panic_impl, panic_impl, Target::Fn, GenericRequirement::None;

Diff for: compiler/rustc_middle/src/mir/mod.rs

+18-2
Original file line numberDiff line numberDiff line change
@@ -1277,7 +1277,7 @@ impl<O> AssertKind<O> {
12771277

12781278
/// Getting a description does not require `O` to be printable, and does not
12791279
/// require allocation.
1280-
/// The caller is expected to handle `BoundsCheck` separately.
1280+
/// The caller is expected to handle `BoundsCheck` and `MisalignedPointerDereference` separately.
12811281
pub fn description(&self) -> &'static str {
12821282
use AssertKind::*;
12831283
match self {
@@ -1296,7 +1296,9 @@ impl<O> AssertKind<O> {
12961296
ResumedAfterReturn(GeneratorKind::Async(_)) => "`async fn` resumed after completion",
12971297
ResumedAfterPanic(GeneratorKind::Gen) => "generator resumed after panicking",
12981298
ResumedAfterPanic(GeneratorKind::Async(_)) => "`async fn` resumed after panicking",
1299-
BoundsCheck { .. } => bug!("Unexpected AssertKind"),
1299+
BoundsCheck { .. } | MisalignedPointerDereference { .. } => {
1300+
bug!("Unexpected AssertKind")
1301+
}
13001302
}
13011303
}
13021304

@@ -1353,6 +1355,13 @@ impl<O> AssertKind<O> {
13531355
Overflow(BinOp::Shl, _, r) => {
13541356
write!(f, "\"attempt to shift left by `{{}}`, which would overflow\", {:?}", r)
13551357
}
1358+
MisalignedPointerDereference { required, found } => {
1359+
write!(
1360+
f,
1361+
"\"misaligned pointer dereference: address must be a multiple of {{}} but is {{}}\", {:?}, {:?}",
1362+
required, found
1363+
)
1364+
}
13561365
_ => write!(f, "\"{}\"", self.description()),
13571366
}
13581367
}
@@ -1397,6 +1406,13 @@ impl<O: fmt::Debug> fmt::Debug for AssertKind<O> {
13971406
Overflow(BinOp::Shl, _, r) => {
13981407
write!(f, "attempt to shift left by `{:#?}`, which would overflow", r)
13991408
}
1409+
MisalignedPointerDereference { required, found } => {
1410+
write!(
1411+
f,
1412+
"misaligned pointer dereference: address must be a multiple of {:?} but is {:?}",
1413+
required, found
1414+
)
1415+
}
14001416
_ => write!(f, "{}", self.description()),
14011417
}
14021418
}

Diff for: compiler/rustc_middle/src/mir/syntax.rs

+1
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ pub enum AssertKind<O> {
760760
RemainderByZero(O),
761761
ResumedAfterReturn(GeneratorKind),
762762
ResumedAfterPanic(GeneratorKind),
763+
MisalignedPointerDereference { required: O, found: O },
763764
}
764765

765766
#[derive(Clone, Debug, PartialEq, TyEncodable, TyDecodable, Hash, HashStable)]

Diff for: compiler/rustc_middle/src/mir/visit.rs

+4
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,10 @@ macro_rules! make_mir_visitor {
610610
ResumedAfterReturn(_) | ResumedAfterPanic(_) => {
611611
// Nothing to visit
612612
}
613+
MisalignedPointerDereference { required, found } => {
614+
self.visit_operand(required, location);
615+
self.visit_operand(found, location);
616+
}
613617
}
614618
}
615619

Diff for: compiler/rustc_mir_transform/src/check_alignment.rs

+227
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
use crate::MirPass;
2+
use rustc_hir::def_id::DefId;
3+
use rustc_index::vec::IndexVec;
4+
use rustc_middle::mir::*;
5+
use rustc_middle::mir::{
6+
interpret::{ConstValue, Scalar},
7+
visit::{PlaceContext, Visitor},
8+
};
9+
use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut};
10+
use rustc_session::Session;
11+
12+
pub struct CheckAlignment;
13+
14+
impl<'tcx> MirPass<'tcx> for CheckAlignment {
15+
fn is_enabled(&self, sess: &Session) -> bool {
16+
sess.opts.debug_assertions
17+
}
18+
19+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
20+
let basic_blocks = body.basic_blocks.as_mut();
21+
let local_decls = &mut body.local_decls;
22+
23+
for block in (0..basic_blocks.len()).rev() {
24+
let block = block.into();
25+
for statement_index in (0..basic_blocks[block].statements.len()).rev() {
26+
let location = Location { block, statement_index };
27+
let statement = &basic_blocks[block].statements[statement_index];
28+
let source_info = statement.source_info;
29+
30+
let mut finder = PointerFinder {
31+
local_decls,
32+
tcx,
33+
pointers: Vec::new(),
34+
def_id: body.source.def_id(),
35+
};
36+
for (pointer, pointee_ty) in finder.find_pointers(statement) {
37+
debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty);
38+
39+
let new_block = split_block(basic_blocks, location);
40+
insert_alignment_check(
41+
tcx,
42+
local_decls,
43+
&mut basic_blocks[block],
44+
pointer,
45+
pointee_ty,
46+
source_info,
47+
new_block,
48+
);
49+
}
50+
}
51+
}
52+
}
53+
}
54+
55+
impl<'tcx, 'a> PointerFinder<'tcx, 'a> {
56+
fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> {
57+
self.pointers.clear();
58+
self.visit_statement(statement, Location::START);
59+
core::mem::take(&mut self.pointers)
60+
}
61+
}
62+
63+
struct PointerFinder<'tcx, 'a> {
64+
local_decls: &'a mut LocalDecls<'tcx>,
65+
tcx: TyCtxt<'tcx>,
66+
def_id: DefId,
67+
pointers: Vec<(Place<'tcx>, Ty<'tcx>)>,
68+
}
69+
70+
impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> {
71+
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) {
72+
if let PlaceContext::NonUse(_) = context {
73+
return;
74+
}
75+
if !place.is_indirect() {
76+
return;
77+
}
78+
79+
let pointer = Place::from(place.local);
80+
let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty;
81+
82+
// We only want to check unsafe pointers
83+
if !pointer_ty.is_unsafe_ptr() {
84+
trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty);
85+
return;
86+
}
87+
88+
let Some(pointee) = pointer_ty.builtin_deref(true) else {
89+
debug!("Indirect but no builtin deref: {:?}", pointer_ty);
90+
return;
91+
};
92+
let mut pointee_ty = pointee.ty;
93+
if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() {
94+
pointee_ty = pointee_ty.sequence_element_type(self.tcx);
95+
}
96+
97+
if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) {
98+
debug!("Unsafe pointer, but unsized: {:?}", pointer_ty);
99+
return;
100+
}
101+
102+
if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8, self.tcx.types.str_]
103+
.contains(&pointee_ty)
104+
{
105+
debug!("Trivially aligned pointee type: {:?}", pointer_ty);
106+
return;
107+
}
108+
109+
self.pointers.push((pointer, pointee_ty))
110+
}
111+
}
112+
113+
fn split_block(
114+
basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
115+
location: Location,
116+
) -> BasicBlock {
117+
let block_data = &mut basic_blocks[location.block];
118+
119+
// Drain every statement after this one and move the current terminator to a new basic block
120+
let new_block = BasicBlockData {
121+
statements: block_data.statements.split_off(location.statement_index),
122+
terminator: block_data.terminator.take(),
123+
is_cleanup: block_data.is_cleanup,
124+
};
125+
126+
basic_blocks.push(new_block)
127+
}
128+
129+
fn insert_alignment_check<'tcx>(
130+
tcx: TyCtxt<'tcx>,
131+
local_decls: &mut LocalDecls<'tcx>,
132+
block_data: &mut BasicBlockData<'tcx>,
133+
pointer: Place<'tcx>,
134+
pointee_ty: Ty<'tcx>,
135+
source_info: SourceInfo,
136+
new_block: BasicBlock,
137+
) {
138+
// Cast the pointer to a *const ()
139+
let const_raw_ptr = tcx.mk_ptr(TypeAndMut { ty: tcx.types.unit, mutbl: Mutability::Not });
140+
let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr);
141+
let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into();
142+
block_data
143+
.statements
144+
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) });
145+
146+
// Transmute the pointer to a usize (equivalent to `ptr.addr()`)
147+
let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize);
148+
let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
149+
block_data
150+
.statements
151+
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });
152+
153+
// Get the alignment of the pointee
154+
let alignment =
155+
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
156+
let rvalue = Rvalue::NullaryOp(NullOp::AlignOf, pointee_ty);
157+
block_data.statements.push(Statement {
158+
source_info,
159+
kind: StatementKind::Assign(Box::new((alignment, rvalue))),
160+
});
161+
162+
// Subtract 1 from the alignment to get the alignment mask
163+
let alignment_mask =
164+
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
165+
let one = Operand::Constant(Box::new(Constant {
166+
span: source_info.span,
167+
user_ty: None,
168+
literal: ConstantKind::Val(
169+
ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)),
170+
tcx.types.usize,
171+
),
172+
}));
173+
block_data.statements.push(Statement {
174+
source_info,
175+
kind: StatementKind::Assign(Box::new((
176+
alignment_mask,
177+
Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(alignment), one))),
178+
))),
179+
});
180+
181+
// BitAnd the alignment mask with the pointer
182+
let alignment_bits =
183+
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
184+
block_data.statements.push(Statement {
185+
source_info,
186+
kind: StatementKind::Assign(Box::new((
187+
alignment_bits,
188+
Rvalue::BinaryOp(
189+
BinOp::BitAnd,
190+
Box::new((Operand::Copy(addr), Operand::Copy(alignment_mask))),
191+
),
192+
))),
193+
});
194+
195+
// Check if the alignment bits are all zero
196+
let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
197+
let zero = Operand::Constant(Box::new(Constant {
198+
span: source_info.span,
199+
user_ty: None,
200+
literal: ConstantKind::Val(
201+
ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)),
202+
tcx.types.usize,
203+
),
204+
}));
205+
block_data.statements.push(Statement {
206+
source_info,
207+
kind: StatementKind::Assign(Box::new((
208+
is_ok,
209+
Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(alignment_bits), zero.clone()))),
210+
))),
211+
});
212+
213+
// Set this block's terminator to our assert, continuing to new_block if we pass
214+
block_data.terminator = Some(Terminator {
215+
source_info,
216+
kind: TerminatorKind::Assert {
217+
cond: Operand::Copy(is_ok),
218+
expected: true,
219+
target: new_block,
220+
msg: AssertKind::MisalignedPointerDereference {
221+
required: Operand::Copy(alignment),
222+
found: Operand::Copy(addr),
223+
},
224+
cleanup: None,
225+
},
226+
});
227+
}

Diff for: compiler/rustc_mir_transform/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ mod separate_const_switch;
9090
mod shim;
9191
mod ssa;
9292
// This pass is public to allow external drivers to perform MIR cleanup
93+
mod check_alignment;
9394
pub mod simplify;
9495
mod simplify_branches;
9596
mod simplify_comparison_integral;
@@ -545,6 +546,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
545546
tcx,
546547
body,
547548
&[
549+
&check_alignment::CheckAlignment,
548550
&reveal_all::RevealAll, // has to be done before inlining, since inlined code is in RevealAll mode.
549551
&lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first
550552
&unreachable_prop::UnreachablePropagation,

Diff for: compiler/rustc_span/src/symbol.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,7 @@ symbols! {
10691069
panic_implementation,
10701070
panic_info,
10711071
panic_location,
1072+
panic_misaligned_pointer_dereference,
10721073
panic_nounwind,
10731074
panic_runtime,
10741075
panic_str,

Diff for: library/core/src/panicking.rs

+14
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,20 @@ fn panic_bounds_check(index: usize, len: usize) -> ! {
162162
panic!("index out of bounds: the len is {len} but the index is {index}")
163163
}
164164

165+
#[cold]
166+
#[cfg_attr(not(feature = "panic_immediate_abort"), inline(never))]
167+
#[track_caller]
168+
#[cfg_attr(not(bootstrap), lang = "panic_misaligned_pointer_dereference")] // needed by codegen for panic on misaligned pointer deref
169+
fn panic_misaligned_pointer_dereference(required: usize, found: usize) -> ! {
170+
if cfg!(feature = "panic_immediate_abort") {
171+
super::intrinsics::abort()
172+
}
173+
174+
panic!(
175+
"misaligned pointer dereference: address must be a multiple of {required:#x} but is {found:#x}"
176+
)
177+
}
178+
165179
/// Panic because we cannot unwind out of a function.
166180
///
167181
/// This function is called directly by the codegen backend, and must not have

0 commit comments

Comments
 (0)