Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] Remove primitive attribute from composite function #5014

Merged
merged 1 commit into from
Mar 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/relay/pass/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ class MergeCompositeWrapper : public ExprMutator {
// make the composite function
auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {}, Attrs());
f = FunctionSetAttr(f, attr::kComposite, tir::StringImmNode::make(pattern_name_));
f = FunctionSetAttr(f, attr::kPrimitive, tvm::Integer(1));
// find the expressions associated with the free vars using the args_map
// this tells us which expressions should be given as inputs to the composite function
Array<Expr> args;
Expand Down
16 changes: 0 additions & 16 deletions tests/python/relay/test_pass_merge_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def expected():
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu"))

# merged function
Expand Down Expand Up @@ -230,8 +229,6 @@ def expected():
sub_node = relay.subtract(in_1, in_2)
mul_node = relay.multiply(add_node, sub_node)
add_sub_mul = relay.Function([in_1, in_2], mul_node)
add_sub_mul = add_sub_mul.set_attribute("Primitive",
tir.IntImm("int32", 1))
add_sub_mul = add_sub_mul.set_attribute("Composite",
tir.StringImm("add_sub_mul"))

Expand All @@ -242,8 +239,6 @@ def expected():
sub_node_1 = relay.subtract(in_3, in_4)
mul_node_1 = relay.multiply(add_node_1, sub_node_1)
add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
add_sub_mul_1 = add_sub_mul_1.set_attribute("Primitive",
tir.IntImm("int32", 1))
add_sub_mul_1 = add_sub_mul_1.set_attribute("Composite",
tir.StringImm("add_sub_mul"))

Expand Down Expand Up @@ -304,8 +299,6 @@ def expected():
add_node_1 = relay.add(in_1, add_node)
add_node_2 = relay.add(add_node_1, add_node)
add_add_add = relay.Function([in_1, in_2], add_node_2)
add_add_add = add_add_add.set_attribute("Primitive",
tir.IntImm("int32", 1))
add_add_add = add_add_add.set_attribute("Composite",
tir.StringImm("add_add_add"))

Expand Down Expand Up @@ -390,7 +383,6 @@ def expected():
bias_node = relay.nn.bias_add(conv_node, in_3)
r = relay.nn.relu(bias_node)
conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
conv_bias_add_relu = conv_bias_add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
conv_bias_add_relu = conv_bias_add_relu.set_attribute("Composite",
tir.StringImm("conv2d_bias_relu"))

Expand All @@ -400,7 +392,6 @@ def expected():
add_node = relay.add(in_4, in_5)
r = relay.nn.relu(add_node)
add_relu = relay.Function([in_4, in_5], r)
add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu"))

# merged function
Expand Down Expand Up @@ -470,7 +461,6 @@ def after_A_priority(composite_name):
out = relay.abs(out)
out = relay.nn.relu(out)
merged_func = relay.Function([x, y], out)
merged_func = merged_func.set_attribute('Primitive', tir.IntImm('int32', 1))
merged_func = merged_func.set_attribute('Composite',
tir.StringImm(composite_name))
ret = relay.Call(merged_func, [input_1, input_2])
Expand Down Expand Up @@ -537,14 +527,12 @@ def after():
y = relay.var('y')
branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
func_1 = relay.Function([x, y], branch_1)
func_1 = func_1.set_attribute('Primitive', tir.IntImm('int32', 1))
func_1 = func_1.set_attribute('Composite', tir.StringImm("add_sub_mul"))
call_1 = relay.Call(func_1, [input_1, input_2])
x1 = relay.var('x1')
y1 = relay.var('y1')
branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
func_2 = relay.Function([x1, y1], branch_2)
func_2 = func_2.set_attribute('Primitive', tir.IntImm('int32', 1))
func_2 = func_2.set_attribute('Composite', tir.StringImm("add_sub_mul"))
call_2 = relay.Call(func_2, [input_1, input_2])
out = relay.multiply(call_1, call_2)
Expand Down Expand Up @@ -624,15 +612,13 @@ def after_A():
add_relu_1 = relay.add(x, y)
add_relu_1 = relay.nn.relu(add_relu_1)
add_relu_1 = relay.Function([x, y], add_relu_1)
add_relu_1 = add_relu_1.set_attribute('Primitive', tir.IntImm('int32', 1))
add_relu_1 = add_relu_1.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
x1 = relay.var('x1')
y1 = relay.var('y1')
add_relu_2 = relay.add(x1, y1)
add_relu_2 = relay.nn.relu(add_relu_2)
add_relu_2 = relay.Function([x1, y1], add_relu_2)
add_relu_2 = add_relu_2.set_attribute('Primitive', tir.IntImm('int32', 1))
add_relu_2 = add_relu_2.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
x2 = relay.var('x2')
Expand All @@ -641,7 +627,6 @@ def after_A():
sub = relay.subtract(x2, y2)
add_sub_mul = relay.multiply(add, sub)
add_sub_mul = relay.Function([x2, y2], add_sub_mul)
add_sub_mul = add_sub_mul.set_attribute('Primitive', tir.IntImm('int32', 1))
add_sub_mul = add_sub_mul.set_attribute('Composite', tir.StringImm('add_sub_mul'))
add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
return relay.Function(inputs, add_sub_mul_call)
Expand All @@ -655,7 +640,6 @@ def after_B():
add_relu = relay.add(x, y)
add_relu = relay.nn.relu(add_relu)
add_relu = relay.Function([x, y], add_relu)
add_relu = add_relu.set_attribute('Primitive', tir.IntImm('int32', 1))
add_relu = add_relu.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
add_relu_calls.append(add_relu_call)
Expand Down