Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: MIR episode 2 #14232

Merged
merged 9 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions crates/hir-def/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use syntax::{ast, AstPtr, SyntaxNode, SyntaxNodePtr};
use crate::{
attr::Attrs,
db::DefDatabase,
expr::{dummy_expr_id, Expr, ExprId, Label, LabelId, Pat, PatId},
expr::{dummy_expr_id, Binding, BindingId, Expr, ExprId, Label, LabelId, Pat, PatId},
item_scope::BuiltinShadowMode,
macro_id_to_def_id,
nameres::DefMap,
Expand Down Expand Up @@ -270,6 +270,7 @@ pub struct Mark {
pub struct Body {
pub exprs: Arena<Expr>,
pub pats: Arena<Pat>,
pub bindings: Arena<Binding>,
pub or_pats: FxHashMap<PatId, Arc<[PatId]>>,
pub labels: Arena<Label>,
/// The patterns for the function's parameters. While the parameter types are
Expand Down Expand Up @@ -435,13 +436,24 @@ impl Body {
}

fn shrink_to_fit(&mut self) {
let Self { _c: _, body_expr: _, block_scopes, or_pats, exprs, labels, params, pats } = self;
let Self {
_c: _,
body_expr: _,
block_scopes,
or_pats,
exprs,
labels,
params,
pats,
bindings,
} = self;
block_scopes.shrink_to_fit();
or_pats.shrink_to_fit();
exprs.shrink_to_fit();
labels.shrink_to_fit();
params.shrink_to_fit();
pats.shrink_to_fit();
bindings.shrink_to_fit();
}
}

Expand All @@ -451,6 +463,7 @@ impl Default for Body {
body_expr: dummy_expr_id(),
exprs: Default::default(),
pats: Default::default(),
bindings: Default::default(),
or_pats: Default::default(),
labels: Default::default(),
params: Default::default(),
Expand Down Expand Up @@ -484,6 +497,14 @@ impl Index<LabelId> for Body {
}
}

impl Index<BindingId> for Body {
type Output = Binding;

fn index(&self, b: BindingId) -> &Binding {
&self.bindings[b]
}
}

// FIXME: Change `node_` prefix to something more reasonable.
// Perhaps `expr_syntax` and `expr_id`?
impl BodySourceMap {
Expand Down
117 changes: 80 additions & 37 deletions crates/hir-def/src/body/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use la_arena::Arena;
use once_cell::unsync::OnceCell;
use profile::Count;
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use syntax::{
ast::{
self, ArrayExprKind, AstChildren, HasArgList, HasLoopBody, HasName, LiteralKind,
Expand All @@ -30,9 +31,9 @@ use crate::{
builtin_type::{BuiltinFloat, BuiltinInt, BuiltinUint},
db::DefDatabase,
expr::{
dummy_expr_id, Array, BindingAnnotation, ClosureKind, Expr, ExprId, FloatTypeWrapper,
Label, LabelId, Literal, MatchArm, Movability, Pat, PatId, RecordFieldPat, RecordLitField,
Statement,
dummy_expr_id, Array, Binding, BindingAnnotation, BindingId, ClosureKind, Expr, ExprId,
FloatTypeWrapper, Label, LabelId, Literal, MatchArm, Movability, Pat, PatId,
RecordFieldPat, RecordLitField, Statement,
},
item_scope::BuiltinShadowMode,
path::{GenericArgs, Path},
Expand Down Expand Up @@ -87,6 +88,7 @@ pub(super) fn lower(
body: Body {
exprs: Arena::default(),
pats: Arena::default(),
bindings: Arena::default(),
labels: Arena::default(),
params: Vec::new(),
body_expr: dummy_expr_id(),
Expand Down Expand Up @@ -116,6 +118,22 @@ struct ExprCollector<'a> {
is_lowering_generator: bool,
}

#[derive(Debug, Default)]
struct BindingList {
map: FxHashMap<Name, BindingId>,
}

impl BindingList {
fn find(
&mut self,
ec: &mut ExprCollector<'_>,
name: Name,
mode: BindingAnnotation,
) -> BindingId {
*self.map.entry(name).or_insert_with_key(|n| ec.alloc_binding(n.clone(), mode))
}
}

impl ExprCollector<'_> {
fn collect(
mut self,
Expand All @@ -127,17 +145,16 @@ impl ExprCollector<'_> {
param_list.self_param().filter(|_| attr_enabled.next().unwrap_or(false))
{
let ptr = AstPtr::new(&self_param);
let param_pat = self.alloc_pat(
Pat::Bind {
name: name![self],
mode: BindingAnnotation::new(
self_param.mut_token().is_some() && self_param.amp_token().is_none(),
false,
),
subpat: None,
},
Either::Right(ptr),
let binding_id = self.alloc_binding(
name![self],
BindingAnnotation::new(
self_param.mut_token().is_some() && self_param.amp_token().is_none(),
false,
),
);
let param_pat =
self.alloc_pat(Pat::Bind { id: binding_id, subpat: None }, Either::Right(ptr));
self.add_definition_to_binding(binding_id, param_pat);
self.body.params.push(param_pat);
}

Expand Down Expand Up @@ -179,6 +196,9 @@ impl ExprCollector<'_> {
id
}

fn alloc_binding(&mut self, name: Name, mode: BindingAnnotation) -> BindingId {
self.body.bindings.alloc(Binding { name, mode, definitions: SmallVec::new() })
}
fn alloc_pat(&mut self, pat: Pat, ptr: PatPtr) -> PatId {
let src = self.expander.to_source(ptr);
let id = self.make_pat(pat, src.clone());
Expand Down Expand Up @@ -804,7 +824,7 @@ impl ExprCollector<'_> {
}

fn collect_pat(&mut self, pat: ast::Pat) -> PatId {
let pat_id = self.collect_pat_(pat);
let pat_id = self.collect_pat_(pat, &mut BindingList::default());
for (_, pats) in self.name_to_pat_grouping.drain() {
let pats = Arc::<[_]>::from(pats);
self.body.or_pats.extend(pats.iter().map(|&pat| (pat, pats.clone())));
Expand All @@ -820,16 +840,18 @@ impl ExprCollector<'_> {
}
}

fn collect_pat_(&mut self, pat: ast::Pat) -> PatId {
fn collect_pat_(&mut self, pat: ast::Pat, binding_list: &mut BindingList) -> PatId {
let pattern = match &pat {
ast::Pat::IdentPat(bp) => {
let name = bp.name().map(|nr| nr.as_name()).unwrap_or_else(Name::missing);

let key = self.is_lowering_inside_or_pat.then(|| name.clone());
let annotation =
BindingAnnotation::new(bp.mut_token().is_some(), bp.ref_token().is_some());
let subpat = bp.pat().map(|subpat| self.collect_pat_(subpat));
let pattern = if annotation == BindingAnnotation::Unannotated && subpat.is_none() {
let subpat = bp.pat().map(|subpat| self.collect_pat_(subpat, binding_list));
let (binding, pattern) = if annotation == BindingAnnotation::Unannotated
&& subpat.is_none()
{
// This could also be a single-segment path pattern. To
// decide that, we need to try resolving the name.
let (resolved, _) = self.expander.def_map.resolve_path(
Expand All @@ -839,30 +861,37 @@ impl ExprCollector<'_> {
BuiltinShadowMode::Other,
);
match resolved.take_values() {
Some(ModuleDefId::ConstId(_)) => Pat::Path(name.into()),
Some(ModuleDefId::ConstId(_)) => (None, Pat::Path(name.into())),
Some(ModuleDefId::EnumVariantId(_)) => {
// this is only really valid for unit variants, but
// shadowing other enum variants with a pattern is
// an error anyway
Pat::Path(name.into())
(None, Pat::Path(name.into()))
}
Some(ModuleDefId::AdtId(AdtId::StructId(s)))
if self.db.struct_data(s).variant_data.kind() != StructKind::Record =>
{
// Funnily enough, record structs *can* be shadowed
// by pattern bindings (but unit or tuple structs
// can't).
Pat::Path(name.into())
(None, Pat::Path(name.into()))
}
// shadowing statics is an error as well, so we just ignore that case here
_ => Pat::Bind { name, mode: annotation, subpat },
_ => {
let id = binding_list.find(self, name, annotation);
(Some(id), Pat::Bind { id, subpat })
}
}
} else {
Pat::Bind { name, mode: annotation, subpat }
let id = binding_list.find(self, name, annotation);
(Some(id), Pat::Bind { id, subpat })
};

let ptr = AstPtr::new(&pat);
let pat = self.alloc_pat(pattern, Either::Left(ptr));
if let Some(binding_id) = binding {
self.add_definition_to_binding(binding_id, pat);
}
if let Some(key) = key {
self.name_to_pat_grouping.entry(key).or_default().push(pat);
}
Expand All @@ -871,11 +900,11 @@ impl ExprCollector<'_> {
ast::Pat::TupleStructPat(p) => {
let path =
p.path().and_then(|path| self.expander.parse_path(self.db, path)).map(Box::new);
let (args, ellipsis) = self.collect_tuple_pat(p.fields());
let (args, ellipsis) = self.collect_tuple_pat(p.fields(), binding_list);
Pat::TupleStruct { path, args, ellipsis }
}
ast::Pat::RefPat(p) => {
let pat = self.collect_pat_opt(p.pat());
let pat = self.collect_pat_opt_(p.pat(), binding_list);
let mutability = Mutability::from_mutable(p.mut_token().is_some());
Pat::Ref { pat, mutability }
}
Expand All @@ -886,12 +915,12 @@ impl ExprCollector<'_> {
}
ast::Pat::OrPat(p) => {
self.is_lowering_inside_or_pat = true;
let pats = p.pats().map(|p| self.collect_pat_(p)).collect();
let pats = p.pats().map(|p| self.collect_pat_(p, binding_list)).collect();
Pat::Or(pats)
}
ast::Pat::ParenPat(p) => return self.collect_pat_opt_(p.pat()),
ast::Pat::ParenPat(p) => return self.collect_pat_opt_(p.pat(), binding_list),
ast::Pat::TuplePat(p) => {
let (args, ellipsis) = self.collect_tuple_pat(p.fields());
let (args, ellipsis) = self.collect_tuple_pat(p.fields(), binding_list);
Pat::Tuple { args, ellipsis }
}
ast::Pat::WildcardPat(_) => Pat::Wild,
Expand All @@ -904,7 +933,7 @@ impl ExprCollector<'_> {
.fields()
.filter_map(|f| {
let ast_pat = f.pat()?;
let pat = self.collect_pat_(ast_pat);
let pat = self.collect_pat_(ast_pat, binding_list);
let name = f.field_name()?.as_name();
Some(RecordFieldPat { name, pat })
})
Expand All @@ -923,9 +952,15 @@ impl ExprCollector<'_> {

// FIXME properly handle `RestPat`
Pat::Slice {
prefix: prefix.into_iter().map(|p| self.collect_pat_(p)).collect(),
slice: slice.map(|p| self.collect_pat_(p)),
suffix: suffix.into_iter().map(|p| self.collect_pat_(p)).collect(),
prefix: prefix
.into_iter()
.map(|p| self.collect_pat_(p, binding_list))
.collect(),
slice: slice.map(|p| self.collect_pat_(p, binding_list)),
suffix: suffix
.into_iter()
.map(|p| self.collect_pat_(p, binding_list))
.collect(),
}
}
ast::Pat::LiteralPat(lit) => {
Expand All @@ -948,7 +983,7 @@ impl ExprCollector<'_> {
Pat::Missing
}
ast::Pat::BoxPat(boxpat) => {
let inner = self.collect_pat_opt_(boxpat.pat());
let inner = self.collect_pat_opt_(boxpat.pat(), binding_list);
Pat::Box { inner }
}
ast::Pat::ConstBlockPat(const_block_pat) => {
Expand All @@ -965,7 +1000,7 @@ impl ExprCollector<'_> {
let src = self.expander.to_source(Either::Left(AstPtr::new(&pat)));
let pat =
self.collect_macro_call(call, macro_ptr, true, |this, expanded_pat| {
this.collect_pat_opt_(expanded_pat)
this.collect_pat_opt_(expanded_pat, binding_list)
});
self.source_map.pat_map.insert(src, pat);
return pat;
Expand All @@ -979,21 +1014,25 @@ impl ExprCollector<'_> {
self.alloc_pat(pattern, Either::Left(ptr))
}

fn collect_pat_opt_(&mut self, pat: Option<ast::Pat>) -> PatId {
fn collect_pat_opt_(&mut self, pat: Option<ast::Pat>, binding_list: &mut BindingList) -> PatId {
match pat {
Some(pat) => self.collect_pat_(pat),
Some(pat) => self.collect_pat_(pat, binding_list),
None => self.missing_pat(),
}
}

fn collect_tuple_pat(&mut self, args: AstChildren<ast::Pat>) -> (Box<[PatId]>, Option<usize>) {
fn collect_tuple_pat(
&mut self,
args: AstChildren<ast::Pat>,
binding_list: &mut BindingList,
) -> (Box<[PatId]>, Option<usize>) {
// Find the location of the `..`, if there is one. Note that we do not
// consider the possibility of there being multiple `..` here.
let ellipsis = args.clone().position(|p| matches!(p, ast::Pat::RestPat(_)));
// We want to skip the `..` pattern here, since we account for it above.
let args = args
.filter(|p| !matches!(p, ast::Pat::RestPat(_)))
.map(|p| self.collect_pat_(p))
.map(|p| self.collect_pat_(p, binding_list))
.collect();

(args, ellipsis)
Expand Down Expand Up @@ -1022,6 +1061,10 @@ impl ExprCollector<'_> {
None => Some(()),
}
}

fn add_definition_to_binding(&mut self, binding_id: BindingId, pat_id: PatId) {
self.body.bindings[binding_id].definitions.push(pat_id);
}
}

impl From<ast::LiteralKind> for Literal {
Expand Down
23 changes: 14 additions & 9 deletions crates/hir-def/src/body/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::fmt::{self, Write};
use syntax::ast::HasName;

use crate::{
expr::{Array, BindingAnnotation, ClosureKind, Literal, Movability, Statement},
expr::{Array, BindingAnnotation, BindingId, ClosureKind, Literal, Movability, Statement},
pretty::{print_generic_args, print_path, print_type_ref},
type_ref::TypeRef,
};
Expand Down Expand Up @@ -524,14 +524,8 @@ impl<'a> Printer<'a> {
}
Pat::Path(path) => self.print_path(path),
Pat::Lit(expr) => self.print_expr(*expr),
Pat::Bind { mode, name, subpat } => {
let mode = match mode {
BindingAnnotation::Unannotated => "",
BindingAnnotation::Mutable => "mut ",
BindingAnnotation::Ref => "ref ",
BindingAnnotation::RefMut => "ref mut ",
};
w!(self, "{}{}", mode, name);
Pat::Bind { id, subpat } => {
self.print_binding(*id);
if let Some(pat) = subpat {
self.whitespace();
self.print_pat(*pat);
Expand Down Expand Up @@ -635,4 +629,15 @@ impl<'a> Printer<'a> {
fn print_path(&mut self, path: &Path) {
print_path(path, self).unwrap();
}

fn print_binding(&mut self, id: BindingId) {
let Binding { name, mode, .. } = &self.body.bindings[id];
let mode = match mode {
BindingAnnotation::Unannotated => "",
BindingAnnotation::Mutable => "mut ",
BindingAnnotation::Ref => "ref ",
BindingAnnotation::RefMut => "ref mut ",
};
w!(self, "{}{}", mode, name);
}
}
Loading