Skip to content

Commit

Permalink
Create an assist to convert closure to freestanding fn
Browse files Browse the repository at this point in the history
The assist converts all captures to parameters.
  • Loading branch information
ChayimFriedman2 committed Aug 22, 2024
1 parent a28eca6 commit 0f05961
Show file tree
Hide file tree
Showing 6 changed files with 1,475 additions and 12 deletions.
117 changes: 106 additions & 11 deletions crates/hir-ty/src/infer/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use intern::sym;
use rustc_hash::FxHashMap;
use smallvec::{smallvec, SmallVec};
use stdx::never;
use syntax::utils::is_raw_identifier;

use crate::{
db::{HirDatabase, InternedClosure},
Expand Down Expand Up @@ -242,6 +243,11 @@ impl CapturedItem {
self.place.local
}

/// Returns whether this place has any field (aka. non-deref) projections.
pub fn has_field_projections(&self) -> bool {
self.place.projections.iter().any(|it| !matches!(it, ProjectionElem::Deref))
}

pub fn ty(&self, subst: &Substitution) -> Ty {
self.ty.clone().substitute(Interner, utils::ClosureSubst(subst).parent_subst())
}
Expand All @@ -254,6 +260,103 @@ impl CapturedItem {
self.span_stacks.iter().map(|stack| *stack.last().expect("empty span stack")).collect()
}

/// Converts the place to a name that can be inserted into source code.
pub fn place_to_name(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
use std::fmt::Write;

let body = db.body(owner);
let mut result = body[self.place.local].name.unescaped().display(db.upcast()).to_string();
for proj in &self.place.projections {
match proj {
ProjectionElem::Deref => {}
ProjectionElem::Field(Either::Left(f)) => {
match &*f.parent.variant_data(db.upcast()) {
VariantData::Record(fields) => {
result.push('_');
result.push_str(fields[f.local_id].name.as_str())
}
VariantData::Tuple(fields) => {
let index = fields.iter().position(|it| it.0 == f.local_id);
if let Some(index) = index {
write!(result, "_{index}").unwrap();
}
}
VariantData::Unit => {}
}
}
ProjectionElem::Field(Either::Right(f)) => write!(result, "_{}", f.index).unwrap(),
&ProjectionElem::ClosureField(field) => write!(result, "_{field}").unwrap(),
ProjectionElem::Index(_)
| ProjectionElem::ConstantIndex { .. }
| ProjectionElem::Subslice { .. }
| ProjectionElem::OpaqueCast(_) => {
never!("Not happen in closure capture");
continue;
}
}
}
if is_raw_identifier(&result, db.crate_graph()[owner.module(db.upcast()).krate()].edition) {
result.insert_str(0, "r#");
}
result
}

pub fn display_place_source_code(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
use std::fmt::Write;

let body = db.body(owner);
let krate = owner.krate(db.upcast());
let edition = db.crate_graph()[krate].edition;
let mut result = body[self.place.local].name.display(db.upcast(), edition).to_string();
for proj in &self.place.projections {
match proj {
// In source code autoderef kicks in.
ProjectionElem::Deref => {}
ProjectionElem::Field(Either::Left(f)) => {
let variant_data = f.parent.variant_data(db.upcast());
match &*variant_data {
VariantData::Record(fields) => write!(
result,
".{}",
fields[f.local_id].name.display(db.upcast(), edition)
)
.unwrap(),
VariantData::Tuple(fields) => write!(
result,
".{}",
fields.iter().position(|it| it.0 == f.local_id).unwrap_or_default()
)
.unwrap(),
VariantData::Unit => {}
}
}
ProjectionElem::Field(Either::Right(f)) => {
let field = f.index;
write!(result, ".{field}").unwrap();
}
&ProjectionElem::ClosureField(field) => {
write!(result, ".{field}").unwrap();
}
ProjectionElem::Index(_)
| ProjectionElem::ConstantIndex { .. }
| ProjectionElem::Subslice { .. }
| ProjectionElem::OpaqueCast(_) => {
never!("Not happen in closure capture");
continue;
}
}
}
let final_derefs_count = self
.place
.projections
.iter()
.rev()
.take_while(|proj| matches!(proj, ProjectionElem::Deref))
.count();
result.insert_str(0, &"*".repeat(final_derefs_count));
result
}

pub fn display_place(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
let body = db.body(owner);
let krate = owner.krate(db.upcast());
Expand Down Expand Up @@ -442,14 +545,6 @@ impl InferenceContext<'_> {
});
}

fn is_ref_span(&self, span: MirSpan) -> bool {
match span {
MirSpan::ExprId(expr) => matches!(self.body[expr], Expr::Ref { .. }),
MirSpan::BindingId(_) => true,
MirSpan::PatId(_) | MirSpan::SelfParam | MirSpan::Unknown => false,
}
}

