Skip to content

Commit

Permalink
Provide suggestion to dereference closure tail if appropriate
Browse files Browse the repository at this point in the history
When encoutnering a case like

```rust
//@ run-rustfix
use std::collections::HashMap;

fn main() {
    let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3];

    let mut counts = HashMap::new();
    for num in vs {
        let count = counts.entry(num).or_insert(0);
        *count += 1;
    }

    let _ = counts.iter().max_by_key(|(_, v)| v);
```
produce the following suggestion
```
error: lifetime may not live long enough
  --> $DIR/return-value-lifetime-error.rs:13:47
   |
LL |     let _ = counts.iter().max_by_key(|(_, v)| v);
   |                                       ------- ^ returning this value requires that `'1` must outlive `'2`
   |                                       |     |
   |                                       |     return type of closure is &'2 &i32
   |                                       has type `&'1 (&i32, &i32)`
   |
help: dereference the return value
   |
LL |     let _ = counts.iter().max_by_key(|(_, v)| **v);
   |                                               ++
```

Fix #50195.
  • Loading branch information
estebank committed Mar 8, 2024
1 parent 735f758 commit f06c0a8
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3565,6 +3565,7 @@ dependencies = [
"rustc_fluent_macro",
"rustc_graphviz",
"rustc_hir",
"rustc_hir_typeck",
"rustc_index",
"rustc_infer",
"rustc_lexer",
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_borrowck/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rustc_errors = { path = "../rustc_errors" }
rustc_fluent_macro = { path = "../rustc_fluent_macro" }
rustc_graphviz = { path = "../rustc_graphviz" }
rustc_hir = { path = "../rustc_hir" }
rustc_hir_typeck = { path = "../rustc_hir_typeck" }
rustc_index = { path = "../rustc_index" }
rustc_infer = { path = "../rustc_infer" }
rustc_lexer = { path = "../rustc_lexer" }
Expand Down
189 changes: 189 additions & 0 deletions compiler/rustc_borrowck/src/diagnostics/region_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use rustc_hir::GenericBound::Trait;
use rustc_hir::QPath::Resolved;
use rustc_hir::WherePredicate::BoundPredicate;
use rustc_hir::{PolyTraitRef, TyKind, WhereBoundPredicate};
use rustc_hir_typeck::{FnCtxt, Inherited};
use rustc_infer::infer::{
error_reporting::nice_region_error::{
self, find_anon_type, find_param_with_region, suggest_adding_lifetime_params,
Expand All @@ -20,12 +21,17 @@ use rustc_infer::infer::{
};
use rustc_middle::hir::place::PlaceBase;
use rustc_middle::mir::{ConstraintCategory, ReturnConstraint};
use rustc_middle::traits::ObligationCause;
use rustc_middle::ty::GenericArgs;
use rustc_middle::ty::TypeVisitor;
use rustc_middle::ty::{self, RegionVid, Ty};
use rustc_middle::ty::{Region, TyCtxt};
use rustc_span::symbol::{kw, Ident};
use rustc_span::Span;
use rustc_trait_selection::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use rustc_trait_selection::infer::InferCtxtExt;
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _;
use rustc_trait_selection::traits::Obligation;

use crate::borrowck_errors;
use crate::session_diagnostics::{
Expand Down Expand Up @@ -810,6 +816,7 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> {
self.add_static_impl_trait_suggestion(&mut diag, *fr, fr_name, *outlived_fr);
self.suggest_adding_lifetime_params(&mut diag, *fr, *outlived_fr);
self.suggest_move_on_borrowing_closure(&mut diag);
self.suggest_deref_closure_value(&mut diag);

diag
}
Expand Down Expand Up @@ -1039,6 +1046,188 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> {
suggest_adding_lifetime_params(self.infcx.tcx, sub, ty_sup, ty_sub, diag);
}

#[allow(rustc::diagnostic_outside_of_impl)]
#[allow(rustc::untranslatable_diagnostic)] // FIXME: make this translatable
/// When encountering a lifetime error caused by the return type of a closure, check the
/// corresponding trait bound and see if dereferencing the closure return value would satisfy
/// them. If so, we produce a structured suggestion.
fn suggest_deref_closure_value(&self, diag: &mut Diag<'_>) {
let tcx = self.infcx.tcx;
let map = tcx.hir();

// Get the closure return value and type.
let body_id = map.body_owned_by(self.mir_def_id());
let body = &map.body(body_id);
let value = &body.value.peel_blocks();
let hir::Node::Expr(closure_expr) = tcx.hir_node_by_def_id(self.mir_def_id()) else {
return;
};
let fn_call_id = tcx.parent_hir_id(self.mir_hir_id());
let hir::Node::Expr(expr) = tcx.hir_node(fn_call_id) else { return };
let def_id = map.enclosing_body_owner(fn_call_id);
let tables = tcx.typeck(def_id);
let Some(return_value_ty) = tables.node_type_opt(value.hir_id) else { return };
let return_value_ty = self.infcx.resolve_vars_if_possible(return_value_ty);

// We don't use `ty.peel_refs()` to get the number of `*`s needed to get the root type.
let mut ty = return_value_ty;
let mut count = 0;
while let ty::Ref(_, t, _) = ty.kind() {
ty = *t;
count += 1;
}
if !self.infcx.type_is_copy_modulo_regions(self.param_env, ty) {
return;
}

// Build a new closure where the return type is an owned value, instead of a ref.
let Some(ty::Closure(did, args)) =
tables.node_type_opt(closure_expr.hir_id).as_ref().map(|ty| ty.kind())
else {
return;
};
let sig = args.as_closure().sig();
let closure_sig_as_fn_ptr_ty = Ty::new_fn_ptr(
tcx,
sig.map_bound(|s| {
let unsafety = hir::Unsafety::Normal;
use rustc_target::spec::abi;
tcx.mk_fn_sig(
[s.inputs()[0]],
s.output().peel_refs(),
s.c_variadic,
unsafety,
abi::Abi::Rust,
)
}),
);
let parent_args = GenericArgs::identity_for_item(
tcx,
tcx.typeck_root_def_id(self.mir_def_id().to_def_id()),
);
let closure_kind = args.as_closure().kind();
let closure_kind_ty = Ty::from_closure_kind(tcx, closure_kind);
let tupled_upvars_ty = self.infcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::ClosureSynthetic,
span: closure_expr.span,
});
let closure_args = ty::ClosureArgs::new(
tcx,
ty::ClosureArgsParts {
parent_args,
closure_kind_ty,
closure_sig_as_fn_ptr_ty,
tupled_upvars_ty,
},
);
let closure_ty = Ty::new_closure(tcx, *did, closure_args.args);
let closure_ty = tcx.erase_regions(closure_ty);

let hir::ExprKind::MethodCall(segment, rcvr, args, _) = expr.kind else { return };
let Some(pos) = args
.iter()
.enumerate()
.find(|(_, arg)| arg.hir_id == closure_expr.hir_id)
.map(|(i, _)| i)
else {
return;
};
// The found `Self` type of the method call.
let Some(possible_rcvr_ty) = tables.node_type_opt(rcvr.hir_id) else { return };

// The `MethodCall` expression is `Res::Err`, so we search for the method ion the `rcvr_ty`.
let inh = Inherited::new(tcx, self.mir_def_id());
let fn_ctxt = FnCtxt::new(&inh, self.param_env, self.mir_def_id());
let Ok(method) =
fn_ctxt.lookup_method_for_diagnostic(possible_rcvr_ty, segment, expr.span, expr, rcvr)
else {
return;
};

// Get the arguments for the found method, only specifying that `Self` is the receiver type.
let args = GenericArgs::for_item(tcx, method.def_id, |param, _| {
if param.index == 0 {
possible_rcvr_ty.into()
} else {
self.infcx.var_for_def(expr.span, param)
}
});

let preds = tcx.predicates_of(method.def_id).instantiate(tcx, args);
// Get the type for the parameter corresponding to the argument the closure with the
// lifetime error we had.
let Some(input) = tcx
.fn_sig(method.def_id)
.instantiate_identity()
.inputs()
.skip_binder()
// Methods have a `self` arg, so `pos` is actually `+ 1` to match the method call arg.
.get(pos + 1)
else {
return;
};

let cause = ObligationCause::misc(expr.span, self.mir_def_id());

enum CanSuggest {
Yes,
No,
Maybe,
}
let mut can_suggest = CanSuggest::Maybe;
for pred in preds.predicates {
match tcx.liberate_late_bound_regions(self.mir_def_id().into(), pred.kind()) {
ty::ClauseKind::Trait(pred)
if self.infcx.can_eq(self.param_env, pred.self_ty(), *input)
&& [
tcx.lang_items().fn_trait(),
tcx.lang_items().fn_mut_trait(),
tcx.lang_items().fn_once_trait(),
]
.contains(&Some(pred.def_id())) =>
{
// This predicate is an `Fn*` trait and corresponds to the argument with the
// closure that failed the lifetime check. We verify that the arguments will
// continue to match (which didn't change, so they should, and this be a no-op).
let pred = pred.with_self_ty(tcx, closure_ty);
let o = Obligation::new(tcx, cause.clone(), self.param_env, pred);
if !self.infcx.predicate_may_hold(&o) {
// The closure we have doesn't have the right arguments for the trait bound
can_suggest = CanSuggest::No;
} else if let CanSuggest::Maybe = can_suggest {
// The closure has the right arguments
can_suggest = CanSuggest::Yes;
}
}
ty::ClauseKind::Projection(proj)
if self.infcx.can_eq(self.param_env, proj.projection_ty.self_ty(), *input)
&& tcx.lang_items().fn_once_output() == Some(proj.projection_ty.def_id) =>
{
// Verify that `<[closure@...] as FnOnce>::Output` matches the expected
// `Output` from the trait bound.
let proj = proj.with_self_ty(tcx, closure_ty);
let o = Obligation::new(tcx, cause.clone(), self.param_env, proj);
if !self.infcx.predicate_may_hold(&o) {
// Return type doesn't match.
can_suggest = CanSuggest::No;
} else if let CanSuggest::Maybe = can_suggest {
// Return type matches, we can suggest dereferencing the closure's value.
can_suggest = CanSuggest::Yes;
}
}
_ => {}
}
}
if let CanSuggest::Yes = can_suggest {
diag.span_suggestion_verbose(
value.span.shrink_to_lo(),
"dereference the return value",
"*".repeat(count),
Applicability::MachineApplicable,
);
}
}

#[allow(rustc::diagnostic_outside_of_impl)]
#[allow(rustc::untranslatable_diagnostic)] // FIXME: make this translatable
fn suggest_move_on_borrowing_closure(&self, diag: &mut Diag<'_>) {
Expand Down
16 changes: 16 additions & 0 deletions tests/ui/closures/return-value-lifetime-error.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//@ run-rustfix
use std::collections::HashMap;

fn main() {
let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3];

let mut counts = HashMap::new();
for num in vs {
let count = counts.entry(num).or_insert(0);
*count += 1;
}

let _ = counts.iter().max_by_key(|(_, v)| **v);
//~^ ERROR lifetime may not live long enough
//~| HELP dereference the return value
}
16 changes: 16 additions & 0 deletions tests/ui/closures/return-value-lifetime-error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//@ run-rustfix
use std::collections::HashMap;

fn main() {
let vs = vec![0, 0, 1, 1, 3, 4, 5, 6, 3, 3, 3];

let mut counts = HashMap::new();
for num in vs {
let count = counts.entry(num).or_insert(0);
*count += 1;
}

let _ = counts.iter().max_by_key(|(_, v)| v);
//~^ ERROR lifetime may not live long enough
//~| HELP dereference the return value
}
16 changes: 16 additions & 0 deletions tests/ui/closures/return-value-lifetime-error.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
error: lifetime may not live long enough
--> $DIR/return-value-lifetime-error.rs:13:47
|
LL | let _ = counts.iter().max_by_key(|(_, v)| v);
| ------- ^ returning this value requires that `'1` must outlive `'2`
| | |
| | return type of closure is &'2 &i32
| has type `&'1 (&i32, &i32)`
|
help: dereference the return value
|
LL | let _ = counts.iter().max_by_key(|(_, v)| **v);
| ++

error: aborting due to 1 previous error

0 comments on commit f06c0a8

Please sign in to comment.