Skip to content

Commit fcc06c8

Browse files
committed
Auto merge of #123939 - WaffleLapkin:never-fallback-unsafe-lint, r=compiler-errors
Add a lint against never type fallback affecting unsafe code ~~I'm not very happy with the code quality... `VecGraph` not allowing you to get predecessors is very annoying. This should work though, so there is that.~~ (ended up updating `VecGraph` to support getting predecessors) ~~First few commits are from #123934 #123980
2 parents f92d49b + b562617 commit fcc06c8

File tree

9 files changed

+574
-21
lines changed

9 files changed

+574
-21
lines changed

compiler/rustc_data_structures/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#![feature(lazy_cell)]
2626
#![feature(lint_reasons)]
2727
#![feature(macro_metavar_expr)]
28+
#![feature(map_try_insert)]
2829
#![feature(maybe_uninit_uninit_array)]
2930
#![feature(min_specialization)]
3031
#![feature(negative_impls)]

compiler/rustc_data_structures/src/unord.rs

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
55
use rustc_hash::{FxHashMap, FxHashSet};
66
use rustc_macros::{Decodable_Generic, Encodable_Generic};
7+
use std::collections::hash_map::OccupiedError;
78
use std::{
89
borrow::{Borrow, BorrowMut},
910
collections::hash_map::Entry,
@@ -469,6 +470,11 @@ impl<K: Eq + Hash, V> UnordMap<K, V> {
469470
self.inner.insert(k, v)
470471
}
471472

473+
#[inline]
474+
pub fn try_insert(&mut self, k: K, v: V) -> Result<&mut V, OccupiedError<'_, K, V>> {
475+
self.inner.try_insert(k, v)
476+
}
477+
472478
#[inline]
473479
pub fn contains_key<Q: ?Sized>(&self, k: &Q) -> bool
474480
where

compiler/rustc_hir_typeck/messages.ftl

