Skip to content

Commit

Permalink
Auto merge of #17940 - ChayimFriedman2:closure-to-fn, r=Veykril
Browse files Browse the repository at this point in the history
feat: Create an assist to convert closure to freestanding fn

The assist converts all captures to parameters.

Closes #17920.

This was more work than I though, since it has to handle a bunch of edge cases...

Based on #17941. Needs to merge it first.
  • Loading branch information
bors committed Aug 29, 2024
2 parents 4068232 + 0e4f4d3 commit 07a66c4
Show file tree
Hide file tree
Showing 6 changed files with 1,511 additions and 13 deletions.
113 changes: 101 additions & 12 deletions crates/hir-ty/src/infer/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use hir_expand::name::Name;
use intern::sym;
use rustc_hash::FxHashMap;
use smallvec::{smallvec, SmallVec};
use stdx::never;
use stdx::{format_to, never};
use syntax::utils::is_raw_identifier;

use crate::{
db::{HirDatabase, InternedClosure},
Expand Down Expand Up @@ -251,6 +252,11 @@ impl CapturedItem {
self.place.local
}

/// Returns whether this place has any field (aka. non-deref) projections.
pub fn has_field_projections(&self) -> bool {
self.place.projections.iter().any(|it| !matches!(it, ProjectionElem::Deref))
}

pub fn ty(&self, subst: &Substitution) -> Ty {
self.ty.clone().substitute(Interner, utils::ClosureSubst(subst).parent_subst())
}
Expand All @@ -263,6 +269,97 @@ impl CapturedItem {
self.span_stacks.iter().map(|stack| *stack.last().expect("empty span stack")).collect()
}

/// Converts the place to a name that can be inserted into source code.
pub fn place_to_name(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
let body = db.body(owner);
let mut result = body[self.place.local].name.unescaped().display(db.upcast()).to_string();
for proj in &self.place.projections {
match proj {
ProjectionElem::Deref => {}
ProjectionElem::Field(Either::Left(f)) => {
match &*f.parent.variant_data(db.upcast()) {
VariantData::Record(fields) => {
result.push('_');
result.push_str(fields[f.local_id].name.as_str())
}
VariantData::Tuple(fields) => {
let index = fields.iter().position(|it| it.0 == f.local_id);
if let Some(index) = index {
format_to!(result, "_{index}");
}
}
VariantData::Unit => {}
}
}
ProjectionElem::Field(Either::Right(f)) => format_to!(result, "_{}", f.index),
&ProjectionElem::ClosureField(field) => format_to!(result, "_{field}"),
ProjectionElem::Index(_)
| ProjectionElem::ConstantIndex { .. }
| ProjectionElem::Subslice { .. }
| ProjectionElem::OpaqueCast(_) => {
never!("Not happen in closure capture");
continue;
}
}
}
if is_raw_identifier(&result, db.crate_graph()[owner.module(db.upcast()).krate()].edition) {
result.insert_str(0, "r#");
}
result
}

pub fn display_place_source_code(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
let body = db.body(owner);
let krate = owner.krate(db.upcast());
let edition = db.crate_graph()[krate].edition;
let mut result = body[self.place.local].name.display(db.upcast(), edition).to_string();
for proj in &self.place.projections {
match proj {
// In source code autoderef kicks in.
ProjectionElem::Deref => {}
ProjectionElem::Field(Either::Left(f)) => {
let variant_data = f.parent.variant_data(db.upcast());
match &*variant_data {
VariantData::Record(fields) => format_to!(
result,
".{}",
fields[f.local_id].name.display(db.upcast(), edition)
),
VariantData::Tuple(fields) => format_to!(
result,
".{}",
fields.iter().position(|it| it.0 == f.local_id).unwrap_or_default()
),
VariantData::Unit => {}
}
}
ProjectionElem::Field(Either::Right(f)) => {
let field = f.index;
format_to!(result, ".{field}");
}
&ProjectionElem::ClosureField(field) => {
format_to!(result, ".{field}");
}
ProjectionElem::Index(_)
| ProjectionElem::ConstantIndex { .. }
| ProjectionElem::Subslice { .. }
| ProjectionElem::OpaqueCast(_) => {
never!("Not happen in closure capture");
continue;
}
}
}
let final_derefs_count = self
.place
.projections
.iter()
.rev()
.take_while(|proj| matches!(proj, ProjectionElem::Deref))
.count();
result.insert_str(0, &"*".repeat(final_derefs_count));
result
}

pub fn display_place(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
let body = db.body(owner);
let krate = owner.krate(db.upcast());
Expand Down Expand Up @@ -451,14 +548,6 @@ impl InferenceContext<'_> {
});
}

fn is_ref_span(&self, span: MirSpan) -> bool {
match span {
MirSpan::ExprId(expr) => matches!(self.body[expr], Expr::Ref { .. }),
MirSpan::BindingId(_) => true,
MirSpan::PatId(_) | MirSpan::SelfParam | MirSpan::Unknown => false,
}
}

