Skip to content

Commit

Permalink
feat(opt): Constant fold through simple matches
Browse files Browse the repository at this point in the history
  • Loading branch information
Marwes committed Aug 11, 2020
1 parent 669f959 commit 360c9d0
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 71 deletions.
81 changes: 59 additions & 22 deletions std/cmp.glu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//@NO-IMPLICIT-PRELUDE
//! Functionality for ordering and comparison.

let { Bool, Ordering } = import! std.types
let { Bool, Ordering, Option } = import! std.types
let { Semigroup } = import! std.semigroup
let { Monoid } = import! std.monoid

Expand All @@ -24,42 +24,77 @@ let (/=) ?eq l r : [Eq a] -> a -> a -> Bool = if (eq.(==) l r) then False else T
type Ord a = {
eq : Eq a,
/// Compares two values and returns wheter the first is less than, equal or greater than the second.
compare : a -> a -> Ordering
compare : a -> a -> Ordering,
(<): a -> a -> Bool,
(<=): a -> a -> Bool,
(>=): a -> a -> Bool,
(>): a -> a -> Bool,
}

let compare ?ord : [Ord a] -> a -> a -> Ordering = ord.compare

#[infix(left, 4)]
let (=?) opt y : Option b -> b -> b =
match opt with
| Some x -> x
| None -> y

let mk_ord builder : _ -> Ord a =
let compare = builder.compare
let eq = builder.eq

#[infix(left, 4)]
let (<=) l r : a -> a -> Bool =
match compare l r with
| LT -> True
| EQ -> True
| GT -> False

#[infix(left, 4)]
let (<) l r : a -> a -> Bool =
match compare l r with
| LT -> True
| EQ -> False
| GT -> False

#[infix(left, 4)]
let (>) l r : a -> a -> Bool =
match compare l r with
| LT -> False
| EQ -> False
| GT -> True

#[infix(left, 4)]
let (>=) l r : a -> a -> Bool =
match compare l r with
| LT -> False
| EQ -> True
| GT -> True

{
eq,
compare,
(<) = builder.(<) =? (<),
(<=) = builder.(<=) =? (<=),
(>=) = builder.(>=) =? (>=),
(>) = builder.(>) =? (>),
}

/// Returns whether `l` is less than or equal to `r`.
#[infix(left, 4)]
let (<=) l r : [Ord a] -> a -> a -> Bool =
match compare l r with
| LT -> True
| EQ -> True
| GT -> False
let (<=) ?ord : [Ord a] -> a -> a -> Bool = ord.(<=)

/// Returns whether `l` is less than `r`.
#[infix(left, 4)]
let (<) l r : [Ord a] -> a -> a -> Bool =
match compare l r with
| LT -> True
| EQ -> False
| GT -> False
let (<) ?ord : [Ord a] -> a -> a -> Bool = ord.(<)

/// Returns whether `l` is greater than `r`.
#[infix(left, 4)]
let (>) l r : [Ord a] -> a -> a -> Bool =
match compare l r with
| LT -> False
| EQ -> False
| GT -> True
let (>) ?ord : [Ord a] -> a -> a -> Bool = ord.(>)

/// Returns whether `l` is greater than or equal to `r`.
#[infix(left, 4)]
let (>=) l r : [Ord a] -> a -> a -> Bool =
match compare l r with
| LT -> False
| EQ -> True
| GT -> True
let (>=) ?ord : [Ord a] -> a -> a -> Bool = ord.(>=)

