diff --git a/compiler/rustc_mir_build/src/check_unsafety.rs b/compiler/rustc_mir_build/src/check_unsafety.rs index 0b6b36640e92b..f9f8e873cd535 100644 --- a/compiler/rustc_mir_build/src/check_unsafety.rs +++ b/compiler/rustc_mir_build/src/check_unsafety.rs @@ -45,6 +45,8 @@ struct UnsafetyVisitor<'a, 'tcx> { /// Flag to ensure that we only suggest wrapping the entire function body in /// an unsafe block once. suggest_unsafe_block: bool, + /// Controls how union field accesses are checked + union_field_access_mode: UnionFieldAccessMode, } impl<'tcx> UnsafetyVisitor<'_, 'tcx> { @@ -223,6 +225,7 @@ impl<'tcx> UnsafetyVisitor<'_, 'tcx> { inside_adt: false, warnings: self.warnings, suggest_unsafe_block: self.suggest_unsafe_block, + union_field_access_mode: UnionFieldAccessMode::Normal, }; // params in THIR may be unsafe, e.g. a union pattern. for param in &inner_thir.params { @@ -545,6 +548,20 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> { } } ExprKind::RawBorrow { arg, .. } => { + // Handle the case where we're taking a raw pointer to a union field + if let ExprKind::Scope { value: arg, .. } = self.thir[arg].kind { + if self.is_union_field_access(arg) { + // Taking a raw pointer to a union field is safe - just check the base expression + // but skip the union field safety check + self.visit_union_field_for_raw_borrow(arg); + return; + } + } else if self.is_union_field_access(arg) { + // Direct raw borrow of union field + self.visit_union_field_for_raw_borrow(arg); + return; + } + if let ExprKind::Scope { value: arg, .. } = self.thir[arg].kind && let ExprKind::Deref { arg } = self.thir[arg].kind { @@ -649,17 +666,27 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> { if adt_def.variant(variant_index).fields[name].safety.is_unsafe() { self.requires_unsafe(expr.span, UseOfUnsafeField); } else if adt_def.is_union() { - if let Some(assigned_ty) = self.assignment_info { - if assigned_ty.needs_drop(self.tcx, self.typing_env) { - // This would be unsafe, but should be outright impossible since we - // reject such unions. - assert!( - self.tcx.dcx().has_errors().is_some(), - "union fields that need dropping should be impossible: {assigned_ty}" - ); + // Check if this field access is part of a raw borrow operation + // If so, we've already handled it above and shouldn't reach here + match self.union_field_access_mode { + UnionFieldAccessMode::SuppressUnionFieldAccessError => { + // Suppress AccessToUnionField error for union fields chains + } + UnionFieldAccessMode::Normal => { + if let Some(assigned_ty) = self.assignment_info { + if assigned_ty.needs_drop(self.tcx, self.typing_env) { + // This would be unsafe, but should be outright impossible since we + // reject such unions. + assert!( + self.tcx.dcx().has_errors().is_some(), + "union fields that need dropping should be impossible: {assigned_ty}" + ); + } + } else { + // Only require unsafe if this is not a raw borrow operation + self.requires_unsafe(expr.span, AccessToUnionField); + } } - } else { - self.requires_unsafe(expr.span, AccessToUnionField); } } } @@ -712,6 +739,46 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> { } } +impl<'a, 'tcx> UnsafetyVisitor<'a, 'tcx> { + /// Check if an expression is a union field access + fn is_union_field_access(&self, expr_id: ExprId) -> bool { + match self.thir[expr_id].kind { + ExprKind::Field { lhs, .. } => { + let lhs = &self.thir[lhs]; + matches!(lhs.ty.kind(), ty::Adt(adt_def, _) if adt_def.is_union()) + } + _ => false, + } + } + + /// Visit a union field access in the context of a raw borrow operation + /// This ensures we still check safety of nested operations while allowing + /// the raw pointer creation itself + fn visit_union_field_for_raw_borrow(&mut self, mut expr_id: ExprId) { + let prev = self.union_field_access_mode; + self.union_field_access_mode = UnionFieldAccessMode::SuppressUnionFieldAccessError; + // Walk through the chain of union field accesses using while let + while let ExprKind::Field { lhs, variant_index, name } = self.thir[expr_id].kind { + let lhs_expr = &self.thir[lhs]; + if let ty::Adt(adt_def, _) = lhs_expr.ty.kind() { + // Check for unsafe fields but skip the union access check + if adt_def.variant(variant_index).fields[name].safety.is_unsafe() { + self.requires_unsafe(self.thir[expr_id].span, UseOfUnsafeField); + } + // If the LHS is also a union field access, keep walking + expr_id = lhs; + } else { + // Not a union, use normal visiting + visit::walk_expr(self, &self.thir[expr_id]); + return; + } + } + // Visit the base expression for any nested safety checks + self.visit_expr(&self.thir[expr_id]); + self.union_field_access_mode = prev; + } +} + #[derive(Clone)] enum SafetyContext { Safe, @@ -720,6 +787,13 @@ enum SafetyContext { UnsafeBlock { span: Span, hir_id: HirId, used: bool, nested_used_blocks: Vec }, } +/// Controls how union field accesses are checked +#[derive(Clone, Copy)] +enum UnionFieldAccessMode { + Normal, + SuppressUnionFieldAccessError, +} + #[derive(Clone, Copy)] struct NestedUsedBlock { hir_id: HirId, @@ -1199,6 +1273,7 @@ pub(crate) fn check_unsafety(tcx: TyCtxt<'_>, def: LocalDefId) { inside_adt: false, warnings: &mut warnings, suggest_unsafe_block: true, + union_field_access_mode: UnionFieldAccessMode::Normal, }; // params in THIR may be unsafe, e.g. a union pattern. for param in &thir.params { diff --git a/src/tools/miri/tests/pass/both_borrows/smallvec.rs b/src/tools/miri/tests/pass/both_borrows/smallvec.rs index f48815e37be34..fa5cfb03de2a9 100644 --- a/src/tools/miri/tests/pass/both_borrows/smallvec.rs +++ b/src/tools/miri/tests/pass/both_borrows/smallvec.rs @@ -25,7 +25,7 @@ impl RawSmallVec { } const fn as_mut_ptr_inline(&mut self) -> *mut T { - (unsafe { &raw mut self.inline }) as *mut T + &raw mut self.inline as *mut T } const unsafe fn as_mut_ptr_heap(&mut self) -> *mut T { diff --git a/tests/ui/union/union-unsafe.rs b/tests/ui/union/union-unsafe.rs index bd3946686be36..bbf01ac14c74f 100644 --- a/tests/ui/union/union-unsafe.rs +++ b/tests/ui/union/union-unsafe.rs @@ -17,6 +17,10 @@ union U4 { a: T, } +union U5 { + a: usize, +} + union URef { p: &'static mut i32, } @@ -31,6 +35,11 @@ fn deref_union_field(mut u: URef) { *(u.p) = 13; //~ ERROR access to union field is unsafe } +fn raw_deref_union_field(mut u: URef) { + // This is unsafe because we first dereference u.p (reading uninitialized memory) + let _p = &raw const *(u.p); //~ ERROR access to union field is unsafe +} + fn assign_noncopy_union_field(mut u: URefCell) { u.a = (ManuallyDrop::new(RefCell::new(0)), 1); // OK (assignment does not drop) u.a.0 = ManuallyDrop::new(RefCell::new(0)); // OK (assignment does not drop) @@ -57,6 +66,20 @@ fn main() { let a = u1.a; //~ ERROR access to union field is unsafe u1.a = 11; // OK + let mut u2 = U1 { a: 10 }; + let a = &raw mut u2.a; // OK + unsafe { *a = 3 }; + + let mut u3 = U1 { a: 10 }; + let a = std::ptr::addr_of_mut!(u3.a); // OK + unsafe { *a = 14 }; + + let u4 = U5 { a: 2 }; + let vec = vec![1, 2, 3]; + // This is unsafe because we read u4.a (potentially uninitialized memory) + // to use as an array index + let _a = &raw const vec[u4.a]; //~ ERROR access to union field is unsafe + let U1 { a } = u1; //~ ERROR access to union field is unsafe if let U1 { a: 12 } = u1 {} //~ ERROR access to union field is unsafe if let Some(U1 { a: 13 }) = Some(u1) {} //~ ERROR access to union field is unsafe @@ -73,4 +96,38 @@ fn main() { let mut u3 = U3 { a: ManuallyDrop::new(String::from("old")) }; // OK u3.a = ManuallyDrop::new(String::from("new")); // OK (assignment does not drop) *u3.a = String::from("new"); //~ ERROR access to union field is unsafe + + let mut unions = [U1 { a: 1 }, U1 { a: 2 }]; + + // Array indexing + union field raw borrow - should be OK + let ptr = &raw mut unions[0].a; // OK + let ptr2 = &raw const unions[1].a; // OK + + // Test for union fields chain, this should be allowed + #[derive(Copy, Clone)] + union Inner { + a: u8, + } + + union MoreInner { + moreinner: ManuallyDrop, + } + + union LessOuter { + lessouter: ManuallyDrop, + } + + union Outer { + outer: ManuallyDrop, + } + + let super_outer = Outer { + outer: ManuallyDrop::new(LessOuter { + lessouter: ManuallyDrop::new(MoreInner { + moreinner: ManuallyDrop::new(Inner { a: 42 }), + }), + }), + }; + + let ptr = &raw const super_outer.outer.lessouter.moreinner.a; } diff --git a/tests/ui/union/union-unsafe.stderr b/tests/ui/union/union-unsafe.stderr index 82b3f897167c7..ced9be73022c1 100644 --- a/tests/ui/union/union-unsafe.stderr +++ b/tests/ui/union/union-unsafe.stderr @@ -1,5 +1,5 @@ error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:31:6 + --> $DIR/union-unsafe.rs:35:6 | LL | *(u.p) = 13; | ^^^^^ access to union field @@ -7,7 +7,15 @@ LL | *(u.p) = 13; = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:43:6 + --> $DIR/union-unsafe.rs:40:26 + | +LL | let _p = &raw const *(u.p); + | ^^^^^ access to union field + | + = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior + +error[E0133]: access to union field is unsafe and requires unsafe function or block + --> $DIR/union-unsafe.rs:52:6 | LL | *u3.a = T::default(); | ^^^^ access to union field @@ -15,7 +23,7 @@ LL | *u3.a = T::default(); = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:49:6 + --> $DIR/union-unsafe.rs:58:6 | LL | *u3.a = T::default(); | ^^^^ access to union field @@ -23,7 +31,7 @@ LL | *u3.a = T::default(); = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:57:13 + --> $DIR/union-unsafe.rs:66:13 | LL | let a = u1.a; | ^^^^ access to union field @@ -31,7 +39,15 @@ LL | let a = u1.a; = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:60:14 + --> $DIR/union-unsafe.rs:81:29 + | +LL | let _a = &raw const vec[u4.a]; + | ^^^^ access to union field + | + = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior + +error[E0133]: access to union field is unsafe and requires unsafe function or block + --> $DIR/union-unsafe.rs:83:14 | LL | let U1 { a } = u1; | ^ access to union field @@ -39,7 +55,7 @@ LL | let U1 { a } = u1; = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:61:20 + --> $DIR/union-unsafe.rs:84:20 | LL | if let U1 { a: 12 } = u1 {} | ^^ access to union field @@ -47,7 +63,7 @@ LL | if let U1 { a: 12 } = u1 {} = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:62:25 + --> $DIR/union-unsafe.rs:85:25 | LL | if let Some(U1 { a: 13 }) = Some(u1) {} | ^^ access to union field @@ -55,7 +71,7 @@ LL | if let Some(U1 { a: 13 }) = Some(u1) {} = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:67:6 + --> $DIR/union-unsafe.rs:90:6 | LL | *u2.a = String::from("new"); | ^^^^ access to union field @@ -63,7 +79,7 @@ LL | *u2.a = String::from("new"); = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:71:6 + --> $DIR/union-unsafe.rs:94:6 | LL | *u3.a = 1; | ^^^^ access to union field @@ -71,13 +87,13 @@ LL | *u3.a = 1; = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior error[E0133]: access to union field is unsafe and requires unsafe function or block - --> $DIR/union-unsafe.rs:75:6 + --> $DIR/union-unsafe.rs:98:6 | LL | *u3.a = String::from("new"); | ^^^^ access to union field | = note: the field may not be properly initialized: using uninitialized data will cause undefined behavior -error: aborting due to 10 previous errors +error: aborting due to 12 previous errors For more information about this error, try `rustc --explain E0133`.