diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 48634bafa744..78688d7dc730 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -92,7 +92,7 @@ class AlphaEqualHandler: auto compute = [&]() { if (&lhs == &rhs) return true; if (auto lhsd = lhs.as()) { - auto rhsd = lhs.as(); + auto rhsd = rhs.as(); if (!rhsd) return false; if (lhsd->dict.size() != rhsd->dict.size()) return false; for (const auto& k : lhsd->dict) { diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index ad1525576d08..bdda72ca8702 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """ test ir""" +import pytest import tvm from tvm import relay from tvm.tir.expr import * @@ -174,6 +175,7 @@ def test_function(): str(fn) check_json_roundtrip(fn) +@pytest.mark.skip(reason="AttrsEqualHandler doesn't handle Map so far.") def test_function_attrs(): param_names = ['a', 'b', 'c', 'd'] params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names]) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 5985273ce6de..0319d0b1a371 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -18,6 +18,7 @@ import tvm from tvm import relay from tvm.relay import analysis +from tvm.relay.testing import run_opt_pass def alpha_equal(x, y): """ @@ -313,7 +314,7 @@ def test_tuple_get_item_alpha_equal(): assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1)) -def test_multi_node_subgraph(): +def test_function_attr(): x0 = relay.var('x0', shape=(10, 10)) w00 = relay.var('w00', shape=(10, 10)) w01 = relay.var('w01', shape=(10, 10)) @@ -607,6 +608,7 @@ def test_graph_equal(): z3 = relay.add(relay.add(x, x), relay.add(x, x)) + assert alpha_equal(z0, z1) assert alpha_equal(z0, z1) # z3's dataflow format is different from z0 @@ -649,6 +651,26 @@ def test_tuple_match(): assert analysis.structural_hash(x) == analysis.structural_hash(y) +def test_fn_attribute(): + # create function that performs add + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + add = relay.add(a, b) + add_fn = relay.Function([a, b], add) + add_fn = run_opt_pass(add_fn, relay.transform.InferType()) + + # create function that performs add with test attribute + c = relay.var('c', shape=(10, 10)) + d = relay.var('d', shape=(10, 10)) + add_1 = relay.add(c, d) + add_1_fn = relay.Function([c, d], add_1) + add_1_fn = add_1_fn.set_attribute("TestAttribute", tvm.tir.StringImm("test")) + add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType()) + + assert not relay.analysis.alpha_equal(add_1_fn, add_fn) + assert not relay.analysis.alpha_equal(add_fn, add_1_fn) + + if __name__ == "__main__": test_tensor_type_alpha_equal() test_incomplete_type_alpha_equal() @@ -672,3 +694,4 @@ def test_tuple_match(): test_var_alpha_equal() test_graph_equal() test_hash_unequal() + test_fn_attribute() diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 18916f758a6c..e11b6aeb0a2c 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -35,6 +35,7 @@ def expected(): z = relay.exp(y) w = relay.squeeze(z) f1 = relay.Function([x], w) + f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) return relay.Function([x], y) @@ -76,6 +77,8 @@ def expected(dshape): x = relay.var("p0", shape=dshape) y = relay.add(x, relay.const(1, "float32")) f0 = relay.Function([x], y) + f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + # segment 1 x = relay.var("p0", shape=dshape) w = relay.var("p1") @@ -86,6 +89,8 @@ def expected(dshape): y1 = relay.add(relay.const(1, "float32"), y) y = relay.add(y, y1) f1 = relay.Function([x, w], y) + f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + # segment 2 x = relay.var("p0", shape=dshape) w = relay.var("p1") @@ -94,6 +99,8 @@ def expected(dshape): padding=(1,1), channels=16) f2 = relay.Function([x, w], z2) + f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + # segment 3 x = relay.var("p0", shape=dshape) w = relay.var("p1") @@ -104,6 +111,8 @@ def expected(dshape): channels=16) z3 = relay.add(z3, offset) f3 = relay.Function([x, w, offset], z3) + f3 = f3.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + # compose x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) @@ -135,6 +144,7 @@ def expected(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) f0 = relay.Function([x], pooled) + f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) p1 = relay.var("p1", shape=dshape) @@ -142,6 +152,7 @@ def expected(dshape): concat = relay.concatenate((upsampled, p1), axis=1) out = relay.add(concat, relay.const(1, "float32")) f1 = relay.Function([p0, p1], out) + f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) @@ -172,10 +183,12 @@ def expected(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) f0 = relay.Function([x], pooled) + f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") f1 = relay.Function([p0], upsampled) + f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) @@ -205,10 +218,12 @@ def expected(dshape): x = relay.var("p0", shape=dshape) y = relay.add(x, relay.const(1, "float32")) f1 = relay.Function([x], y) + f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("p01", shape=dshape) y = relay.exp(x) f2 = relay.Function([x], y) + f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f1, [x]) @@ -242,6 +257,7 @@ def expected(dshape, dtype): p2 = relay.var('p2', shape=dshape, dtype=dtype) fused_gt = relay.Function([p1, p2], relay.op.greater(p1, p2)) + fused_gt = fused_gt.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) with sb.if_scope(fused_gt(x, y)): sb.ret(relay.Function([], x)) with sb.else_scope(): @@ -271,11 +287,13 @@ def expected(dim): p1 = relay.var("p1", shape=(3 * dim, dim)) matmul = relay.nn.dense(p0, p1) f0 = relay.Function([p0, p1], matmul) + f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) p01 = relay.var("p01", shape=(1, 3 * dim)) splitted = relay.split(p01, indices_or_sections=3, axis=1) out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) f1 = relay.Function([p01], out) + f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) X = relay.var("X", shape=(1, dim)) W = relay.var("W", shape=(3 * dim, dim)) @@ -306,11 +324,13 @@ def expected(dim): splitted = relay.split(p0, indices_or_sections=3, axis=1) out = splitted[0] f0 = relay.Function([p0], out) + f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) p01 = relay.var("p01", shape=(1, dim)) p1 = relay.var("p1", shape=(dim, dim)) out = relay.nn.dense(p01, p1) f1 = relay.Function([p01, p1], out) + f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) X = relay.var("X", shape=(1, 3 * dim)) W = relay.var("W", shape=(dim, dim)) @@ -346,8 +366,9 @@ def before(x): def expected(p0): f0 = before(p0) + f1 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) - y = relay.Call(f0, [x]) + y = relay.Call(f1, [x]) return relay.Function([x], y) dshape = (1, 16, 64, 64) @@ -388,15 +409,18 @@ def expected(dshape): p0 = relay.var("p0", shape=dshape) concat = gen_consecutive_tuple(p0) f0 = relay.Function([p0], concat) + f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3])) pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) out = relay.add(pooled, relay.const(1, "float32")) f1 = relay.Function([p01], out) + f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2)) out = relay.add(p02, relay.const(1, "float32")) f2 = relay.Function([p02], out) + f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) @@ -438,30 +462,36 @@ def expected(dshape): p0 = relay.var("p0", shape=dshape) c = conv(p0) f0 = relay.Function(relay.analysis.free_vars(c), c) + f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) p01 = relay.var("p01", shape=dshape) c = conv(p01) f1 = relay.Function(relay.analysis.free_vars(c), c) + f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) p02 = relay.var("p02", shape=dshape) p12 = relay.var("p12", shape=dshape) concat1 = relay.concatenate((p02, p12), axis=1) f_concat1 = relay.Function([p02, p12], concat1) + f_concat1 = f_concat1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3]) p03 = relay.var("p03", shape=dshape2) c = conv(p03) f2 = relay.Function(relay.analysis.free_vars(c), c) + f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) p04 = relay.var("p04", shape=dshape2) c = conv(p04) f3 = relay.Function(relay.analysis.free_vars(c), c) + f3 = f3.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) p05 = relay.var("p05", shape=dshape) p15 = relay.var("p15", shape=dshape) concat2 = relay.concatenate((p05, p15), axis=1) f_concat2 = relay.Function([p05, p15], concat2) + f_concat2 = f_concat2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) c1 = relay.Call(f0, [x, relay.var("w1")]) @@ -499,6 +529,7 @@ def expected(): u = relay.transpose(y, axes=[0, 1]) w = relay.left_shift(z, u) f1 = relay.Function([x], w) + f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) return relay.Function([x], y) @@ -529,6 +560,7 @@ def expected(): z = relay.exp(y) w = relay.squeeze(z) f1 = relay.Function([x], w) + f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) mod = tvm.IRModule() @@ -570,6 +602,7 @@ def expected(): for i in range(max_fused_ops): y = relay.exp(y) f1 = relay.Function([x], y) + f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) z = relay.Call(f1, [x]) xx = relay.var("pp", shape=(10, 20)) @@ -577,6 +610,7 @@ def expected(): for i in range(n-max_fused_ops): yy = relay.exp(yy) f2 = relay.Function([xx], yy) + f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) zz = relay.Call(f2, [z]) return relay.Function([x], zz) diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index 4f785d7c915e..4f5acc707a52 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """Unit tests for merge composite.""" -from tvm import expr from tvm import relay +from tvm import tir from tvm.relay.testing import run_opt_pass """ @@ -144,6 +144,8 @@ def expected(): add_node = relay.add(in_1, in_2) relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) + add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1)) + add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu")) # merged function r = relay.Call(add_relu, [a, b]) @@ -208,11 +210,27 @@ def expected(): sub_node = relay.subtract(in_1, in_2) mul_node = relay.multiply(add_node, sub_node) add_sub_mul = relay.Function([in_1, in_2], mul_node) + add_sub_mul = add_sub_mul.set_attribute("Primitive", + tir.IntImm("int32", 1)) + add_sub_mul = add_sub_mul.set_attribute("Composite", + tir.StringImm("add_sub_mul")) + + # add_sub_mul1 function + in_3 = relay.var('in_3', shape=(10, 10)) + in_4 = relay.var('in_4', shape=(10, 10)) + add_node_1 = relay.add(in_3, in_4) + sub_node_1 = relay.subtract(in_3, in_4) + mul_node_1 = relay.multiply(add_node_1, sub_node_1) + add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1) + add_sub_mul_1 = add_sub_mul_1.set_attribute("Primitive", + tir.IntImm("int32", 1)) + add_sub_mul_1 = add_sub_mul_1.set_attribute("Composite", + tir.StringImm("add_sub_mul")) # merged function - add_sub_mul_1 = relay.Call(add_sub_mul, [a, b]) - add_sub_mul_2 = relay.Call(add_sub_mul, [c, add_sub_mul_1]) - r = relay.nn.relu(add_sub_mul_2) + m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b]) + m_add_sub_mul_2 = relay.Call(add_sub_mul_1, [c, m_add_sub_mul_1]) + r = relay.nn.relu(m_add_sub_mul_2) return relay.Function([a, b, c], r) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) @@ -291,6 +309,9 @@ def expected(): bias_node = relay.nn.bias_add(conv_node, in_3) r = relay.nn.relu(bias_node) conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r) + conv_bias_add_relu = conv_bias_add_relu.set_attribute("Primitive", tir.IntImm("int32", 1)) + conv_bias_add_relu = conv_bias_add_relu.set_attribute("Composite", + tir.StringImm("conv2d_bias_relu")) # add_relu function in_4 = relay.var('in_4', shape=(1, 256, 28, 28)) @@ -298,6 +319,8 @@ def expected(): add_node = relay.add(in_4, in_5) r = relay.nn.relu(add_node) add_relu = relay.Function([in_4, in_5], r) + add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1)) + add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu")) # merged function conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias]) @@ -357,7 +380,7 @@ def before(): out = relay.nn.relu(out) return relay.Function([input_1, input_2], out) - def after_A_priority(): + def after_A_priority(composite_name): input_1 = relay.var('input_1', shape=(10, 10)) input_2 = relay.var('input_2', shape=(10, 10)) x = relay.var('x') @@ -366,38 +389,12 @@ def after_A_priority(): out = relay.abs(out) out = relay.nn.relu(out) merged_func = relay.Function([x, y], out) - merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1)) - merged_func = merged_func.set_attribute('Composite', expr.StringImm('A')) + merged_func = merged_func.set_attribute('Primitive', tir.IntImm('int32', 1)) + merged_func = merged_func.set_attribute('Composite', + tir.StringImm(composite_name)) ret = relay.Call(merged_func, [input_1, input_2]) return relay.Function([input_1, input_2], ret) - def after_B_priority(): - input_1 = relay.var('input_1', shape=(10, 10)) - input_2 = relay.var('input_2', shape=(10, 10)) - x = relay.var('x') - y = relay.var('y') - out = relay.add(x, y) - out = relay.abs(out) - merged_func = relay.Function([x, y], out) - merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1)) - merged_func = merged_func.set_attribute('Composite', expr.StringImm('B')) - merged_call = relay.Call(merged_func, [input_1, input_2]) - ret = relay.nn.relu(merged_call) - return relay.Function([input_1, input_2], ret) - - def after_C_priority(): - input_1 = relay.var('input_1', shape=(10, 10)) - input_2 = relay.var('input_2', shape=(10, 10)) - add = relay.add(input_1, input_2) - x = relay.var('x') - out = relay.abs(x) - out = relay.nn.relu(out) - merged_func = relay.Function([x], out) - merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1)) - merged_func = merged_func.set_attribute('Composite', expr.StringImm('C')) - ret = relay.Call(merged_func, [add]) - return relay.Function([input_1, input_2], ret) - # check A highest priority pattern_table = [ ("A", pattern_A()), @@ -406,7 +403,7 @@ def after_C_priority(): ] result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_A_priority(), relay.transform.InferType()) + expected = run_opt_pass(after_A_priority("A"), relay.transform.InferType()) assert relay.analysis.alpha_equal(result, expected) # check B highest priority @@ -417,7 +414,7 @@ def after_C_priority(): ] result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_A_priority(), relay.transform.InferType()) + expected = run_opt_pass(after_A_priority("B"), relay.transform.InferType()) assert relay.analysis.alpha_equal(result, expected) # check C highest priority @@ -428,7 +425,7 @@ def after_C_priority(): ] result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_A_priority(), relay.transform.InferType()) + expected = run_opt_pass(after_A_priority("C"), relay.transform.InferType()) assert relay.analysis.alpha_equal(result, expected) @@ -459,11 +456,15 @@ def after(): y = relay.var('y') branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) func_1 = relay.Function([x, y], branch_1) + func_1 = func_1.set_attribute('Primitive', tir.IntImm('int32', 1)) + func_1 = func_1.set_attribute('Composite', tir.StringImm("add_sub_mul")) call_1 = relay.Call(func_1, [input_1, input_2]) x1 = relay.var('x1') y1 = relay.var('y1') branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) func_2 = relay.Function([x1, y1], branch_2) + func_2 = func_2.set_attribute('Primitive', tir.IntImm('int32', 1)) + func_2 = func_2.set_attribute('Composite', tir.StringImm("add_sub_mul")) call_2 = relay.Call(func_2, [input_1, input_2]) out = relay.multiply(call_1, call_2) return relay.Function([input_1, input_2], out) @@ -542,16 +543,16 @@ def after_A(): add_relu_1 = relay.add(x, y) add_relu_1 = relay.nn.relu(add_relu_1) add_relu_1 = relay.Function([x, y], add_relu_1) - add_relu_1 = add_relu_1.set_attribute('Primitive', expr.IntImm('int32', 1)) - add_relu_1 = add_relu_1.set_attribute('Composite', expr.StringImm('add_relu')) + add_relu_1 = add_relu_1.set_attribute('Primitive', tir.IntImm('int32', 1)) + add_relu_1 = add_relu_1.set_attribute('Composite', tir.StringImm('add_relu')) add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]]) x1 = relay.var('x1') y1 = relay.var('y1') add_relu_2 = relay.add(x1, y1) add_relu_2 = relay.nn.relu(add_relu_2) add_relu_2 = relay.Function([x1, y1], add_relu_2) - add_relu_2 = add_relu_2.set_attribute('Primitive', expr.IntImm('int32', 1)) - add_relu_2 = add_relu_2.set_attribute('Composite', expr.StringImm('add_relu')) + add_relu_2 = add_relu_2.set_attribute('Primitive', tir.IntImm('int32', 1)) + add_relu_2 = add_relu_2.set_attribute('Composite', tir.StringImm('add_relu')) add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]]) x2 = relay.var('x2') y2 = relay.var('y2') @@ -559,8 +560,8 @@ def after_A(): sub = relay.subtract(x2, y2) add_sub_mul = relay.multiply(add, sub) add_sub_mul = relay.Function([x2, y2], add_sub_mul) - add_sub_mul = add_sub_mul.set_attribute('Primitive', expr.IntImm('int32', 1)) - add_sub_mul = add_sub_mul.set_attribute('Composite', expr.StringImm('add_sub_mul')) + add_sub_mul = add_sub_mul.set_attribute('Primitive', tir.IntImm('int32', 1)) + add_sub_mul = add_sub_mul.set_attribute('Composite', tir.StringImm('add_sub_mul')) add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2]) return relay.Function(inputs, add_sub_mul_call) @@ -573,8 +574,8 @@ def after_B(): add_relu = relay.add(x, y) add_relu = relay.nn.relu(add_relu) add_relu = relay.Function([x, y], add_relu) - add_relu = add_relu.set_attribute('Primitive', expr.IntImm('int32', 1)) - add_relu = add_relu.set_attribute('Composite', expr.StringImm('add_relu')) + add_relu = add_relu.set_attribute('Primitive', tir.IntImm('int32', 1)) + add_relu = add_relu.set_attribute('Composite', tir.StringImm('add_relu')) add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]]) add_relu_calls.append(add_relu_call) @@ -606,4 +607,4 @@ def after_B(): test_multiple_patterns() test_merge_order() test_parallel_merge() - test_multiple_input_subgraphs() \ No newline at end of file + test_multiple_input_subgraphs()