Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pattern guards #1910

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions core/src/eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ use crate::{
make as mk_term,
pattern::compile::Compile,
record::{Field, RecordData},
BinaryOp, BindingType, LetAttrs, MatchData, RecordOpKind, RichTerm, RuntimeContract,
StrChunk, Term, UnaryOp,
BinaryOp, BindingType, LetAttrs, MatchBranch, MatchData, RecordOpKind, RichTerm,
RuntimeContract, StrChunk, Term, UnaryOp,
},
};

Expand Down Expand Up @@ -1191,11 +1191,12 @@ pub fn subst<C: Cache>(
Term::Match(data) => {
let branches = data.branches
.into_iter()
.map(|(pat, branch)| {
(
pat,
subst(cache, branch, initial_env, env),
)
.map(|MatchBranch { pattern, guard, body} | {
MatchBranch {
pattern,
guard: guard.map(|cond| subst(cache, cond, initial_env, env)),
body: subst(cache, body, initial_env, env),
}
})
.collect();

Expand Down
12 changes: 7 additions & 5 deletions core/src/parser/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,10 @@ Applicative: UniTerm = {
<op: BOpPre> <t1: AsTerm<Atom>> <t2: AsTerm<Atom>>
=> UniTerm::from(mk_term::op2(op, t1, t2)),
NOpPre<AsTerm<Atom>>,
"match" "{" <branches: (MatchCase ",")*> <last: MatchCase?> "}" => {
"match" "{" <branches: (MatchBranch ",")*> <last: MatchBranch?> "}" => {
let branches = branches
.into_iter()
.map(|(case, _comma)| case)
.map(|(branch, _comma)| branch)
.chain(last)
.collect();

Expand Down Expand Up @@ -876,9 +876,11 @@ UOp: UnaryOp = {
"enum_get_tag" => UnaryOp::EnumGetTag(),
}

MatchCase: (Pattern, RichTerm) = {
<pat: Pattern> "=>" <t: Term> => (pat, t),
};
PatternGuard: RichTerm = "if" <Term> => <>;

MatchBranch: MatchBranch =
<pattern: Pattern> <guard: PatternGuard?> "=>" <body: Term> =>
MatchBranch { pattern, guard, body};

// Infix operators by precedence levels. Lowest levels take precedence over
// highest ones.
Expand Down
44 changes: 30 additions & 14 deletions core/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -884,20 +884,12 @@ where
allocator,
allocator.line(),
allocator.intersperse(
data.branches
.iter()
.map(|(pat, t)| (pat.pretty(allocator), t))
.map(|(lhs, t)| docs![
allocator,
lhs,
allocator.space(),
"=>",
allocator.line(),
t,
","
]
.nest(2)),
allocator.line()
data.branches.iter().map(|branch| docs![
allocator,
branch.pretty(allocator),
","
]),
allocator.line(),
),
]
.nest(2)
Expand Down Expand Up @@ -1185,6 +1177,30 @@ where
}
}

impl<'a, D, A> Pretty<'a, D, A> for &MatchBranch
where
D: NickelAllocatorExt<'a, A>,
D::Doc: Clone,
A: Clone + 'a,
{
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
let guard = if let Some(guard) = &self.guard {
docs![allocator, allocator.line(), "if", allocator.space(), guard]
} else {
allocator.nil()
};

docs![
allocator,
&self.pattern,
guard,
allocator.space(),
"=>",
docs![allocator, allocator.line(), self.body.pretty(allocator),].nest(2),
]
}
}

/// Generate an implementation of `fmt::Display` for types that implement `Pretty`.
#[macro_export]
macro_rules! impl_display_from_pretty {
Expand Down
50 changes: 43 additions & 7 deletions core/src/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,12 +604,24 @@ impl fmt::Display for MergePriority {
}
}

/// A branch of a match expression.
#[derive(Debug, PartialEq, Clone)]
pub struct MatchBranch {
/// The pattern on the left hand side of `=>`.
pub pattern: Pattern,
/// A potential guard, which is an additional side-condition defined as `if cond`. The value
/// stored in this field is the boolean condition itself.
pub guard: Option<RichTerm>,
/// The body of the branch, on the right hand side of `=>`.
pub body: RichTerm,
}

/// Content of a match expression.
#[derive(Debug, PartialEq, Clone)]
pub struct MatchData {
/// Branches of the match expression, where the first component is the pattern on the left hand
/// side of `=>` and the second component is the body of the branch.
pub branches: Vec<(Pattern, RichTerm)>,
pub branches: Vec<MatchBranch>,
}

/// A type or a contract together with its corresponding label.
Expand Down Expand Up @@ -2024,11 +2036,26 @@ impl Traverse<RichTerm> for RichTerm {
Term::Match(data) => {
// The annotation on `map_res` use Result's corresponding trait to convert from
// Iterator<Result> to a Result<Iterator>
let branches: Result<Vec<(Pattern, RichTerm)>, E> = data
let branches: Result<Vec<MatchBranch>, E> = data
.branches
.into_iter()
// For the conversion to work, note that we need a Result<(Ident,RichTerm), E>
.map(|(pat, t)| t.traverse(f, order).map(|t_ok| (pat, t_ok)))
.map(
|MatchBranch {
pattern,
guard,
body,
}| {
let guard = guard.map(|cond| cond.traverse(f, order)).transpose()?;
let body = body.traverse(f, order)?;

Ok(MatchBranch {
pattern,
guard,
body,
})
},
)
.collect();

RichTerm::new(
Expand Down Expand Up @@ -2203,10 +2230,19 @@ impl Traverse<RichTerm> for RichTerm {
.or_else(|| field.traverse_ref(f, state))
})
}),
Term::Match(data) => data
.branches
.iter()
.find_map(|(_pat, t)| t.traverse_ref(f, state)),
Term::Match(data) => data.branches.iter().find_map(
|MatchBranch {
pattern: _,
guard,
body,
}| {
if let Some(cond) = guard.as_ref() {
cond.traverse_ref(f, state)?;
}

body.traverse_ref(f, state)
},
),
Term::Array(ts, _) => ts.iter().find_map(|t| t.traverse_ref(f, state)),
Term::OpN(_, ts) => ts.iter().find_map(|t| t.traverse_ref(f, state)),
Term::Annotated(annot, t) => t
Expand Down
86 changes: 61 additions & 25 deletions core/src/term/pattern/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ use super::*;
use crate::{
mk_app,
term::{
make, record::FieldMetadata, BinaryOp, MatchData, RecordExtKind, RecordOpKind, RichTerm,
Term, UnaryOp,
make, record::FieldMetadata, BinaryOp, MatchBranch, MatchData, RecordExtKind, RecordOpKind,
RichTerm, Term, UnaryOp,
},
};

Expand Down Expand Up @@ -668,23 +668,32 @@ impl Compile for MatchData {
// # this primop evaluates body with an environment extended with bindings_id
// %pattern_branch% body bindings_id
fn compile(mut self, value: RichTerm, pos: TermPos) -> RichTerm {
if self.branches.iter().all(|(pat, _)| {
yannham marked this conversation as resolved.
Show resolved Hide resolved
if self.branches.iter().all(|branch| {
// While we could get something working even with a guard, it's a bit more work and
// there's no current incentive to do so (a guard on a tags-only match is arguably less
// common, as such patterns don't bind any variable). For the time being, we just
// exclude guards from the tags-only optimization.
matches!(
pat.data,
branch.pattern.data,
PatternData::Enum(EnumPattern { pattern: None, .. }) | PatternData::Wildcard
)
) && branch.guard.is_none()
}) {
let wildcard_pat = self
.branches
.iter()
.enumerate()
.find_map(|(idx, (pat, body))| {
if let PatternData::Wildcard = pat.data {
let wildcard_pat = self.branches.iter().enumerate().find_map(
|(
idx,
MatchBranch {
pattern,
guard,
body,
},
)| {
if matches!((&pattern.data, guard), (PatternData::Wildcard, None)) {
Some((idx, body.clone()))
} else {
None
}
});
},
);

// If we find a wildcard pattern, we record its index in order to discard all the
// patterns coming after the wildcard, because they are unreachable.
Expand All @@ -698,13 +707,19 @@ impl Compile for MatchData {
let tags_only = self
.branches
.into_iter()
.filter_map(|(pat, body)| {
if let PatternData::Enum(EnumPattern { tag, .. }) = pat.data {
Some((tag, body))
} else {
None
}
})
.filter_map(
|MatchBranch {
pattern,
guard: _,
body,
}| {
if let PatternData::Enum(EnumPattern { tag, .. }) = pattern.data {
Some((tag, body))
} else {
None
}
},
)
.collect();

return TagsOnlyMatch {
Expand All @@ -726,14 +741,14 @@ impl Compile for MatchData {

// The fold block:
//
// <for (pattern, body) in branches.rev()
// <for branch in branches.rev()
// - cont is the accumulator
// - initial accumulator is the default branch (or error if not default branch)
// >
// let init_bindings_id = {} in
// let bindings_id = <pattern.compile_part(value_id, init_bindings)> in
//
// if bindings_id == null then
// if bindings_id == null || !<guard> then
// cont
// else
// # this primop evaluates body with an environment extended with bindings_id
Expand All @@ -742,10 +757,31 @@ impl Compile for MatchData {
.branches
.into_iter()
.rev()
.fold(error_case, |cont, (pat, body)| {
.fold(error_case, |cont, branch| {
let init_bindings_id = LocIdent::fresh();
let bindings_id = LocIdent::fresh();

// inner if condition:
// bindings_id == null || !<guard>
let inner_if_cond = make::op2(BinaryOp::Eq(), Term::Var(bindings_id), Term::Null);
let inner_if_cond = if let Some(guard) = branch.guard {
// the guard must be evaluated in the same environment as the body of the
// branch, as it might use bindings introduced by the pattern. Since `||` is
// lazy in Nickel, we know that `bindings_id` is not null if the guard
// condition is ever evaluated.
let guard_cond = mk_app!(
make::op1(UnaryOp::PatternBranch(), Term::Var(bindings_id)),
guard
);

mk_app!(
make::op1(UnaryOp::BoolOr(), inner_if_cond),
make::op1(UnaryOp::BoolNot(), guard_cond)
)
} else {
inner_if_cond
};

// inner if block:
//
// if bindings_id == null then
Expand All @@ -754,11 +790,11 @@ impl Compile for MatchData {
// # this primop evaluates body with an environment extended with bindings_id
// %pattern_branch% bindings_id body
let inner = make::if_then_else(
make::op2(BinaryOp::Eq(), Term::Var(bindings_id), Term::Null),
inner_if_cond,
cont,
mk_app!(
make::op1(UnaryOp::PatternBranch(), Term::Var(bindings_id),),
body
branch.body
),
);

Expand All @@ -772,7 +808,7 @@ impl Compile for MatchData {
Term::Record(RecordData::empty()),
make::let_in(
bindings_id,
pat.compile_part(value_id, init_bindings_id),
branch.pattern.compile_part(value_id, init_bindings_id),
inner,
),
)
Expand Down
12 changes: 8 additions & 4 deletions core/src/transform/desugar_destructuring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use crate::{
identifier::LocIdent,
match_sharedterm,
term::{pattern::*, MatchData, RichTerm, Term},
term::{pattern::*, MatchBranch, MatchData, RichTerm, Term},
};

/// Entry point of the destructuring desugaring transformation.
Expand Down Expand Up @@ -48,16 +48,20 @@ pub fn desugar_fun(mut pat: Pattern, body: RichTerm) -> Term {
/// Desugar a destructuring let-binding.
///
/// A let-binding `let <pat> = bound in body` is desugared to `<bound> |> match { <pat> => body }`.
pub fn desugar_let(pat: Pattern, bound: RichTerm, body: RichTerm) -> Term {
pub fn desugar_let(pattern: Pattern, bound: RichTerm, body: RichTerm) -> Term {
// the position of the match expression is used during error reporting, so we try to provide a
// sensible one.
let match_expr_pos = pat.pos.fuse(bound.pos);
let match_expr_pos = pattern.pos.fuse(bound.pos);

// `(match { <pat> => <body> }) <bound>`
Term::App(
RichTerm::new(
Term::Match(MatchData {
branches: vec![(pat, body)],
branches: vec![MatchBranch {
pattern,
guard: None,
body,
}],
}),
match_expr_pos,
),
Expand Down
Loading
Loading