Skip to content

Commit

Permalink
Auto merge of #123939 - WaffleLapkin:never-fallback-unsafe-lint, r=co…
Browse files Browse the repository at this point in the history
…mpiler-errors

Add a lint against never type fallback affecting unsafe code

~~I'm not very happy with the code quality... `VecGraph` not allowing you to get predecessors is very annoying. This should work though, so there is that.~~ (ended up updating `VecGraph` to support getting predecessors)

~~First few commits are from #123934 #123980
  • Loading branch information
bors committed May 2, 2024
2 parents f92d49b + b562617 commit fcc06c8
Show file tree
Hide file tree
Showing 9 changed files with 574 additions and 21 deletions.
1 change: 1 addition & 0 deletions compiler/rustc_data_structures/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#![feature(lazy_cell)]
#![feature(lint_reasons)]
#![feature(macro_metavar_expr)]
#![feature(map_try_insert)]
#![feature(maybe_uninit_uninit_array)]
#![feature(min_specialization)]
#![feature(negative_impls)]
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_data_structures/src/unord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_macros::{Decodable_Generic, Encodable_Generic};
use std::collections::hash_map::OccupiedError;
use std::{
borrow::{Borrow, BorrowMut},
collections::hash_map::Entry,
Expand Down Expand Up @@ -469,6 +470,11 @@ impl<K: Eq + Hash, V> UnordMap<K, V> {
self.inner.insert(k, v)
}

#[inline]
pub fn try_insert(&mut self, k: K, v: V) -> Result<&mut V, OccupiedError<'_, K, V>> {
self.inner.try_insert(k, v)
}

#[inline]
pub fn contains_key<Q: ?Sized>(&self, k: &Q) -> bool
where
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_hir_typeck/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ hir_typeck_lossy_provenance_ptr2int =
hir_typeck_missing_parentheses_in_range = can't call method `{$method_name}` on type `{$ty_str}`
hir_typeck_never_type_fallback_flowing_into_unsafe_call = never type fallback affects this call to an `unsafe` function
.help = specify the type explicitly
hir_typeck_never_type_fallback_flowing_into_unsafe_deref = never type fallback affects this raw pointer dereference
.help = specify the type explicitly
hir_typeck_never_type_fallback_flowing_into_unsafe_method = never type fallback affects this call to an `unsafe` method
.help = specify the type explicitly
hir_typeck_never_type_fallback_flowing_into_unsafe_path = never type fallback affects this `unsafe` function
.help = specify the type explicitly
hir_typeck_never_type_fallback_flowing_into_unsafe_union_field = never type fallback affects this union access
.help = specify the type explicitly
hir_typeck_no_associated_item = no {$item_kind} named `{$item_name}` found for {$ty_prefix} `{$ty_str}`{$trait_missing_method ->
[true] {""}
*[other] {" "}in the current scope
Expand Down
20 changes: 19 additions & 1 deletion compiler/rustc_hir_typeck/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,25 @@ pub struct MissingParenthesesInRange {
pub add_missing_parentheses: Option<AddMissingParenthesesInRange>,
}

#[derive(LintDiagnostic)]
pub enum NeverTypeFallbackFlowingIntoUnsafe {
#[help]
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_call)]
Call,
#[help]
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_method)]
Method,
#[help]
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_path)]
Path,
#[help]
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_union_field)]
UnionField,
#[help]
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_deref)]
Deref,
}

#[derive(Subdiagnostic)]
#[multipart_suggestion(
hir_typeck_add_missing_parentheses_in_range,
Expand Down Expand Up @@ -632,7 +651,6 @@ pub enum SuggestBoxingForReturnImplTrait {
ends: Vec<Span>,
},
}