fn truncate_capture_spans(&self, capture: &mut CapturedItemWithoutTy, mut truncate_to: usize) {
// The first span is the identifier, and it must always remain.
truncate_to += 1;
Expand All @@ -458,15 +553,15 @@ impl InferenceContext<'_> {
let mut actual_truncate_to = 0;
for &span in &*span_stack {
actual_truncate_to += 1;
if !self.is_ref_span(span) {
if !span.is_ref_span(self.body) {
remained -= 1;
if remained == 0 {
break;
}
}
}
if actual_truncate_to < span_stack.len()
&& self.is_ref_span(span_stack[actual_truncate_to])
&& span_stack[actual_truncate_to].is_ref_span(self.body)
{
// Include the ref operator if there is one, we will fix it later (in `strip_captures_ref_span()`) if it's incorrect.
actual_truncate_to += 1;
Expand Down Expand Up @@ -1140,7 +1235,7 @@ impl InferenceContext<'_> {
for capture in &mut captures {
if matches!(capture.kind, CaptureKind::ByValue) {
for span_stack in &mut capture.span_stacks {
if self.is_ref_span(span_stack[span_stack.len() - 1]) {
if span_stack[span_stack.len() - 1].is_ref_span(self.body) {
span_stack.truncate(span_stack.len() - 1);
}
}
Expand Down
17 changes: 16 additions & 1 deletion crates/hir-ty/src/mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use base_db::CrateId;
use chalk_ir::Mutability;
use either::Either;
use hir_def::{
hir::{BindingId, Expr, ExprId, Ordering, PatId},
body::Body,
hir::{BindingAnnotation, BindingId, Expr, ExprId, Ordering, PatId},
DefWithBodyId, FieldId, StaticId, TupleFieldId, UnionId, VariantId,
};
use la_arena::{Arena, ArenaMap, Idx, RawIdx};
Expand Down Expand Up @@ -1174,6 +1175,20 @@ pub enum MirSpan {
Unknown,
}

impl MirSpan {
pub fn is_ref_span(&self, body: &Body) -> bool {
match *self {
MirSpan::ExprId(expr) => matches!(body[expr], Expr::Ref { .. }),
// FIXME: Figure out if this is correct wrt. match ergonomics.
MirSpan::BindingId(binding) => matches!(
body.bindings[binding].mode,
BindingAnnotation::Ref | BindingAnnotation::RefMut
),
MirSpan::PatId(_) | MirSpan::SelfParam | MirSpan::Unknown => false,
}
}
}

impl_from!(ExprId, PatId for MirSpan);

impl From<&ExprId> for MirSpan {
Expand Down
87 changes: 87 additions & 0 deletions crates/hir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ use hir_ty::{
use itertools::Itertools;
use nameres::diagnostics::DefDiagnosticKind;
use rustc_hash::FxHashSet;
use smallvec::SmallVec;
use span::{Edition, EditionedFileId, FileId, MacroCallId, SyntaxContextId};
use stdx::{impl_from, never};
use syntax::{
Expand Down Expand Up @@ -4114,6 +4115,15 @@ impl ClosureCapture {
Local { parent: self.owner, binding_id: self.capture.local() }
}

/// Returns whether this place has any field (aka. non-deref) projections.
pub fn has_field_projections(&self) -> bool {
self.capture.has_field_projections()
}

pub fn usages(&self) -> CaptureUsages {
CaptureUsages { parent: self.owner, spans: self.capture.spans() }
}

pub fn kind(&self) -> CaptureKind {
match self.capture.kind() {
hir_ty::CaptureKind::ByRef(
Expand All @@ -4129,6 +4139,15 @@ impl ClosureCapture {
}
}

/// Converts the place to a name that can be inserted into source code.
pub fn place_to_name(&self, db: &dyn HirDatabase) -> String {
self.capture.place_to_name(self.owner, db)
}

pub fn display_place_source_code(&self, db: &dyn HirDatabase) -> String {
self.capture.display_place_source_code(self.owner, db)
}

pub fn display_place(&self, db: &dyn HirDatabase) -> String {
self.capture.display_place(self.owner, db)
}
Expand All @@ -4142,6 +4161,74 @@ pub enum CaptureKind {
Move,
}

#[derive(Debug, Clone)]
pub struct CaptureUsages {
parent: DefWithBodyId,
spans: SmallVec<[mir::MirSpan; 3]>,
}

impl CaptureUsages {
pub fn sources(&self, db: &dyn HirDatabase) -> Vec<CaptureUsageSource> {
let (body, source_map) = db.body_with_source_map(self.parent);
let mut result = Vec::with_capacity(self.spans.len());
for &span in self.spans.iter() {
let is_ref = span.is_ref_span(&body);
match span {
mir::MirSpan::ExprId(expr) => {
if let Ok(expr) = source_map.expr_syntax(expr) {
result.push(CaptureUsageSource {
is_ref,
source: expr.map(AstPtr::wrap_left),
})
}
}
mir::MirSpan::PatId(pat) => {
if let Ok(pat) = source_map.pat_syntax(pat) {
result.push(CaptureUsageSource {
is_ref,
source: pat.map(AstPtr::wrap_right),
});
}
}
mir::MirSpan::BindingId(binding) => result.extend(
source_map
.patterns_for_binding(binding)
.iter()
.filter_map(|&pat| source_map.pat_syntax(pat).ok())
.map(|pat| CaptureUsageSource {
is_ref,
source: pat.map(AstPtr::wrap_right),
}),
),
mir::MirSpan::SelfParam | mir::MirSpan::Unknown => {
unreachable!("invalid capture usage span")
}
}
}
result
}
}

#[derive(Debug)]
pub struct CaptureUsageSource {
is_ref: bool,
source: InFile<AstPtr<Either<ast::Expr, ast::Pat>>>,
}

impl CaptureUsageSource {
pub fn source(&self) -> AstPtr<Either<ast::Expr, ast::Pat>> {
self.source.value
}

pub fn file_id(&self) -> HirFileId {
self.source.file_id
}

pub fn is_ref(&self) -> bool {
self.is_ref
}
}

#[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub struct Type {
env: Arc<TraitEnvironment>,
Expand Down
Loading

0 comments on commit 0f05961

Please sign in to comment.