diff --git a/src/relay/pass/merge_composite.cc b/src/relay/pass/merge_composite.cc index 4e1094b617e9..162bf3a2bba6 100644 --- a/src/relay/pass/merge_composite.cc +++ b/src/relay/pass/merge_composite.cc @@ -168,7 +168,6 @@ class MergeCompositeWrapper : public ExprMutator { // make the composite function auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {}, Attrs()); f = FunctionSetAttr(f, attr::kComposite, tir::StringImmNode::make(pattern_name_)); - f = FunctionSetAttr(f, attr::kPrimitive, tvm::Integer(1)); // find the expressions associated with the free vars using the args_map // this tells us which expressions should be given as inputs to the composite function Array args; diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index b96a89b1f483..bcf61a01f1db 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -164,7 +164,6 @@ def expected(): add_node = relay.add(in_1, in_2) relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) - add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1)) add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu")) # merged function @@ -230,8 +229,6 @@ def expected(): sub_node = relay.subtract(in_1, in_2) mul_node = relay.multiply(add_node, sub_node) add_sub_mul = relay.Function([in_1, in_2], mul_node) - add_sub_mul = add_sub_mul.set_attribute("Primitive", - tir.IntImm("int32", 1)) add_sub_mul = add_sub_mul.set_attribute("Composite", tir.StringImm("add_sub_mul")) @@ -242,8 +239,6 @@ def expected(): sub_node_1 = relay.subtract(in_3, in_4) mul_node_1 = relay.multiply(add_node_1, sub_node_1) add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1) - add_sub_mul_1 = add_sub_mul_1.set_attribute("Primitive", - tir.IntImm("int32", 1)) add_sub_mul_1 = add_sub_mul_1.set_attribute("Composite", tir.StringImm("add_sub_mul")) @@ -304,8 +299,6 @@ def expected(): add_node_1 = relay.add(in_1, add_node) add_node_2 = relay.add(add_node_1, add_node) add_add_add = relay.Function([in_1, in_2], add_node_2) - add_add_add = add_add_add.set_attribute("Primitive", - tir.IntImm("int32", 1)) add_add_add = add_add_add.set_attribute("Composite", tir.StringImm("add_add_add")) @@ -390,7 +383,6 @@ def expected(): bias_node = relay.nn.bias_add(conv_node, in_3) r = relay.nn.relu(bias_node) conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r) - conv_bias_add_relu = conv_bias_add_relu.set_attribute("Primitive", tir.IntImm("int32", 1)) conv_bias_add_relu = conv_bias_add_relu.set_attribute("Composite", tir.StringImm("conv2d_bias_relu")) @@ -400,7 +392,6 @@ def expected(): add_node = relay.add(in_4, in_5) r = relay.nn.relu(add_node) add_relu = relay.Function([in_4, in_5], r) - add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1)) add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu")) # merged function @@ -470,7 +461,6 @@ def after_A_priority(composite_name): out = relay.abs(out) out = relay.nn.relu(out) merged_func = relay.Function([x, y], out) - merged_func = merged_func.set_attribute('Primitive', tir.IntImm('int32', 1)) merged_func = merged_func.set_attribute('Composite', tir.StringImm(composite_name)) ret = relay.Call(merged_func, [input_1, input_2]) @@ -537,14 +527,12 @@ def after(): y = relay.var('y') branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) func_1 = relay.Function([x, y], branch_1) - func_1 = func_1.set_attribute('Primitive', tir.IntImm('int32', 1)) func_1 = func_1.set_attribute('Composite', tir.StringImm("add_sub_mul")) call_1 = relay.Call(func_1, [input_1, input_2]) x1 = relay.var('x1') y1 = relay.var('y1') branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) func_2 = relay.Function([x1, y1], branch_2) - func_2 = func_2.set_attribute('Primitive', tir.IntImm('int32', 1)) func_2 = func_2.set_attribute('Composite', tir.StringImm("add_sub_mul")) call_2 = relay.Call(func_2, [input_1, input_2]) out = relay.multiply(call_1, call_2) @@ -624,7 +612,6 @@ def after_A(): add_relu_1 = relay.add(x, y) add_relu_1 = relay.nn.relu(add_relu_1) add_relu_1 = relay.Function([x, y], add_relu_1) - add_relu_1 = add_relu_1.set_attribute('Primitive', tir.IntImm('int32', 1)) add_relu_1 = add_relu_1.set_attribute('Composite', tir.StringImm('add_relu')) add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]]) x1 = relay.var('x1') @@ -632,7 +619,6 @@ def after_A(): add_relu_2 = relay.add(x1, y1) add_relu_2 = relay.nn.relu(add_relu_2) add_relu_2 = relay.Function([x1, y1], add_relu_2) - add_relu_2 = add_relu_2.set_attribute('Primitive', tir.IntImm('int32', 1)) add_relu_2 = add_relu_2.set_attribute('Composite', tir.StringImm('add_relu')) add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]]) x2 = relay.var('x2') @@ -641,7 +627,6 @@ def after_A(): sub = relay.subtract(x2, y2) add_sub_mul = relay.multiply(add, sub) add_sub_mul = relay.Function([x2, y2], add_sub_mul) - add_sub_mul = add_sub_mul.set_attribute('Primitive', tir.IntImm('int32', 1)) add_sub_mul = add_sub_mul.set_attribute('Composite', tir.StringImm('add_sub_mul')) add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2]) return relay.Function(inputs, add_sub_mul_call) @@ -655,7 +640,6 @@ def after_B(): add_relu = relay.add(x, y) add_relu = relay.nn.relu(add_relu) add_relu = relay.Function([x, y], add_relu) - add_relu = add_relu.set_attribute('Primitive', tir.IntImm('int32', 1)) add_relu = add_relu.set_attribute('Composite', tir.StringImm('add_relu')) add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]]) add_relu_calls.append(add_relu_call)