Skip to content

Commit

Permalink
[Relay] Add hd,tl,nth for list in Prelude (apache#2771)
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic committed Mar 20, 2019
1 parent 1294c82 commit b447360
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
47 changes: 47 additions & 0 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,50 @@ def define_list_adt(self):
self.cons = Constructor("cons", [a, self.l(a)], self.l)
self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons])

def define_list_hd(self):
"""Defines a function to get the head of a list. Assume the list has at least one
element.
hd(l) : list[a] -> a
"""
self.hd = GlobalVar("hd")
a = TypeVar("a")
x = Var("x", self.l(a))
y = Var("y")
z = Var("z")
# Don't match nil() since it will break type checking
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y)
self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a])

def define_list_tl(self):
"""Defines a function to get the tail of a list.
tl(l) : list[a] -> list[a]
"""
self.tl = GlobalVar("tl")
a = TypeVar("a")
x = Var("x", self.l(a))
y = Var("y")
z = Var("z")
nil_case = Clause(PatternConstructor(self.nil, []), self.nil())
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z)
self.mod[self.tl] = Function([x], Match(x, [nil_case, cons_case]), self.l(a), [a])

def define_list_nth(self):
"""Defines a function to get the nth element of a list.
nth(l) : list[a] -> a
"""
self.nth = GlobalVar("nth")
a = TypeVar("a")
x = Var("x", self.l(a))
n = Var("n", self.nat())

y = Var("y")
z_case = Clause(PatternConstructor(self.z), self.hd(x))
s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y))
self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a])

def define_list_map(self):
"""Defines a function for mapping a function over a list's
elements. That is, map(f, l) returns a new list where
Expand Down Expand Up @@ -405,6 +449,8 @@ def define_iterate(self):
def __init__(self, mod):
self.mod = mod
self.define_list_adt()
self.define_list_hd()
self.define_list_tl()
self.define_list_map()
self.define_list_foldl()
self.define_list_foldr()
Expand All @@ -423,6 +469,7 @@ def __init__(self, mod):
self.define_nat_double()
self.define_nat_add()
self.define_list_length()
self.define_list_nth()
self.define_list_sum()

self.define_tree_adt()
Expand Down
27 changes: 27 additions & 0 deletions tests/python/relay/test_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
nil = p.nil
cons = p.cons
l = p.l
hd = p.hd
tl = p.tl
nth = p.nth
length = p.length
map = p.map
foldl = p.foldl
Expand Down Expand Up @@ -120,6 +123,30 @@ def test_list_constructor():
a = relay.TypeVar("a")
assert relay.ir_pass.infer_type(cons(z(), nil()), mod).checked_type == l(nat())

def test_hd_tl():
expected = list(range(10))
l = nil()
for i in reversed(expected):
l = cons(build_nat(i), l)

got = []
for i in range(len(expected)):
got.append(count(intrp.evaluate(hd(l))))
l = tl(l)

assert got == expected

def test_nth():
expected = list(range(10))
l = nil()
for i in reversed(expected):
l = cons(build_nat(i), l)

got = []
for i in range(len(expected)):
got.append(count(intrp.evaluate(nth(l, build_nat(i)))))

assert got == expected

def test_length():
a = relay.TypeVar("a")
Expand Down

0 comments on commit b447360

Please sign in to comment.