Skip to content

Commit

Permalink
Generalize some inference functions for patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
lowr committed May 31, 2022
1 parent c1c8675 commit 62d6b5a
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 64 deletions.
25 changes: 25 additions & 0 deletions crates/hir-ty/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,31 @@ impl Default for BindingMode {
}
}

/// Used to generalize patterns and assignee expressions.
trait PatLike: Into<ExprOrPatId> + Copy {
type BindingMode: Copy;

fn infer(
this: &mut InferenceContext,
id: Self,
expected_ty: &Ty,
default_bm: Self::BindingMode,
) -> Ty;
}

impl PatLike for PatId {
type BindingMode = BindingMode;

fn infer(
this: &mut InferenceContext,
id: Self,
expected_ty: &Ty,
default_bm: Self::BindingMode,
) -> Ty {
this.infer_pat(id, expected_ty, default_bm)
}
}

#[derive(Debug)]
pub(crate) struct InferOk<T> {
value: T,
Expand Down
146 changes: 82 additions & 64 deletions crates/hir-ty/src/infer/pat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::iter::repeat_with;

use chalk_ir::Mutability;
use hir_def::{
expr::{BindingAnnotation, Expr, Literal, Pat, PatId, RecordFieldPat},
expr::{BindingAnnotation, Expr, Literal, Pat, PatId},
path::Path,
type_ref::ConstScalar,
};
Expand All @@ -17,15 +17,20 @@ use crate::{
TyKind,
};

use super::PatLike;

impl<'a> InferenceContext<'a> {
fn infer_tuple_struct_pat(
/// Infers type for tuple struct pattern or its corresponding assignee expression.
///
/// Ellipses found in the original pattern or expression must be filtered out.
pub(super) fn infer_tuple_struct_pat_like<T: PatLike>(
&mut self,
path: Option<&Path>,
subpats: &[PatId],
expected: &Ty,
default_bm: BindingMode,
id: PatId,
default_bm: T::BindingMode,
id: T,
ellipsis: Option<usize>,
subs: &[T],
) -> Ty {
let (ty, def) = self.resolve_variant(path, true);
let var_data = def.map(|it| it.variant_data(self.db.upcast()));
Expand All @@ -39,8 +44,8 @@ impl<'a> InferenceContext<'a> {

let field_tys = def.map(|it| self.db.field_types(it)).unwrap_or_default();
let (pre, post) = match ellipsis {
Some(idx) => subpats.split_at(idx),
None => (subpats, &[][..]),
Some(idx) => subs.split_at(idx),
None => (subs, &[][..]),
};
let post_idx_offset = field_tys.iter().count().saturating_sub(post.len());

Expand All @@ -54,22 +59,22 @@ impl<'a> InferenceContext<'a> {
field_tys[field].clone().substitute(Interner, &substs)
});
let expected_ty = self.normalize_associated_types_in(expected_ty);
self.infer_pat(subpat, &expected_ty, default_bm);
T::infer(self, subpat, &expected_ty, default_bm);
}

ty
}

fn infer_record_pat(
/// Infers type for record pattern or its corresponding assignee expression.
pub(super) fn infer_record_pat_like<T: PatLike>(
&mut self,
path: Option<&Path>,
subpats: &[RecordFieldPat],
expected: &Ty,
default_bm: BindingMode,
id: PatId,
default_bm: T::BindingMode,
id: T,
subs: impl Iterator<Item = (Name, T)>,
) -> Ty {
let (ty, def) = self.resolve_variant(path, false);
let var_data = def.map(|it| it.variant_data(self.db.upcast()));
if let Some(variant) = def {
self.write_variant_resolution(id.into(), variant);
}
Expand All @@ -80,18 +85,64 @@ impl<'a> InferenceContext<'a> {
ty.as_adt().map(|(_, s)| s.clone()).unwrap_or_else(|| Substitution::empty(Interner));

let field_tys = def.map(|it| self.db.field_types(it)).unwrap_or_default();
for subpat in subpats {
let matching_field = var_data.as_ref().and_then(|it| it.field(&subpat.name));
let expected_ty = matching_field.map_or(self.err_ty(), |field| {
field_tys[field].clone().substitute(Interner, &substs)
});
let var_data = def.map(|it| it.variant_data(self.db.upcast()));

for (name, inner) in subs {
let expected_ty = var_data
.as_ref()
.and_then(|it| it.field(&name))
.map_or(self.err_ty(), |f| field_tys[f].clone().substitute(Interner, &substs));
let expected_ty = self.normalize_associated_types_in(expected_ty);
self.infer_pat(subpat.pat, &expected_ty, default_bm);

T::infer(self, inner, &expected_ty, default_bm);
}

ty
}

/// Infers type for tuple pattern or its corresponding assignee expression.
///
/// Ellipses found in the original pattern or expression must be filtered out.
pub(super) fn infer_tuple_pat_like<T: PatLike>(
&mut self,
expected: &Ty,
default_bm: T::BindingMode,
ellipsis: Option<usize>,
subs: &[T],
) -> Ty {
let expectations = match expected.as_tuple() {
Some(parameters) => &*parameters.as_slice(Interner),
_ => &[],
};

let ((pre, post), n_uncovered_patterns) = match ellipsis {
Some(idx) => (subs.split_at(idx), expectations.len().saturating_sub(subs.len())),
None => ((&subs[..], &[][..]), 0),
};
let mut expectations_iter = expectations
.iter()
.cloned()
.map(|a| a.assert_ty_ref(Interner).clone())
.chain(repeat_with(|| self.table.new_type_var()));

let mut inner_tys = Vec::with_capacity(n_uncovered_patterns + subs.len());

inner_tys.extend(expectations_iter.by_ref().take(n_uncovered_patterns + subs.len()));

// Process pre
for (ty, pat) in inner_tys.iter_mut().zip(pre) {
*ty = T::infer(self, *pat, ty, default_bm);
}

// Process post
for (ty, pat) in inner_tys.iter_mut().skip(pre.len() + n_uncovered_patterns).zip(post) {
*ty = T::infer(self, *pat, ty, default_bm);
}

TyKind::Tuple(inner_tys.len(), Substitution::from_iter(Interner, inner_tys))
.intern(Interner)
}

pub(super) fn infer_pat(
&mut self,
pat: PatId,
Expand Down Expand Up @@ -129,42 +180,7 @@ impl<'a> InferenceContext<'a> {

let ty = match &self.body[pat] {
Pat::Tuple { args, ellipsis } => {
let expectations = match expected.as_tuple() {
Some(parameters) => &*parameters.as_slice(Interner),
_ => &[],
};

let ((pre, post), n_uncovered_patterns) = match ellipsis {
Some(idx) => {
(args.split_at(*idx), expectations.len().saturating_sub(args.len()))
}
None => ((&args[..], &[][..]), 0),
};
let mut expectations_iter = expectations
.iter()
.cloned()
.map(|a| a.assert_ty_ref(Interner).clone())
.chain(repeat_with(|| self.table.new_type_var()));

let mut inner_tys = Vec::with_capacity(n_uncovered_patterns + args.len());

inner_tys
.extend(expectations_iter.by_ref().take(n_uncovered_patterns + args.len()));

// Process pre
for (ty, pat) in inner_tys.iter_mut().zip(pre) {
*ty = self.infer_pat(*pat, ty, default_bm);
}

// Process post
for (ty, pat) in
inner_tys.iter_mut().skip(pre.len() + n_uncovered_patterns).zip(post)
{
*ty = self.infer_pat(*pat, ty, default_bm);
}

TyKind::Tuple(inner_tys.len(), Substitution::from_iter(Interner, inner_tys))
.intern(Interner)
self.infer_tuple_pat_like(&expected, default_bm, *ellipsis, args)
}
Pat::Or(pats) => {
if let Some((first_pat, rest)) = pats.split_first() {
Expand All @@ -191,16 +207,18 @@ impl<'a> InferenceContext<'a> {
let subty = self.infer_pat(*pat, &expectation, default_bm);
TyKind::Ref(mutability, static_lifetime(), subty).intern(Interner)
}
Pat::TupleStruct { path: p, args: subpats, ellipsis } => self.infer_tuple_struct_pat(
p.as_deref(),
subpats,
&expected,
default_bm,
pat,
*ellipsis,
),
Pat::TupleStruct { path: p, args: subpats, ellipsis } => self
.infer_tuple_struct_pat_like(
p.as_deref(),
&expected,
default_bm,
pat,
*ellipsis,
subpats,
),
Pat::Record { path: p, args: fields, ellipsis: _ } => {
self.infer_record_pat(p.as_deref(), fields, &expected, default_bm, pat)
let subs = fields.iter().map(|f| (f.name.clone(), f.pat));
self.infer_record_pat_like(p.as_deref(), &expected, default_bm, pat.into(), subs)
}
Pat::Path(path) => {
// FIXME use correct resolver for the surrounding expression
Expand Down

0 comments on commit 62d6b5a

Please sign in to comment.