Skip to content

Commit

Permalink
Don't worry about uncaptured contravariant lifetimes if they outlive …
Browse files Browse the repository at this point in the history
…a captured lifetime
  • Loading branch information
compiler-errors committed Sep 5, 2024
1 parent eb33b43 commit f8f4d50
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 21 deletions.
247 changes: 226 additions & 21 deletions compiler/rustc_lint/src/impl_trait_overcaptures.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
use rustc_data_structures::fx::FxIndexSet;
use std::cell::LazyCell;

use rustc_data_structures::fx::{FxHashMap, FxIndexMap, FxIndexSet};
use rustc_data_structures::unord::UnordSet;
use rustc_errors::{Applicability, LintDiagnostic};
use rustc_hir as hir;
use rustc_hir::def::DefKind;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_infer::infer::outlives::env::OutlivesEnvironment;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_macros::LintDiagnostic;
use rustc_middle::bug;
use rustc_middle::middle::resolve_bound_vars::ResolvedArg;
use rustc_middle::ty::relate::{
structurally_relate_consts, structurally_relate_tys, Relate, RelateResult, TypeRelation,
};
use rustc_middle::ty::{
self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor,
};
use rustc_middle::{bug, span_bug};
use rustc_session::lint::FutureIncompatibilityReason;
use rustc_session::{declare_lint, declare_lint_pass};
use rustc_span::edition::Edition;
use rustc_span::Span;
use rustc_span::{Span, Symbol};
use rustc_trait_selection::traits::outlives_bounds::InferCtxtExt;
use rustc_trait_selection::traits::ObligationCtxt;

use crate::{fluent_generated as fluent, LateContext, LateLintPass};

Expand Down Expand Up @@ -119,38 +128,86 @@ impl<'tcx> LateLintPass<'tcx> for ImplTraitOvercaptures {
}
}

#[derive(PartialEq, Eq, Hash, Debug, Copy, Clone)]
enum ParamKind {
// Early-bound var.
Early(Symbol, u32),
// Late-bound var on function, not within a binder. We can capture these.
Free(DefId, Symbol),
// Late-bound var in a binder. We can't capture these yet.
Late,
}

fn check_fn(tcx: TyCtxt<'_>, parent_def_id: LocalDefId) {
let sig = tcx.fn_sig(parent_def_id).instantiate_identity();

let mut in_scope_parameters = FxIndexSet::default();
let mut in_scope_parameters = FxIndexMap::default();
// Populate the in_scope_parameters list first with all of the generics in scope
let mut current_def_id = Some(parent_def_id.to_def_id());
while let Some(def_id) = current_def_id {
let generics = tcx.generics_of(def_id);
for param in &generics.own_params {
in_scope_parameters.insert(param.def_id);
in_scope_parameters.insert(param.def_id, ParamKind::Early(param.name, param.index));
}
current_def_id = generics.parent;
}

for bound_var in sig.bound_vars() {
let ty::BoundVariableKind::Region(ty::BoundRegionKind::BrNamed(def_id, name)) = bound_var
else {
span_bug!(tcx.def_span(parent_def_id), "unexpected non-lifetime binder on fn sig");
};

in_scope_parameters.insert(def_id, ParamKind::Free(def_id, name));
}

let sig = tcx.liberate_late_bound_regions(parent_def_id.to_def_id(), sig);

// Then visit the signature to walk through all the binders (incl. the late-bound
// vars on the function itself, which we need to count too).
sig.visit_with(&mut VisitOpaqueTypes {
tcx,
parent_def_id,
in_scope_parameters,
seen: Default::default(),
// Lazily compute these two, since they're likely a bit expensive.
variances: LazyCell::new(|| {
let mut functional_variances = FunctionalVariances {
tcx: tcx,
variances: FxHashMap::default(),
ambient_variance: ty::Covariant,
generics: tcx.generics_of(parent_def_id),
};
let _ = functional_variances.relate(sig, sig);
functional_variances.variances
}),
outlives_env: LazyCell::new(|| {
let param_env = tcx.param_env(parent_def_id);
let infcx = tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(&infcx);
let assumed_wf_tys = ocx.assumed_wf_types(param_env, parent_def_id).unwrap_or_default();
let implied_bounds =
infcx.implied_bounds_tys_compat(param_env, parent_def_id, &assumed_wf_tys, false);
OutlivesEnvironment::with_bounds(param_env, implied_bounds)
}),
});
}

struct VisitOpaqueTypes<'tcx> {
struct VisitOpaqueTypes<'tcx, VarFn, OutlivesFn> {
tcx: TyCtxt<'tcx>,
parent_def_id: LocalDefId,
in_scope_parameters: FxIndexSet<DefId>,
in_scope_parameters: FxIndexMap<DefId, ParamKind>,
variances: LazyCell<FxHashMap<DefId, ty::Variance>, VarFn>,
outlives_env: LazyCell<OutlivesEnvironment<'tcx>, OutlivesFn>,
seen: FxIndexSet<LocalDefId>,
}

impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
impl<'tcx, VarFn, OutlivesFn> TypeVisitor<TyCtxt<'tcx>>
for VisitOpaqueTypes<'tcx, VarFn, OutlivesFn>
where
VarFn: FnOnce() -> FxHashMap<DefId, ty::Variance>,
OutlivesFn: FnOnce() -> OutlivesEnvironment<'tcx>,
{
fn visit_binder<T: TypeVisitable<TyCtxt<'tcx>>>(
&mut self,
t: &ty::Binder<'tcx, T>,
Expand All @@ -163,8 +220,8 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
ty::BoundVariableKind::Region(ty::BoundRegionKind::BrNamed(def_id, ..))
| ty::BoundVariableKind::Ty(ty::BoundTyKind::Param(def_id, _)) => {
added.push(def_id);
let unique = self.in_scope_parameters.insert(def_id);
assert!(unique);
let unique = self.in_scope_parameters.insert(def_id, ParamKind::Late);
assert_eq!(unique, None);
}
_ => {
self.tcx.dcx().span_delayed_bug(
Expand Down Expand Up @@ -209,6 +266,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
{
// Compute the set of args that are captured by the opaque...
let mut captured = FxIndexSet::default();
let mut captured_regions = FxIndexSet::default();
let variances = self.tcx.variances_of(opaque_def_id);
let mut current_def_id = Some(opaque_def_id.to_def_id());
while let Some(def_id) = current_def_id {
Expand All @@ -218,25 +276,60 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
if variances[param.index as usize] != ty::Invariant {
continue;
}

let arg = opaque_ty.args[param.index as usize];
// We need to turn all `ty::Param`/`ConstKind::Param` and
// `ReEarlyParam`/`ReBound` into def ids.
captured.insert(extract_def_id_from_arg(
self.tcx,
generics,
opaque_ty.args[param.index as usize],
));
captured.insert(extract_def_id_from_arg(self.tcx, generics, arg));

captured_regions.extend(arg.as_region());
}
current_def_id = generics.parent;
}

// Compute the set of in scope params that are not captured. Get their spans,
// since that's all we really care about them for emitting the diagnostic.
let uncaptured_spans: Vec<_> = self
let mut uncaptured_args: FxIndexSet<_> = self
.in_scope_parameters
.iter()
.filter(|def_id| !captured.contains(*def_id))
.map(|def_id| self.tcx.def_span(def_id))
.filter(|&(def_id, _)| !captured.contains(def_id))
.collect();

// These are args that we know are likely fine to "overcapture", since they can be
// contravariantly shortened to one of the already-captured lifetimes that they
// outlive.
let covariant_long_args: FxIndexSet<_> = uncaptured_args
.iter()
.copied()
.filter(|&(def_id, kind)| {
let Some(ty::Bivariant | ty::Contravariant) = self.variances.get(def_id) else {
return false;
};
let DefKind::LifetimeParam = self.tcx.def_kind(def_id) else {
return false;
};
let uncaptured = match *kind {
ParamKind::Early(name, index) => ty::Region::new_early_param(
self.tcx,
ty::EarlyParamRegion { name, index },
),
ParamKind::Free(def_id, name) => ty::Region::new_late_param(
self.tcx,
self.parent_def_id.to_def_id(),
ty::BoundRegionKind::BrNamed(def_id, name),
),
ParamKind::Late => return false,
};
// Does this region outlive any captured region?
captured_regions.iter().any(|r| {
self.outlives_env
.free_region_map()
.sub_free_regions(self.tcx, *r, uncaptured)
})
})
.collect();
// We don't care to warn on these args.
uncaptured_args.retain(|arg| !covariant_long_args.contains(arg));

let opaque_span = self.tcx.def_span(opaque_def_id);
let new_capture_rules =
Expand All @@ -246,7 +339,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
// `use<>` syntax on it, and we're < edition 2024, then warn the user.
if !new_capture_rules
&& !opaque.bounds.iter().any(|bound| matches!(bound, hir::GenericBound::Use(..)))
&& !uncaptured_spans.is_empty()
&& !uncaptured_args.is_empty()
{
let suggestion = if let Ok(snippet) =
self.tcx.sess.source_map().span_to_snippet(opaque_span)
Expand Down Expand Up @@ -274,6 +367,11 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
None
};

let uncaptured_spans: Vec<_> = uncaptured_args
.into_iter()
.map(|(def_id, _)| self.tcx.def_span(def_id))
.collect();

self.tcx.emit_node_span_lint(
IMPL_TRAIT_OVERCAPTURES,
self.tcx.local_def_id_to_hir_id(opaque_def_id),
Expand Down Expand Up @@ -327,7 +425,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
if self
.in_scope_parameters
.iter()
.all(|def_id| explicitly_captured.contains(def_id))
.all(|(def_id, _)| explicitly_captured.contains(def_id))
{
self.tcx.emit_node_span_lint(
IMPL_TRAIT_REDUNDANT_CAPTURES,
Expand Down Expand Up @@ -396,7 +494,11 @@ fn extract_def_id_from_arg<'tcx>(
ty::ReBound(
_,
ty::BoundRegion { kind: ty::BoundRegionKind::BrNamed(def_id, ..), .. },
) => def_id,
)
| ty::ReLateParam(ty::LateParamRegion {
scope: _,
bound_region: ty::BoundRegionKind::BrNamed(def_id, ..),
}) => def_id,
_ => unreachable!(),
},
ty::GenericArgKind::Type(ty) => {
Expand All @@ -413,3 +515,106 @@ fn extract_def_id_from_arg<'tcx>(
}
}
}

/// Computes the variances of regions that appear in the type, but considering
/// late-bound regions too, which don't have their variance computed usually.
///
/// Like generalization, this is a unary operation implemented on top of the binary
/// relation infrastructure, mostly because it's much easier to have the relation
/// track the variance for you, rather than having to do it yourself.
struct FunctionalVariances<'tcx> {
tcx: TyCtxt<'tcx>,
variances: FxHashMap<DefId, ty::Variance>,
ambient_variance: ty::Variance,
generics: &'tcx ty::Generics,
}

impl<'tcx> TypeRelation<TyCtxt<'tcx>> for FunctionalVariances<'tcx> {
fn cx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn relate_with_variance<T: ty::relate::Relate<TyCtxt<'tcx>>>(
&mut self,
variance: rustc_type_ir::Variance,
_: ty::VarianceDiagInfo<TyCtxt<'tcx>>,
a: T,
b: T,
) -> RelateResult<'tcx, T> {
let old_variance = self.ambient_variance;
self.ambient_variance = self.ambient_variance.xform(variance);
self.relate(a, b)?;
self.ambient_variance = old_variance;
Ok(a)
}

fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
structurally_relate_tys(self, a, b)?;
Ok(a)
}

fn regions(
&mut self,
a: ty::Region<'tcx>,
_: ty::Region<'tcx>,
) -> RelateResult<'tcx, ty::Region<'tcx>> {
let def_id = match *a {
ty::ReEarlyParam(ebr) => self.generics.region_param(ebr, self.tcx).def_id,
ty::ReBound(
_,
ty::BoundRegion { kind: ty::BoundRegionKind::BrNamed(def_id, ..), .. },
)
| ty::ReLateParam(ty::LateParamRegion {
scope: _,
bound_region: ty::BoundRegionKind::BrNamed(def_id, ..),
}) => def_id,
_ => {
return Ok(a);
}
};

if let Some(variance) = self.variances.get_mut(&def_id) {
*variance = unify(*variance, self.ambient_variance);
} else {
self.variances.insert(def_id, self.ambient_variance);
}

Ok(a)
}

fn consts(
&mut self,
a: ty::Const<'tcx>,
b: ty::Const<'tcx>,
) -> RelateResult<'tcx, ty::Const<'tcx>> {
structurally_relate_consts(self, a, b)?;
Ok(a)
}

fn binders<T>(
&mut self,
a: ty::Binder<'tcx, T>,
b: ty::Binder<'tcx, T>,
) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
where
T: Relate<TyCtxt<'tcx>>,
{
self.relate(a.skip_binder(), b.skip_binder())?;
Ok(a)
}
}

/// What is the variance that satisfies the two variances?
fn unify(a: ty::Variance, b: ty::Variance) -> ty::Variance {
match (a, b) {
// Bivariance is lattice bottom.
(ty::Bivariant, other) | (other, ty::Bivariant) => other,
// Invariant is lattice top.
(ty::Invariant, _) | (_, ty::Invariant) => ty::Invariant,
// If type is required to be covariant and contravariant, then it's invariant.
(ty::Contravariant, ty::Covariant) | (ty::Covariant, ty::Contravariant) => ty::Invariant,
// Otherwise, co + co = co, contra + contra = contra.
(ty::Contravariant, ty::Contravariant) => ty::Contravariant,
(ty::Covariant, ty::Covariant) => ty::Covariant,
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//@ check-pass

#![deny(impl_trait_overcaptures)]

struct Ctxt<'tcx>(&'tcx ());

// In `compute`, we don't care that we're "overcapturing" `'tcx`
// in edition 2024, because it can be shortened at the call site
// and we know it outlives `'_`.

impl<'tcx> Ctxt<'tcx> {
fn compute(&self) -> impl Sized + '_ {}
}

fn main() {}

0 comments on commit f8f4d50

Please sign in to comment.