+11
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,17 @@ hir_typeck_lossy_provenance_ptr2int =
9999
100100
hir_typeck_missing_parentheses_in_range = can't call method `{$method_name}` on type `{$ty_str}`
101101
102+
hir_typeck_never_type_fallback_flowing_into_unsafe_call = never type fallback affects this call to an `unsafe` function
103+
.help = specify the type explicitly
104+
hir_typeck_never_type_fallback_flowing_into_unsafe_deref = never type fallback affects this raw pointer dereference
105+
.help = specify the type explicitly
106+
hir_typeck_never_type_fallback_flowing_into_unsafe_method = never type fallback affects this call to an `unsafe` method
107+
.help = specify the type explicitly
108+
hir_typeck_never_type_fallback_flowing_into_unsafe_path = never type fallback affects this `unsafe` function
109+
.help = specify the type explicitly
110+
hir_typeck_never_type_fallback_flowing_into_unsafe_union_field = never type fallback affects this union access
111+
.help = specify the type explicitly
112+
102113
hir_typeck_no_associated_item = no {$item_kind} named `{$item_name}` found for {$ty_prefix} `{$ty_str}`{$trait_missing_method ->
103114
[true] {""}
104115
*[other] {" "}in the current scope

compiler/rustc_hir_typeck/src/errors.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,25 @@ pub struct MissingParenthesesInRange {
164164
pub add_missing_parentheses: Option<AddMissingParenthesesInRange>,
165165
}
166166

167+
#[derive(LintDiagnostic)]
168+
pub enum NeverTypeFallbackFlowingIntoUnsafe {
169+
#[help]
170+
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_call)]
171+
Call,
172+
#[help]
173+
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_method)]
174+
Method,
175+
#[help]
176+
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_path)]
177+
Path,
178+
#[help]
179+
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_union_field)]
180+
UnionField,
181+
#[help]
182+
#[diag(hir_typeck_never_type_fallback_flowing_into_unsafe_deref)]
183+
Deref,
184+
}
185+
167186
#[derive(Subdiagnostic)]
168187
#[multipart_suggestion(
169188
hir_typeck_add_missing_parentheses_in_range,
@@ -632,7 +651,6 @@ pub enum SuggestBoxingForReturnImplTrait {
632651
ends: Vec<Span>,
633652
},
634653
}
635-
636654
#[derive(LintDiagnostic)]
637655
#[diag(hir_typeck_dereferencing_mut_binding)]
638656
pub struct DereferencingMutBinding {

compiler/rustc_hir_typeck/src/fallback.rs

+226-19
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1-
use crate::FnCtxt;
1+
use std::cell::OnceCell;
2+
3+
use crate::{errors, FnCtxt, TypeckRootCtxt};
24
use rustc_data_structures::{
35
graph::{self, iterate::DepthFirstSearch, vec_graph::VecGraph},
46
unord::{UnordBag, UnordMap, UnordSet},
57
};
8+
use rustc_hir as hir;
9+
use rustc_hir::intravisit::Visitor;
10+
use rustc_hir::HirId;
611
use rustc_infer::infer::{DefineOpaqueTypes, InferOk};
7-
use rustc_middle::ty::{self, Ty};
12+
use rustc_middle::ty::{self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable};
13+
use rustc_session::lint;
814
use rustc_span::DUMMY_SP;
15+
use rustc_span::{def_id::LocalDefId, Span};
916

1017
#[derive(Copy, Clone)]
1118
pub enum DivergingFallbackBehavior {
@@ -335,6 +342,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
335342
// reach a member of N. If so, it falls back to `()`. Else
336343
// `!`.
337344
let mut diverging_fallback = UnordMap::with_capacity(diverging_vids.len());
345+
let unsafe_infer_vars = OnceCell::new();
338346
for &diverging_vid in &diverging_vids {
339347
let diverging_ty = Ty::new_var(self.tcx, diverging_vid);
340348
let root_vid = self.root_var(diverging_vid);
@@ -354,11 +362,51 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
354362
output: infer_var_infos.items().any(|info| info.output),
355363
};
356364

365+
let mut fallback_to = |ty| {
366+
let unsafe_infer_vars = unsafe_infer_vars.get_or_init(|| {
367+
let unsafe_infer_vars = compute_unsafe_infer_vars(self.root_ctxt, self.body_id);
368+
debug!(?unsafe_infer_vars);
369+
unsafe_infer_vars
370+
});
371+
372+
let affected_unsafe_infer_vars =
373+
graph::depth_first_search_as_undirected(&coercion_graph, root_vid)
374+
.filter_map(|x| unsafe_infer_vars.get(&x).copied())
375+
.collect::<Vec<_>>();
376+
377+
for (hir_id, span, reason) in affected_unsafe_infer_vars {
378+
self.tcx.emit_node_span_lint(
379+
lint::builtin::NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE,
380+
hir_id,
381+
span,
382+
match reason {
383+
UnsafeUseReason::Call => {
384+
errors::NeverTypeFallbackFlowingIntoUnsafe::Call
385+
}
386+
UnsafeUseReason::Method => {
387+
errors::NeverTypeFallbackFlowingIntoUnsafe::Method
388+
}
389+
UnsafeUseReason::Path => {
390+
errors::NeverTypeFallbackFlowingIntoUnsafe::Path
391+
}
392+
UnsafeUseReason::UnionField => {
393+
errors::NeverTypeFallbackFlowingIntoUnsafe::UnionField
394+
}
395+
UnsafeUseReason::Deref => {
396+
errors::NeverTypeFallbackFlowingIntoUnsafe::Deref
397+
}
398+
},
399+
);
400+
}
401+
402+
diverging_fallback.insert(diverging_ty, ty);
403+
};
404+
357405
use DivergingFallbackBehavior::*;
358406
match behavior {
359407
FallbackToUnit => {
360408
debug!("fallback to () - legacy: {:?}", diverging_vid);
361-
diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
409+
fallback_to(self.tcx.types.unit);
362410
}
363411
FallbackToNiko => {
364412
if found_infer_var_info.self_in_trait && found_infer_var_info.output {
@@ -387,21 +435,21 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
387435
// set, see the relationship finding module in
388436
// compiler/rustc_trait_selection/src/traits/relationships.rs.
389437
debug!("fallback to () - found trait and projection: {:?}", diverging_vid);
390-
diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
438+
fallback_to(self.tcx.types.unit);
391439
} else if can_reach_non_diverging {
392440
debug!("fallback to () - reached non-diverging: {:?}", diverging_vid);
393-
diverging_fallback.insert(diverging_ty, self.tcx.types.unit);
441+
fallback_to(self.tcx.types.unit);
394442
} else {
395443
debug!("fallback to ! - all diverging: {:?}", diverging_vid);
396-
diverging_fallback.insert(diverging_ty, self.tcx.types.never);
444+
fallback_to(self.tcx.types.never);
397445
}
398446
}
399447
FallbackToNever => {
400448
debug!(
401449
"fallback to ! - `rustc_never_type_mode = \"fallback_to_never\")`: {:?}",
402450
diverging_vid
403451
);
404-
diverging_fallback.insert(diverging_ty, self.tcx.types.never);
452+
fallback_to(self.tcx.types.never);
405453
}
406454
NoFallback => {
407455
debug!(
@@ -417,7 +465,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
417465

418466
/// Returns a graph whose nodes are (unresolved) inference variables and where
419467
/// an edge `?A -> ?B` indicates that the variable `?A` is coerced to `?B`.
420-
fn create_coercion_graph(&self) -> VecGraph<ty::TyVid> {
468+
fn create_coercion_graph(&self) -> VecGraph<ty::TyVid, true> {
421469
let pending_obligations = self.fulfillment_cx.borrow_mut().pending_obligations();
422470
debug!("create_coercion_graph: pending_obligations={:?}", pending_obligations);
423471
let coercion_edges: Vec<(ty::TyVid, ty::TyVid)> = pending_obligations
@@ -436,17 +484,12 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
436484
//
437485
// In practice currently the two ways that this happens is
438486
// coercion and subtyping.
439-
let (a, b) = if let ty::PredicateKind::Coerce(ty::CoercePredicate { a, b }) = atom {
440-
(a, b)
441-
} else if let ty::PredicateKind::Subtype(ty::SubtypePredicate {
442-
a_is_expected: _,
443-
a,
444-
b,
445-
}) = atom
446-
{
447-
(a, b)
448-
} else {
449-
return None;
487+
let (a, b) = match atom {
488+
ty::PredicateKind::Coerce(ty::CoercePredicate { a, b }) => (a, b),
489+
ty::PredicateKind::Subtype(ty::SubtypePredicate { a_is_expected: _, a, b }) => {
490+
(a, b)
491+
}
492+
_ => return None,
450493
};
451494

452495
let a_vid = self.root_vid(a)?;
@@ -456,6 +499,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
456499
.collect();
457500
debug!("create_coercion_graph: coercion_edges={:?}", coercion_edges);
458501
let num_ty_vars = self.num_ty_vars();
502+
459503
VecGraph::new(num_ty_vars, coercion_edges)
460504
}
461505

@@ -464,3 +508,166 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
464508
Some(self.root_var(self.shallow_resolve(ty).ty_vid()?))
465509
}
466510
}
511+
512+
#[derive(Debug, Copy, Clone)]
513+
pub(crate) enum UnsafeUseReason {
514+
Call,
515+
Method,
516+
Path,
517+
UnionField,
518+
Deref,
519+
}
520+
521+
/// Finds all type variables which are passed to an `unsafe` operation.
522+
///
523+
/// For example, for this function `f`:
524+
/// ```ignore (demonstrative)
525+
/// fn f() {
526+
/// unsafe {
527+
/// let x /* ?X */ = core::mem::zeroed();
528+
/// // ^^^^^^^^^^^^^^^^^^^ -- hir_id, span, reason
529+
///
530+
/// let y = core::mem::zeroed::<Option<_ /* ?Y */>>();
531+
/// // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- hir_id, span, reason
532+
/// }
533+
/// }
534+
/// ```
535+
///
536+
/// `compute_unsafe_infer_vars` will return `{ id(?X) -> (hir_id, span, Call) }`
537+
fn compute_unsafe_infer_vars<'a, 'tcx>(
538+
root_ctxt: &'a TypeckRootCtxt<'tcx>,
539+
body_id: LocalDefId,
540+
) -> UnordMap<ty::TyVid, (HirId, Span, UnsafeUseReason)> {
541+
let body_id =
542+
root_ctxt.tcx.hir().maybe_body_owned_by(body_id).expect("body id must have an owner");
543+
let body = root_ctxt.tcx.hir().body(body_id);
544+
let mut res = UnordMap::default();
545+
546+
struct UnsafeInferVarsVisitor<'a, 'tcx, 'r> {
547+
root_ctxt: &'a TypeckRootCtxt<'tcx>,
548+
res: &'r mut UnordMap<ty::TyVid, (HirId, Span, UnsafeUseReason)>,
549+
}
550+
551+
impl Visitor<'_> for UnsafeInferVarsVisitor<'_, '_, '_> {
552+
fn visit_expr(&mut self, ex: &'_ hir::Expr<'_>) {
553+
let typeck_results = self.root_ctxt.typeck_results.borrow();
554+
555+
match ex.kind {
556+
hir::ExprKind::MethodCall(..) => {
557+
if let Some(def_id) = typeck_results.type_dependent_def_id(ex.hir_id)
558+
&& let method_ty = self.root_ctxt.tcx.type_of(def_id).instantiate_identity()
559+
&& let sig = method_ty.fn_sig(self.root_ctxt.tcx)
560+
&& let hir::Unsafety::Unsafe = sig.unsafety()
561+
{
562+
let mut collector = InferVarCollector {
563+
value: (ex.hir_id, ex.span, UnsafeUseReason::Method),
564+
res: self.res,
565+
};
566+
567+
// Collect generic arguments (incl. `Self`) of the method
568+
typeck_results
569+
.node_args(ex.hir_id)
570+
.types()
571+
.for_each(|t| t.visit_with(&mut collector));
572+
}
573+
}
574+
575+
hir::ExprKind::Call(func, ..) => {
576+
let func_ty = typeck_results.expr_ty(func);
577+
578+
if func_ty.is_fn()
579+
&& let sig = func_ty.fn_sig(self.root_ctxt.tcx)
580+
&& let hir::Unsafety::Unsafe = sig.unsafety()
581+
{
582+
let mut collector = InferVarCollector {
583+
value: (ex.hir_id, ex.span, UnsafeUseReason::Call),
584+
res: self.res,
585+
};
586+
587+
// Try collecting generic arguments of the function.
588+
// Note that we do this below for any paths (that don't have to be called),
589+
// but there we do it with a different span/reason.
590+
// This takes priority.
591+
typeck_results
592+
.node_args(func.hir_id)
593+
.types()
594+
.for_each(|t| t.visit_with(&mut collector));
595+
596+
// Also check the return type, for cases like `returns_unsafe_fn_ptr()()`
597+
sig.output().visit_with(&mut collector);
598+
}
599+
}
600+
601+
// Check paths which refer to functions.
602+
// We do this, instead of only checking `Call` to make sure the lint can't be
603+
// avoided by storing unsafe function in a variable.
604+
hir::ExprKind::Path(_) => {
605+
let ty = typeck_results.expr_ty(ex);
606+
607+
// If this path refers to an unsafe function, collect inference variables which may affect it.
608+
// `is_fn` excludes closures, but those can't be unsafe.
609+
if ty.is_fn()
610+
&& let sig = ty.fn_sig(self.root_ctxt.tcx)
611+
&& let hir::Unsafety::Unsafe = sig.unsafety()
612+
{
613+
let mut collector = InferVarCollector {
614+
value: (ex.hir_id, ex.span, UnsafeUseReason::Path),
615+
res: self.res,
616+
};
617+
618+
// Collect generic arguments of the function
619+
typeck_results
620+
.node_args(ex.hir_id)
621+
.types()
622+
.for_each(|t| t.visit_with(&mut collector));
623+
}
624+
}
625+
626+
hir::ExprKind::Unary(hir::UnOp::Deref, pointer) => {
627+
if let ty::RawPtr(pointee, _) = typeck_results.expr_ty(pointer).kind() {
628+
pointee.visit_with(&mut InferVarCollector {
629+
value: (ex.hir_id, ex.span, UnsafeUseReason::Deref),
630+
res: self.res,
631+
});
632+
}
633+
}
634+
635+
hir::ExprKind::Field(base, _) => {
636+
let base_ty = typeck_results.expr_ty(base);
637+
638+
if base_ty.is_union() {
639+
typeck_results.expr_ty(ex).visit_with(&mut InferVarCollector {
640+
value: (ex.hir_id, ex.span, UnsafeUseReason::UnionField),
641+
res: self.res,
642+
});
643+
}
644+
}
645+
646+
_ => (),
647+
};
648+
649+
hir::intravisit::walk_expr(self, ex);
650+
}
651+
}
652+
653+
struct InferVarCollector<'r, V> {
654+
value: V,
655+
res: &'r mut UnordMap<ty::TyVid, V>,
656+
}
657+
658+
impl<'tcx, V: Copy> ty::TypeVisitor<TyCtxt<'tcx>> for InferVarCollector<'_, V> {
659+
fn visit_ty(&mut self, t: Ty<'tcx>) {
660+
if let Some(vid) = t.ty_vid() {
661+
_ = self.res.try_insert(vid, self.value);
662+
} else {
663+
t.super_visit_with(self)
664+
}
665+
}
666+
}
667+
668+
UnsafeInferVarsVisitor { root_ctxt, res: &mut res }.visit_expr(&body.value);
669+
670+
debug!(?res, "collected the following unsafe vars for {body_id:?}");
671+
672+
res
673+
}

0 commit comments

Comments
 (0)