diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 162271756557c..7412bb261367f 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -910,6 +910,45 @@ def test_load_prelude(): tvm.parser.parse(mod.astext()) +def test_call_attrs(): + def get_func(shape, dtype): + x0 = relay.var("data", shape=shape, dtype=dtype) + w0 = relay.var("weight", shape=shape, dtype=dtype) + a = relay.nn.dense(x0, w0) + b = relay.nn.relu(a) + d = relay.add(b, relay.const(1.0, dtype=dtype)) + return relay.Function([x0, w0], d) + + # build relay graph + shape = (2, 4) + dtype = "float32" + sub_func = get_func(shape, dtype) + p0 = relay.var("p0", shape=shape, dtype=dtype) + p1 = relay.var("p1", shape=shape, dtype=dtype) + attr = tvm.ir.make_node("attrs.TestAttrs", name="func_call_attrs") + call = relay.Call(sub_func, [p0, p1], attrs=attr) + func = relay.Function([p0, p1], call) + + # build relay module + mod = tvm.IRModule() + mod["main"] = func + mod = tvm.relay.transform.InferType()(mod) + + # assert equal + program = """ + def @main(%p0: Tensor[(2, 4), float32], %p1: Tensor[(2, 4), float32]) { + %2 = fn (%data: Tensor[(2, 4), float32], %weight: Tensor[(2, 4), float32]) { + %0 = nn.dense(%data, %weight, units=None); + %1 = nn.relu(%0); + add(%1, 1f) + }; + %2(%p0, %p1, name="func_call_attrs", attrs_type_key="attrs.TestAttrs") + } + """ + parsed = parse_module(program) + assert_graph_equal(parsed, mod) + + if __name__ == "__main__": import sys