let min l r : [Ord a] -> a -> a -> a =
if l <= r then l
Expand Down Expand Up @@ -97,6 +132,8 @@ let monoid : Monoid Ordering = {
min,
max,

mk_ord,

Ordering,

semigroup,
Expand Down
9 changes: 7 additions & 2 deletions std/int.glu
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
//@NO-IMPLICIT-PRELUDE
//! The signed 64-bit integer type.

let { Option } = import! std.types
let { Semigroup } = import! std.semigroup
let { Monoid } = import! std.monoid
let { Group } = import! std.group
let { Eq, Ord, Ordering } = import! std.cmp
let { Eq, Ord, Ordering, mk_ord } = import! std.cmp
let { Num } = import! std.num
let { Show } = import! std.show

Expand Down Expand Up @@ -41,9 +42,13 @@ let eq : Eq Int = {
(==) = \l r -> l #Int== r,
}

let ord : Ord Int = {
let ord : Ord Int = mk_ord {
eq = eq,
compare = \l r -> if l #Int< r then LT else if l #Int== r then EQ else GT,
(<) = Some (\l r -> l #Int< r),
(<=) = None,
(>=) = None,
(>) = None,
}

let num : Num Int = {
Expand Down
60 changes: 60 additions & 0 deletions tests/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,63 @@ mod.(+) (no_inline 1) 2
"#;
check_expr_eq(core_expr.value.expr(), expected_str);
}

#[test]
fn inline_match() {
let _ = env_logger::try_init();

let thread = make_vm();
thread.get_database_mut().set_implicit_prelude(false);

thread
.load_script(
"test",
r#"
type Test = | A | B
match A with
| A -> 1
| B -> 2
"#,
)
.unwrap_or_else(|err| panic!("{}", err));

let db = thread.get_database();
let core_expr = db
.core_expr("test".into())
.unwrap_or_else(|err| panic!("{}", err));
let expected_str = r#"
1
"#;
check_expr_eq(core_expr.value.expr(), expected_str);
}

#[test]
fn inline_cmp() {
let _ = env_logger::try_init();

let thread = make_vm();
thread.get_database_mut().set_implicit_prelude(false);

thread
.load_script(
"test",
r#"
let mod @ { Option } = import! tests.optimize.cmp
let m = mod.mk_ord {
(<) = Some (\l r -> l #Int< r),
}
m.(<)
"#,
)
.unwrap_or_else(|err| panic!("{}", err));

let db = thread.get_database();
let core_expr = db
.core_expr("test".into())
.unwrap_or_else(|err| panic!("{}", err));
let expected_str = r#"
rec let lt l r = (#Int<) l r
in lt
"#;
check_expr_eq(core_expr.value.expr(), expected_str);
}
17 changes: 17 additions & 0 deletions tests/optimize/cmp.glu
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
let { Option, Ordering, Bool } = import! std.types

#[infix(left, 4)]
let (=?) opt y : Option b -> b -> b =
match opt with
| Some x -> x
| None -> y

let mk_ord builder =
#[infix(left, 4)]
let (<) l r = True

{
(<) = builder.(<) =? (<),
}

{ Option, mk_ord }
66 changes: 44 additions & 22 deletions vm/src/core/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,7 @@ impl<'a, 'e> Compiler<'a, 'e> {
if *cost <= 10 && !self.contains_unbound_variables(e.as_ref()) {
true
} else {
trace!("Unable to optimize: {}", e);
trace!("Unable to optimize {}: {}", cost, e);
false
}
})
Expand Down Expand Up @@ -1406,29 +1406,51 @@ impl<'a, 'e> Compiler<'a, 'e> {
})
})
}
Expr::Match(match_expr, alts) if alts.len() == 1 => {
let alt = &alts[0];
match (&alt.pattern, &alt.expr) {
(Pattern::Record(fields), Expr::Ident(id, _))
if fields.len() == 1
&& id.name
== *fields[0].1.as_ref().unwrap_or(&fields[0].0.name) =>
{
let field = &fields[0];
let match_expr = self.peek_reduced_expr(resolver.wrap(match_expr));
let projected =
self.project_reduced(&field.0.name, match_expr.clone())?;
Some(CostBinding {
cost: 0,
bind: Binding::Expr(projected),
})
Expr::Match(match_expr, alts) => {
let match_expr = self.peek_reduced_expr(resolver.wrap(match_expr));

if alts.len() == 1 {
let alt = &alts[0];
match (&alt.pattern, &alt.expr) {
(Pattern::Record(fields), Expr::Ident(id, _))
if fields.len() == 1
&& id.name
== *fields[0]
.1
.as_ref()
.unwrap_or(&fields[0].0.name) =>
{
let field = &fields[0];
let projected =
self.project_reduced(&field.0.name, match_expr.clone())?;
return Some(CostBinding {
cost: 0,
bind: Binding::Expr(projected),
});
}
(Pattern::Record(_), _) | (Pattern::Ident(_), _) => {
return Some(CostBinding {
cost: 0,
bind: Binding::Expr(resolver.wrap(alt.expr)),
})
}
_ => (),
}
(Pattern::Record(_), _) | (Pattern::Ident(_), _) => Some(CostBinding {
cost: 0,
bind: Binding::Expr(resolver.wrap(alt.expr)),
}),
_ => None,
}
match_expr.with(self.allocator, |_, match_expr| match match_expr {
Expr::Data(id, ..) => alts
.iter()
.find(|alt| match &alt.pattern {
Pattern::Constructor(ctor_id, _) => ctor_id.name == id.name,
Pattern::Ident(_) => true,
_ => false,
})
.map(|alt| CostBinding {
cost: 0,
bind: Binding::Expr(resolver.wrap(alt.expr)),
}),
_ => None,
})
}
peek if !ptr::eq::<Expr>(peek, expr) => Some(CostBinding {
cost: 0,
Expand Down
31 changes: 6 additions & 25 deletions vm/src/core/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,12 @@ impl<'a> Expr<'a> {
"match ",
expr.pretty(arena, Prec::Top).nest(INDENT),
" with",
arena.newline(),
arena.hardline(),
chain![arena, "| ", alt.pattern.pretty(arena), arena.space(), "->"]
.group(),
arena.newline(),
arena.hardline(),
alt.expr.pretty(arena, Prec::Top).group(),
arena.newline(),
arena.hardline(),
"end"
]
.group();
Expand All @@ -222,37 +222,18 @@ impl<'a> Expr<'a> {
"match ",
expr.pretty(arena, Prec::Top).nest(INDENT),
" with",
<<<<<<< HEAD
arena.hardline(),
arena.concat(
alts.iter()
.map(|alt| {
chain![
arena,
"| ",
alt.pattern.pretty(arena),
" ->",
arena.space(),
alt.expr.pretty(arena, Prec::Top).nest(INDENT).group()
]
.nest(INDENT)
})
.intersperse(arena.hardline())
)
=======
arena.newline(),
arena.concat(alts.iter().map(|alt| {
chain![arena;
chain![arena,
"| ",
alt.pattern.pretty(arena),
" ->",
arena.space(),
alt.expr.pretty(arena, Prec::Top).nest(INDENT).group()
].nest(INDENT)
}).intersperse(arena.newline())),
arena.newline(),
}).intersperse(arena.hardline())),
arena.hardline(),
"end"
>>>>>>> a
]
.group();
prec.enclose(arena, doc)
Expand Down

0 comments on commit 360c9d0

Please sign in to comment.