Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a lint against never type fallback affecting unsafe code #123939

Merged
merged 7 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/rustc_data_structures/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#![feature(lazy_cell)]
#![feature(lint_reasons)]
#![feature(macro_metavar_expr)]
#![feature(map_try_insert)]
#![feature(maybe_uninit_uninit_array)]
#![feature(min_specialization)]
#![feature(negative_impls)]
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_data_structures/src/unord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

use rustc_hash::{FxHashMap, FxHashSet};
use rustc_macros::{Decodable_Generic, Encodable_Generic};
use std::collections::hash_map::OccupiedError;
use std::{
borrow::{Borrow, BorrowMut},
collections::hash_map::Entry,
Expand Down Expand Up @@ -469,6 +470,11 @@ impl<K: Eq + Hash, V> UnordMap<K, V> {
self.inner.insert(k, v)
}

#[inline]
pub fn try_insert(&mut self, k: K, v: V) -> Result<&mut V, OccupiedError<'_, K, V>> {
self.inner.try_insert(k, v)
}

#[inline]
pub fn contains_key<Q: ?Sized>(&self, k: &Q) -> bool
where
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_hir_typeck/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ hir_typeck_lossy_provenance_ptr2int =

hir_typeck_missing_parentheses_in_range = can't call method `{$method_name}` on type `{$ty_str}`

hir_typeck_never_type_fallback_flowing_into_unsafe_call = never type fallback affects this call to an `unsafe` function
.help = specify the type explicitly
hir_typeck_never_type_fallback_flowing_into_unsafe_deref = never type fallback affects this raw pointer dereference
.help = specify the type explicitly
hir_typeck_never_type_fallback_flowing_into_unsafe_method = never type fallback affects this call to an `unsafe` method
.help = specify the type explicitly
hir_typeck_never_type_fallback_flowing_into_unsafe_path = never type fallback affects this `unsafe` function
.help = specify the type explicitly
hir_typeck_never_type_fallback_flowing_into_unsafe_union_field = never type fallback affects this union access
.help = specify the type explicitly

hir_typeck_no_associated_item = no {$item_kind} named `{$item_name}` found for {$ty_prefix} `{$ty_str}`{$trait_missing_method ->
[true] {""}
*[other] {" "}in the current scope
Expand Down
20 changes: 19 additions & 1 deletion compiler/rustc_hir_typeck/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,25 @@ pub struct MissingParenthesesInRange {
pub add_missing_parentheses: Option<AddMissingParenthesesInRange>,
}

#[derive(LintDiagnostic)]
pub enum NeverTypeFallbackFlowingIntoUnsafe {
#[help]
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_call)]
Call,
#[help]
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_method)]
Method,
#[help]
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_path)]
Path,
#[help]
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_union_field)]
UnionField,
#[help]
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_deref)]
Deref,
}

#[derive(Subdiagnostic)]
#[multipart_suggestion(
hir_typeck_add_missing_parentheses_in_range,
Expand Down Expand Up @@ -632,7 +651,6 @@ pub enum SuggestBoxingForReturnImplTrait {
ends: Vec<Span>,
},
}

