Skip to content

Commit

Permalink
add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
domin1985 committed Jan 28, 2021
1 parent c484dd8 commit a495e64
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions tests/python/relay/test_ir_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,45 @@ def test_load_prelude():
tvm.parser.parse(mod.astext())


def test_call_attrs():
def get_func(shape, dtype):
x0 = relay.var("data", shape=shape, dtype=dtype)
w0 = relay.var("weight", shape=shape, dtype=dtype)
a = relay.nn.dense(x0, w0)
b = relay.nn.relu(a)
d = relay.add(b, relay.const(1.0, dtype=dtype))
return relay.Function([x0, w0], d)

# build relay graph
shape = (2, 4)
dtype = "float32"
sub_func = get_func(shape, dtype)
p0 = relay.var("p0", shape=shape, dtype=dtype)
p1 = relay.var("p1", shape=shape, dtype=dtype)
attr = tvm.ir.make_node("attrs.TestAttrs", name="func_call_attrs")
call = relay.Call(sub_func, [p0, p1], attrs=attr)
func = relay.Function([p0, p1], call)

# build relay module
mod = tvm.IRModule()
mod["main"] = func
mod = tvm.relay.transform.InferType()(mod)

# assert equal
program = """
def @main(%p0: Tensor[(2, 4), float32], %p1: Tensor[(2, 4), float32]) {
%2 = fn (%data: Tensor[(2, 4), float32], %weight: Tensor[(2, 4), float32]) {
%0 = nn.dense(%data, %weight, units=None);
%1 = nn.relu(%0);
add(%1, 1f)
};
%2(%p0, %p1, name="func_call_attrs", attrs_type_key="attrs.TestAttrs")
}
"""
parsed = parse_module(program)
assert_graph_equal(parsed, mod)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit a495e64

Please sign in to comment.