From 512a51030559f044855fb5253102f1d11ed0d247 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Sun, 10 Mar 2019 21:11:40 -0700 Subject: [PATCH] [Relay] Add hd,tl,nth for list in Prelude --- python/tvm/relay/prelude.py | 47 ++++++++++++++++++++++++++++++++++ tests/python/relay/test_adt.py | 27 +++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 41d1be284f8e..26f00c5c5e6d 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -18,6 +18,50 @@ def define_list_adt(self): self.cons = Constructor("cons", [a, self.l(a)], self.l) self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons]) + def define_list_hd(self): + """Defines a function to get the head of a list. Assume the list has at least one + element. + + hd(l) : list[a] -> a + """ + self.hd = GlobalVar("hd") + a = TypeVar("a") + x = Var("x", self.l(a)) + y = Var("y") + z = Var("z") + # Don't match nil() since it will break type checking + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y) + self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a]) + + def define_list_tl(self): + """Defines a function to get the tail of a list. + + tl(l) : list[a] -> list[a] + """ + self.tl = GlobalVar("tl") + a = TypeVar("a") + x = Var("x", self.l(a)) + y = Var("y") + z = Var("z") + nil_case = Clause(PatternConstructor(self.nil, []), self.nil()) + cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z) + self.mod[self.tl] = Function([x], Match(x, [nil_case, cons_case]), self.l(a), [a]) + + def define_list_nth(self): + """Defines a function to get the nth element of a list. + + nth(l) : list[a] -> a + """ + self.nth = GlobalVar("nth") + a = TypeVar("a") + x = Var("x", self.l(a)) + n = Var("n", self.nat()) + + y = Var("y") + z_case = Clause(PatternConstructor(self.z), self.hd(x)) + s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y)) + self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a]) + def define_list_map(self): """Defines a function for mapping a function over a list's elements. That is, map(f, l) returns a new list where @@ -405,6 +449,8 @@ def define_iterate(self): def __init__(self, mod): self.mod = mod self.define_list_adt() + self.define_list_hd() + self.define_list_tl() self.define_list_map() self.define_list_foldl() self.define_list_foldr() @@ -423,6 +469,7 @@ def __init__(self, mod): self.define_nat_double() self.define_nat_add() self.define_list_length() + self.define_list_nth() self.define_list_sum() self.define_tree_adt() diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index a67ee7542e8a..e176194fede6 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -23,6 +23,9 @@ nil = p.nil cons = p.cons l = p.l +hd = p.hd +tl = p.tl +nth = p.nth length = p.length map = p.map foldl = p.foldl @@ -120,6 +123,30 @@ def test_list_constructor(): a = relay.TypeVar("a") assert relay.ir_pass.infer_type(cons(z(), nil()), mod).checked_type == l(nat()) +def test_hd_tl(): + expected = list(range(10)) + l = nil() + for i in reversed(expected): + l = cons(build_nat(i), l) + + got = [] + for i in range(len(expected)): + got.append(count(intrp.evaluate(hd(l)))) + l = tl(l) + + assert got == expected + +def test_nth(): + expected = list(range(10)) + l = nil() + for i in reversed(expected): + l = cons(build_nat(i), l) + + got = [] + for i in range(len(expected)): + got.append(count(intrp.evaluate(nth(l, build_nat(i))))) + + assert got == expected def test_length(): a = relay.TypeVar("a")