#[derive(LintDiagnostic)]
#[diag(hir_typeck_dereferencing_mut_binding)]
pub struct DereferencingMutBinding {
Expand Down
245 changes: 226 additions & 19 deletions compiler/rustc_hir_typeck/src/fallback.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
use crate::FnCtxt;
use std::cell::OnceCell;

use crate::{errors, FnCtxt, TypeckRootCtxt};
use rustc_data_structures::{
graph::{self, iterate::DepthFirstSearch, vec_graph::VecGraph},
unord::{UnordBag, UnordMap, UnordSet},
};
use rustc_hir as hir;
use rustc_hir::intravisit::Visitor;
use rustc_hir::HirId;
use rustc_infer::infer::{DefineOpaqueTypes, InferOk};
use rustc_middle::ty::{self, Ty};
use rustc_middle::ty::{self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable};
use rustc_session::lint;
use rustc_span::DUMMY_SP;
use rustc_span::{def_id::LocalDefId, Span};

#[derive(Copy, Clone)]
pub enum DivergingFallbackBehavior {
Expand Down Expand Up @@ -335,6 +342,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
// reach a member of N. If so, it falls back to `()`. Else
// `!`.
let mut diverging_fallback = UnordMap::with_capacity(diverging_vids.len());
let unsafe_infer_vars = OnceCell::new();
for &diverging_vid in &diverging_vids {
let diverging_ty = Ty::new_var(self.tcx, diverging_vid);
let root_vid = self.root_var(diverging_vid);
Expand All @@ -354,11 +362,51 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
output: infer_var_infos.items().any(|info| info.output),
};

let mut fallback_to = |ty| {
let unsafe_infer_vars = unsafe_infer_vars.get_or_init(|| {
let unsafe_infer_vars = compute_unsafe_infer_vars(self.root_ctxt, self.body_id);
debug!(?unsafe_infer_vars);
unsafe_infer_vars
});

let affected_unsafe_infer_vars =
graph::depth_first_search_as_undirected(&coercion_graph, root_vid)
.filter_map(|x| unsafe_infer_vars.get(&x).copied())
.collect::<Vec<_>>();

for (hir_id, span, reason) in affected_unsafe_infer_vars {
self.tcx.emit_node_span_lint(
lint::builtin::NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE,
hir_id,
span,
match reason {
UnsafeUseReason::Call => {
errors::NeverTypeFallbackFlowingIntoUnsafe::Call
}
UnsafeUseReason::Method => {
errors::NeverTypeFallbackFlowingIntoUnsafe::Method
}
UnsafeUseReason::Path => {
errors::NeverTypeFallbackFlowingIntoUnsafe::Path
}
UnsafeUseReason::UnionField => {
errors::NeverTypeFallbackFlowingIntoUnsafe::UnionField
}
UnsafeUseReason::Deref => {
errors::NeverTypeFallbackFlowingIntoUnsafe::Deref
}
},
);
}

diverging_fallback.insert(diverging_ty, ty);
};

use DivergingFallbackBehavior::*;
match behavior {
FallbackToUnit => {
debug!("fallback to () - legacy: {:?}", diverging_vid);
diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
fallback_to(self.tcx.types.unit);
}
FallbackToNiko => {
if found_infer_var_info.self_in_trait && found_infer_var_info.output {
Expand Down Expand Up @@ -387,21 +435,21 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
// set, see the relationship finding module in
// compiler/rustc_trait_selection/src/traits/relationships.rs.
debug!("fallback to () - found trait and projection: {:?}", diverging_vid);
diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
fallback_to(self.tcx.types.unit);
} else if can_reach_non_diverging {
debug!("fallback to () - reached non-diverging: {:?}", diverging_vid);
diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
fallback_to(self.tcx.types.unit);
} else {
debug!("fallback to ! - all diverging: {:?}", diverging_vid);
diverging_fallback.insert(diverging_ty, self.tcx.types.never);
fallback_to(self.tcx.types.never);
}
}
FallbackToNever => {
debug!(
"fallback to ! - `rustc_never_type_mode = \"fallback_to_never\")`: {:?}",
diverging_vid
);
diverging_fallback.insert(diverging_ty, self.tcx.types.never);
fallback_to(self.tcx.types.never);
}
NoFallback => {
debug!(
Expand All @@ -417,7 +465,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {

/// Returns a graph whose nodes are (unresolved) inference variables and where
/// an edge `?A -> ?B` indicates that the variable `?A` is coerced to `?B`.
fn create_coercion_graph(&self) -> VecGraph<ty::TyVid> {
fn create_coercion_graph(&self) -> VecGraph<ty::TyVid, true> {
let pending_obligations = self.fulfillment_cx.borrow_mut().pending_obligations();
debug!("create_coercion_graph: pending_obligations={:?}", pending_obligations);
let coercion_edges: Vec<(ty::TyVid, ty::TyVid)> = pending_obligations
Expand All @@ -436,17 +484,12 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
//
// In practice currently the two ways that this happens is
// coercion and subtyping.
let (a, b) = if let ty::PredicateKind::Coerce(ty::CoercePredicate { a, b }) = atom {
(a, b)
} else if let ty::PredicateKind::Subtype(ty::SubtypePredicate {
a_is_expected: _,
a,
b,
}) = atom
{
(a, b)
} else {
return None;
let (a, b) = match atom {
ty::PredicateKind::Coerce(ty::CoercePredicate { a, b }) => (a, b),
ty::PredicateKind::Subtype(ty::SubtypePredicate { a_is_expected: _, a, b }) => {
(a, b)
}
_ => return None,
};

let a_vid = self.root_vid(a)?;
Expand All @@ -456,6 +499,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
.collect();
debug!("create_coercion_graph: coercion_edges={:?}", coercion_edges);
let num_ty_vars = self.num_ty_vars();

VecGraph::new(num_ty_vars, coercion_edges)
}

Expand All @@ -464,3 +508,166 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
Some(self.root_var(self.shallow_resolve(ty).ty_vid()?))
}
}

#[derive(Debug, Copy, Clone)]
pub(crate) enum UnsafeUseReason {
Call,
Method,
Path,
UnionField,
Deref,
}

/// Finds all type variables which are passed to an `unsafe` operation.
///
/// For example, for this function `f`:
/// ```ignore (demonstrative)
/// fn f() {
/// unsafe {
/// let x /* ?X */ = core::mem::zeroed();
/// // ^^^^^^^^^^^^^^^^^^^ -- hir_id, span, reason
///
/// let y = core::mem::zeroed::<Option<_ /* ?Y */>>();
/// // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- hir_id, span, reason
/// }
/// }
/// ```
///
/// `compute_unsafe_infer_vars` will return `{ id(?X) -> (hir_id, span, Call) }`
fn compute_unsafe_infer_vars<'a, 'tcx>(
root_ctxt: &'a TypeckRootCtxt<'tcx>,
body_id: LocalDefId,
) -> UnordMap<ty::TyVid, (HirId, Span, UnsafeUseReason)> {
let body_id =
root_ctxt.tcx.hir().maybe_body_owned_by(body_id).expect("body id must have an owner");
let body = root_ctxt.tcx.hir().body(body_id);
let mut res = UnordMap::default();

struct UnsafeInferVarsVisitor<'a, 'tcx, 'r> {
root_ctxt: &'a TypeckRootCtxt<'tcx>,
res: &'r mut UnordMap<ty::TyVid, (HirId, Span, UnsafeUseReason)>,
}

impl Visitor<'_> for UnsafeInferVarsVisitor<'_, '_, '_> {
fn visit_expr(&mut self, ex: &'_ hir::Expr<'_>) {
let typeck_results = self.root_ctxt.typeck_results.borrow();

match ex.kind {
hir::ExprKind::MethodCall(..) => {
if let Some(def_id) = typeck_results.type_dependent_def_id(ex.hir_id)
&& let method_ty = self.root_ctxt.tcx.type_of(def_id).instantiate_identity()
&& let sig = method_ty.fn_sig(self.root_ctxt.tcx)
&& let hir::Unsafety::Unsafe = sig.unsafety()
{
let mut collector = InferVarCollector {
value: (ex.hir_id, ex.span, UnsafeUseReason::Method),
res: self.res,
};

// Collect generic arguments (incl. `Self`) of the method
typeck_results
.node_args(ex.hir_id)
.types()
.for_each(|t| t.visit_with(&mut collector));
}
}

hir::ExprKind::Call(func, ..) => {
let func_ty = typeck_results.expr_ty(func);

if func_ty.is_fn()
&& let sig = func_ty.fn_sig(self.root_ctxt.tcx)
&& let hir::Unsafety::Unsafe = sig.unsafety()
{
let mut collector = InferVarCollector {
value: (ex.hir_id, ex.span, UnsafeUseReason::Call),
res: self.res,
};

// Try collecting generic arguments of the function.
// Note that we do this below for any paths (that don't have to be called),
// but there we do it with a different span/reason.
// This takes priority.
typeck_results
.node_args(func.hir_id)
.types()
.for_each(|t| t.visit_with(&mut collector));

// Also check the return type, for cases like `returns_unsafe_fn_ptr()()`
sig.output().visit_with(&mut collector);
}
}

// Check paths which refer to functions.
// We do this, instead of only checking `Call` to make sure the lint can't be
// avoided by storing unsafe function in a variable.
hir::ExprKind::Path(_) => {
let ty = typeck_results.expr_ty(ex);

// If this path refers to an unsafe function, collect inference variables which may affect it.
// `is_fn` excludes closures, but those can't be unsafe.
if ty.is_fn()
&& let sig = ty.fn_sig(self.root_ctxt.tcx)
&& let hir::Unsafety::Unsafe = sig.unsafety()
{
let mut collector = InferVarCollector {
value: (ex.hir_id, ex.span, UnsafeUseReason::Path),
res: self.res,
};

// Collect generic arguments of the function
typeck_results
.node_args(ex.hir_id)
.types()
.for_each(|t| t.visit_with(&mut collector));
}
}

hir::ExprKind::Unary(hir::UnOp::Deref, pointer) => {
if let ty::RawPtr(pointee, _) = typeck_results.expr_ty(pointer).kind() {
pointee.visit_with(&mut InferVarCollector {
value: (ex.hir_id, ex.span, UnsafeUseReason::Deref),
res: self.res,
});
}
}

hir::ExprKind::Field(base, _) => {
let base_ty = typeck_results.expr_ty(base);

if base_ty.is_union() {
typeck_results.expr_ty(ex).visit_with(&mut InferVarCollector {
value: (ex.hir_id, ex.span, UnsafeUseReason::UnionField),
res: self.res,
});
}
}

_ => (),
};

hir::intravisit::walk_expr(self, ex);
}
}

struct InferVarCollector<'r, V> {
value: V,
res: &'r mut UnordMap<ty::TyVid, V>,
}

impl<'tcx, V: Copy> ty::TypeVisitor<TyCtxt<'tcx>> for InferVarCollector<'_, V> {
fn visit_ty(&mut self, t: Ty<'tcx>) {
if let Some(vid) = t.ty_vid() {
_ = self.res.try_insert(vid, self.value);
} else {
t.super_visit_with(self)
}
}
}

UnsafeInferVarsVisitor { root_ctxt, res: &mut res }.visit_expr(&body.value);

debug!(?res, "collected the following unsafe vars for {body_id:?}");

res
}
Loading
Loading