fn truncate_capture_spans(&self, capture: &mut CapturedItemWithoutTy, mut truncate_to: usize) {
// The first span is the identifier, and it must always remain.
truncate_to += 1;
Expand All @@ -467,15 +556,15 @@ impl InferenceContext<'_> {
let mut actual_truncate_to = 0;
for &span in &*span_stack {
actual_truncate_to += 1;
if !self.is_ref_span(span) {
if !span.is_ref_span(self.body) {
remained -= 1;
if remained == 0 {
break;
}
}
}
if actual_truncate_to < span_stack.len()
&& self.is_ref_span(span_stack[actual_truncate_to])
&& span_stack[actual_truncate_to].is_ref_span(self.body)
{
// Include the ref operator if there is one, we will fix it later (in `strip_captures_ref_span()`) if it's incorrect.
actual_truncate_to += 1;
Expand Down Expand Up @@ -1147,7 +1236,7 @@ impl InferenceContext<'_> {
for capture in &mut captures {
if matches!(capture.kind, CaptureKind::ByValue) {
for span_stack in &mut capture.span_stacks {
if self.is_ref_span(span_stack[span_stack.len() - 1]) {
if span_stack[span_stack.len() - 1].is_ref_span(self.body) {
span_stack.truncate(span_stack.len() - 1);
}
}
Expand Down
17 changes: 16 additions & 1 deletion crates/hir-ty/src/mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use base_db::CrateId;
use chalk_ir::Mutability;
use either::Either;
use hir_def::{
hir::{BindingId, Expr, ExprId, Ordering, PatId},
body::Body,
hir::{BindingAnnotation, BindingId, Expr, ExprId, Ordering, PatId},
DefWithBodyId, FieldId, StaticId, TupleFieldId, UnionId, VariantId,
};
use la_arena::{Arena, ArenaMap, Idx, RawIdx};
Expand Down Expand Up @@ -1174,6 +1175,20 @@ pub enum MirSpan {
Unknown,
}

impl MirSpan {
pub fn is_ref_span(&self, body: &Body) -> bool {
match *self {
MirSpan::ExprId(expr) => matches!(body[expr], Expr::Ref { .. }),
// FIXME: Figure out if this is correct wrt. match ergonomics.
MirSpan::BindingId(binding) => matches!(
body.bindings[binding].mode,
BindingAnnotation::Ref | BindingAnnotation::RefMut
),
MirSpan::PatId(_) | MirSpan::SelfParam | MirSpan::Unknown => false,
}
}
}

impl_from!(ExprId, PatId for MirSpan);

impl From<&ExprId> for MirSpan {
Expand Down
87 changes: 87 additions & 0 deletions crates/hir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ use hir_ty::{
use itertools::Itertools;
use nameres::diagnostics::DefDiagnosticKind;
use rustc_hash::FxHashSet;
use smallvec::SmallVec;
use span::{Edition, EditionedFileId, FileId, MacroCallId, SyntaxContextId};
use stdx::{impl_from, never};
use syntax::{
Expand Down Expand Up @@ -4113,6 +4114,15 @@ impl ClosureCapture {
Local { parent: self.owner, binding_id: self.capture.local() }
}

/// Returns whether this place has any field (aka. non-deref) projections.
pub fn has_field_projections(&self) -> bool {
self.capture.has_field_projections()
}

pub fn usages(&self) -> CaptureUsages {
CaptureUsages { parent: self.owner, spans: self.capture.spans() }
}

pub fn kind(&self) -> CaptureKind {
match self.capture.kind() {
hir_ty::CaptureKind::ByRef(
Expand All @@ -4128,6 +4138,15 @@ impl ClosureCapture {
}
}

/// Converts the place to a name that can be inserted into source code.
pub fn place_to_name(&self, db: &dyn HirDatabase) -> String {
self.capture.place_to_name(self.owner, db)
}

pub fn display_place_source_code(&self, db: &dyn HirDatabase) -> String {
self.capture.display_place_source_code(self.owner, db)
}

pub fn display_place(&self, db: &dyn HirDatabase) -> String {
self.capture.display_place(self.owner, db)
}
Expand All @@ -4141,6 +4160,74 @@ pub enum CaptureKind {
Move,
}

#[derive(Debug, Clone)]
pub struct CaptureUsages {
parent: DefWithBodyId,
spans: SmallVec<[mir::MirSpan; 3]>,
}

impl CaptureUsages {
pub fn sources(&self, db: &dyn HirDatabase) -> Vec<CaptureUsageSource> {
let (body, source_map) = db.body_with_source_map(self.parent);
let mut result = Vec::with_capacity(self.spans.len());
for &span in self.spans.iter() {
let is_ref = span.is_ref_span(&body);
match span {
mir::MirSpan::ExprId(expr) => {
if let Ok(expr) = source_map.expr_syntax(expr) {
result.push(CaptureUsageSource {
is_ref,
source: expr.map(AstPtr::wrap_left),
})
}
}
mir::MirSpan::PatId(pat) => {
if let Ok(pat) = source_map.pat_syntax(pat) {
result.push(CaptureUsageSource {
is_ref,
source: pat.map(AstPtr::wrap_right),
});
}
}
mir::MirSpan::BindingId(binding) => result.extend(
source_map
.patterns_for_binding(binding)
.iter()
.filter_map(|&pat| source_map.pat_syntax(pat).ok())
.map(|pat| CaptureUsageSource {
is_ref,
source: pat.map(AstPtr::wrap_right),
}),
),
mir::MirSpan::SelfParam | mir::MirSpan::Unknown => {
unreachable!("invalid capture usage span")
}
}
}
result
}
}

#[derive(Debug)]
pub struct CaptureUsageSource {
is_ref: bool,
source: InFile<AstPtr<Either<ast::Expr, ast::Pat>>>,
}

impl CaptureUsageSource {
pub fn source(&self) -> AstPtr<Either<ast::Expr, ast::Pat>> {
self.source.value
}

pub fn file_id(&self) -> HirFileId {
self.source.file_id
}

pub fn is_ref(&self) -> bool {
self.is_ref
}
}

#[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub struct Type {
env: Arc<TraitEnvironment>,
Expand Down
Loading

0 comments on commit 07a66c4

Please sign in to comment.