Skip to content

Commit

Permalink
Add option ADT to prelude
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Jan 25, 2019
1 parent 4029e1e commit 99f6591
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def define_list_foldr(self):
self.mod[self.foldr] = Function([f, bv, av],
Match(av, [nil_case, cons_case]), None, [a, b])

def define_option_adt(self):
"""Defines an option ADT, which can either contain some other
type or nothing at all."""
self.option = GlobalTypeVar("option")
a = TypeVar("a")
self.some = Constructor("some", [a], self.option)
self.none = Constructor("none", [], self.option)
self.mod[self.option] = TypeData(self.option, [a], [self.some, self.none])

def define_nat_adt(self):
"""Defines a Peano (unary) natural number ADT.
Zero is represented by z(). s(n) adds 1 to a nat n."""
Expand Down Expand Up @@ -177,6 +186,8 @@ def __init__(self, mod):
self.define_list_foldl()
self.define_list_foldr()

self.define_option_adt()

self.define_nat_adt()
self.define_nat_double()
self.define_nat_add()
Expand Down
24 changes: 24 additions & 0 deletions tests/python/relay/test_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
double = p.double
add = p.add

some = p.some
none = p.none

nil = p.nil
cons = p.cons
l = p.l
Expand Down Expand Up @@ -172,6 +175,27 @@ def test_sum():
assert count(res) == 3


def test_option_matching():
x = relay.Var('x')
y = relay.Var('y')
v = relay.Var('v')
condense = relay.Function(
[x, y],
relay.Match(x, [
relay.Clause(relay.PatternConstructor(some, [relay.PatternVar(v)]), cons(v, y)),
relay.Clause(relay.PatternConstructor(none), y)
]))

res = intrp.evaluate(foldr(condense, nil(), cons(
some(build_nat(3)),
cons(none(), cons(some(build_nat(1)), nil())))))

reduced = to_list(res)
assert len(reduced) == 2
assert count(reduced[0]) == 3
assert count(reduced[1]) == 1


def test_tmap():
a = relay.TypeVar("a")
b = relay.TypeVar("b")
Expand Down

0 comments on commit 99f6591

Please sign in to comment.