#[derive(LintDiagnostic)]
#[diag(hir_typeck_dereferencing_mut_binding)]
pub struct DereferencingMutBinding {
Expand Down
245 changes: 226 additions & 19 deletions compiler/rustc_hir_typeck/src/fallback.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
use crate::FnCtxt;
use std::cell::OnceCell;

use crate::{errors, FnCtxt, TypeckRootCtxt};
use rustc_data_structures::{
graph::{self, iterate::DepthFirstSearch, vec_graph::VecGraph},
unord::{UnordBag, UnordMap, UnordSet},
};
use rustc_hir as hir;
use rustc_hir::intravisit::Visitor;
use rustc_hir::HirId;
use rustc_infer::infer::{DefineOpaqueTypes, InferOk};
use rustc_middle::ty::{self, Ty};
use rustc_middle::ty::{self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable};
use rustc_session::lint;
use rustc_span::DUMMY_SP;
use rustc_span::{def_id::LocalDefId, Span};

#[derive(Copy, Clone)]
pub enum DivergingFallbackBehavior {
Expand Down Expand Up @@ -335,6 +342,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
// reach a member of N. If so, it falls back to `()`. Else
// `!`.
let mut diverging_fallback = UnordMap::with_capacity(diverging_vids.len());
let unsafe_infer_vars = OnceCell::new();
for &diverging_vid in &diverging_vids {
let diverging_ty = Ty::new_var(self.tcx, diverging_vid);
let root_vid = self.root_var(diverging_vid);
Expand All @@ -354,11 +362,51 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
output: infer_var_infos.items().any(|info| info.output),
};

let mut fallback_to = |ty| {
let unsafe_infer_vars = unsafe_infer_vars.get_or_init(|| {
let unsafe_infer_vars = compute_unsafe_infer_vars(self.root_ctxt, self.body_id);
debug!(?unsafe_infer_vars);
unsafe_infer_vars
});

let affected_unsafe_infer_vars =
graph::depth_first_search_as_undirected(&coercion_graph, root_vid)
.filter_map(|x| unsafe_infer_vars.get(&x).copied())
.collect::<Vec<_>>();

for (hir_id, span, reason) in affected_unsafe_infer_vars {
self.tcx.emit_node_span_lint(
lint::builtin::NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE,
hir_id,
span,
match reason {
UnsafeUseReason::Call => {
errors::NeverTypeFallbackFlowingIntoUnsafe::Call
}
UnsafeUseReason::Method => {
errors::NeverTypeFallbackFlowingIntoUnsafe::Method
}
UnsafeUseReason::Path => {
errors::NeverTypeFallbackFlowingIntoUnsafe::Path
}
UnsafeUseReason::UnionField => {
errors::NeverTypeFallbackFlowingIntoUnsafe::UnionField
}
UnsafeUseReason::Deref => {
errors::NeverTypeFallbackFlowingIntoUnsafe::Deref
}
},
);
}

diverging_fallback.insert(diverging_ty, ty);
};

use DivergingFallbackBehavior::*;
match behavior {
FallbackToUnit => {
debug!("fallback to () - legacy: {:?}", diverging_vid);
diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
fallback_to(self.tcx.types.unit);
}
FallbackToNiko => {
if found_infer_var_info.self_in_trait && found_infer_var_info.output {
Expand Down Expand Up @@ -387,21 +435,21 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
// set, see the relationship finding module in
// compiler/rustc_trait_selection/src/traits/relationships.rs.
debug!("fallback to () - found trait and projection: {:?}", diverging_vid);
diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
fallback_to(self.tcx.types.unit);
} else if can_reach_non_diverging {
debug!("fallback to () - reached non-diverging: {:?}", diverging_vid);
diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
fallback_to(self.tcx.types.unit);
} else {
debug!("fallback to ! - all diverging: {:?}", diverging_vid);
diverging_fallback.insert(diverging_ty, self.tcx.types.never);
fallback_to(self.tcx.types.never);
}
}
FallbackToNever => {
debug!(
"fallback to ! - `rustc_never_type_mode = \"fallback_to_never\")`: {:?}",
diverging_vid
);
diverging_fallback.insert(diverging_ty, self.tcx.types.never);
fallback_to(self.tcx.types.never);
}
NoFallback => {
debug!(
Expand All @@ -417,7 +465,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {

/// Returns a graph whose nodes are (unresolved) inference variables and where
/// an edge `?A -> ?B` indicates that the variable `?A` is coerced to `?B`.
fn create_coercion_graph(&self) -> VecGraph<ty::TyVid> {
fn create_coercion_graph(&self) -> VecGraph<ty::TyVid, true> {
let pending_obligations = self.fulfillment_cx.borrow_mut().pending_obligations();
debug!("create_coercion_graph: pending_obligations={:?}", pending_obligations);
let coercion_edges: Vec<(ty::TyVid, ty::TyVid)> = pending_obligations
Expand All @@ -436,17 +484,12 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
//
// In practice currently the two ways that this happens is
// coercion and subtyping.
let (a, b) = if let ty::PredicateKind::Coerce(ty::CoercePredicate { a, b }) = atom {
(a, b)
} else if let ty::PredicateKind::Subtype(ty::SubtypePredicate {
a_is_expected: _,
a,
b,
}) = atom
{
(a, b)
} else {
return None;
let (a, b) = match atom {
ty::PredicateKind::Coerce(ty::CoercePredicate { a, b }) => (a, b),
ty::PredicateKind::Subtype(ty::SubtypePredicate { a_is_expected: _, a, b }) => {
(a, b)
}
_ => return None,
};

let a_vid = self.root_vid(a)?;
Expand All @@ -456,6 +499,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
.collect();
debug!("create_coercion_graph: coercion_edges={:?}", coercion_edges);
let num_ty_vars = self.num_ty_vars();

VecGraph::new(num_ty_vars, coercion_edges)
}

Expand All @@ -464,3 +508,166 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
Some(self.root_var(self.shallow_resolve(ty).ty_vid()?))
}
}

#[derive(Debug, Copy, Clone)]
pub(crate) enum UnsafeUseReason {
Call,
Method,
Path,
UnionField,
Deref,
}

/// Finds all type variables which are passed to an `unsafe` operation.
///
/// For example, for this function `f`:
/// ```ignore (demonstrative)
/// fn f() {
/// unsafe {
/// let x /* ?X */ = core::mem::zeroed();
/// // ^^^^^^^^^^^^^^^^^^^ -- hir_id, span, reason
///
/// let y = core::mem::zeroed::<Option<_ /* ?Y */>>();
/// // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- hir_id, span, reason
/// }
/// }
/// ```
///
/// `compute_unsafe_infer_vars` will return `{ id(?X) -> (hir_id, span, Call) }`
fn compute_unsafe_infer_vars<'a, 'tcx>(
root_ctxt: &'a TypeckRootCtxt<'tcx>,
body_id: LocalDefId,
) -> UnordMap<ty::TyVid, (HirId, Span, UnsafeUseReason)> {
let body_id =
root_ctxt.tcx.hir().maybe_body_owned_by(body_id).expect("body id must have an owner");
let body = root_ctxt.tcx.hir().body(body_id);
let mut res = UnordMap::default();

struct UnsafeInferVarsVisitor<'a, 'tcx, 'r> {
root_ctxt: &'a TypeckRootCtxt<'tcx>,
res: &'r mut UnordMap<ty::TyVid, (HirId, Span, UnsafeUseReason)>,
}

impl Visitor<'_> for UnsafeInferVarsVisitor<'_, '_, '_> {
fn visit_expr(&mut self, ex: &'_ hir::Expr<'_>) {
let typeck_results = self.root_ctxt.typeck_results.borrow();

match ex.kind {
hir::ExprKind::MethodCall(..) => {
if let Some(def_id) = typeck_results.type_dependent_def_id(ex.hir_id)
&& let method_ty = self.root_ctxt.tcx.type_of(def_id).instantiate_identity()
&& let sig = method_ty.fn_sig(self.root_ctxt.tcx)
&& let hir::Unsafety::Unsafe = sig.unsafety()
{
let mut collector = InferVarCollector {
value: (ex.hir_id, ex.span, UnsafeUseReason::Method),
res: self.res,
};

// Collect generic arguments (incl. `Self`) of the method
typeck_results
.node_args(ex.hir_id)
.types()
.for_each(|t| t.visit_with(&mut collector));
}
}

hir::ExprKind::Call(func, ..) => {
let func_ty = typeck_results.expr_ty(func);

if func_ty.is_fn()
&& let sig = func_ty.fn_sig(self.root_ctxt.tcx)
&& let hir::Unsafety::Unsafe = sig.unsafety()
{
let mut collector = InferVarCollector {
value: (ex.hir_id, ex.span, UnsafeUseReason::Call),
res: self.res,
};

// Try collecting generic arguments of the function.
// Note that we do this below for any paths (that don't have to be called),
// but there we do it with a different span/reason.
// This takes priority.
typeck_results
.node_args(func.hir_id)
.types()
.for_each(|t| t.visit_with(&mut collector));

// Also check the return type, for cases like `returns_unsafe_fn_ptr()()`
sig.output().visit_with(&mut collector);
}
}

// Check paths which refer to functions.
// We do this, instead of only checking `Call` to make sure the lint can't be
// avoided by storing unsafe function in a variable.
hir::ExprKind::Path(_) => {
let ty = typeck_results.expr_ty(ex);

// If this path refers to an unsafe function, collect inference variables which may affect it.
// `is_fn` excludes closures, but those can't be unsafe.
if ty.is_fn()
&& let sig = ty.fn_sig(self.root_ctxt.tcx)
&& let hir::Unsafety::Unsafe = sig.unsafety()
{
let mut collector = InferVarCollector {
value: (ex.hir_id, ex.span, UnsafeUseReason::Path),
res: self.res,
};

// Collect generic arguments of the function
typeck_results
.node_args(ex.hir_id)
.types()
.for_each(|t| t.visit_with(&mut collector));
}
}

hir::ExprKind::Unary(hir::UnOp::Deref, pointer) => {
if let ty::RawPtr(pointee, _) = typeck_results.expr_ty(pointer).kind() {
pointee.visit_with(&mut InferVarCollector {
value: (ex.hir_id, ex.span, UnsafeUseReason::Deref),
res: self.res,
});
}
}

hir::ExprKind::Field(base, _) => {
let base_ty = typeck_results.expr_ty(base);

if base_ty.is_union() {
typeck_results.expr_ty(ex).visit_with(&mut InferVarCollector {
value: (ex.hir_id, ex.span, UnsafeUseReason::UnionField),
res: self.res,
});
}
}

_ => (),
};

hir::intravisit::walk_expr(self, ex);
}
}

struct InferVarCollector<'r, V> {
value: V,
res: &'r mut UnordMap<ty::TyVid, V>,
}

impl<'tcx, V: Copy> ty::TypeVisitor<TyCtxt<'tcx>> for InferVarCollector<'_, V> {
fn visit_ty(&mut self, t: Ty<'tcx>) {
if let Some(vid) = t.ty_vid() {
_ = self.res.try_insert(vid, self.value);
} else {
t.super_visit_with(self)
}
}
}

UnsafeInferVarsVisitor { root_ctxt, res: &mut res }.visit_expr(&body.value);

debug!(?res, "collected the following unsafe vars for {body_id:?}");

res
}
Loading

0 comments on commit fcc06c8

Please sign in to comment.