Skip to content

Commit

Permalink
[RELAY] Remove primitive attribute from composite function (apache#5014)
Browse files Browse the repository at this point in the history
* A composite function should not be primitive since we still may need to perform passes on it.

Change-Id: If62d06d265234861a6ec0df7749dc1c339c1055c
  • Loading branch information
lhutton1 authored and zhiics committed Apr 17, 2020
1 parent 12f9a0a commit 1ae0d80
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 17 deletions.
1 change: 0 additions & 1 deletion src/relay/pass/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ class MergeCompositeWrapper : public ExprMutator {
// make the composite function
auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {}, Attrs());
f = FunctionSetAttr(f, attr::kComposite, tir::StringImmNode::make(pattern_name_));
f = FunctionSetAttr(f, attr::kPrimitive, tvm::Integer(1));
// find the expressions associated with the free vars using the args_map
// this tells us which expressions should be given as inputs to the composite function
Array<Expr> args;
Expand Down
16 changes: 0 additions & 16 deletions tests/python/relay/test_pass_merge_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def expected():
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu"))

# merged function
Expand Down Expand Up @@ -230,8 +229,6 @@ def expected():
sub_node = relay.subtract(in_1, in_2)
mul_node = relay.multiply(add_node, sub_node)
add_sub_mul = relay.Function([in_1, in_2], mul_node)
add_sub_mul = add_sub_mul.set_attribute("Primitive",
tir.IntImm("int32", 1))
add_sub_mul = add_sub_mul.set_attribute("Composite",
tir.StringImm("add_sub_mul"))

Expand All @@ -242,8 +239,6 @@ def expected():
sub_node_1 = relay.subtract(in_3, in_4)
mul_node_1 = relay.multiply(add_node_1, sub_node_1)
add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
add_sub_mul_1 = add_sub_mul_1.set_attribute("Primitive",
tir.IntImm("int32", 1))
add_sub_mul_1 = add_sub_mul_1.set_attribute("Composite",
tir.StringImm("add_sub_mul"))

Expand Down Expand Up @@ -304,8 +299,6 @@ def expected():
add_node_1 = relay.add(in_1, add_node)
add_node_2 = relay.add(add_node_1, add_node)
add_add_add = relay.Function([in_1, in_2], add_node_2)
add_add_add = add_add_add.set_attribute("Primitive",
tir.IntImm("int32", 1))
add_add_add = add_add_add.set_attribute("Composite",
tir.StringImm("add_add_add"))

Expand Down Expand Up @@ -390,7 +383,6 @@ def expected():
bias_node = relay.nn.bias_add(conv_node, in_3)
r = relay.nn.relu(bias_node)
conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
conv_bias_add_relu = conv_bias_add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
conv_bias_add_relu = conv_bias_add_relu.set_attribute("Composite",
tir.StringImm("conv2d_bias_relu"))

Expand All @@ -400,7 +392,6 @@ def expected():
add_node = relay.add(in_4, in_5)
r = relay.nn.relu(add_node)
add_relu = relay.Function([in_4, in_5], r)
add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu"))

# merged function
Expand Down Expand Up @@ -470,7 +461,6 @@ def after_A_priority(composite_name):
out = relay.abs(out)
out = relay.nn.relu(out)
merged_func = relay.Function([x, y], out)
merged_func = merged_func.set_attribute('Primitive', tir.IntImm('int32', 1))
merged_func = merged_func.set_attribute('Composite',
tir.StringImm(composite_name))
ret = relay.Call(merged_func, [input_1, input_2])
Expand Down Expand Up @@ -537,14 +527,12 @@ def after():
y = relay.var('y')
branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
func_1 = relay.Function([x, y], branch_1)
func_1 = func_1.set_attribute('Primitive', tir.IntImm('int32', 1))
func_1 = func_1.set_attribute('Composite', tir.StringImm("add_sub_mul"))
call_1 = relay.Call(func_1, [input_1, input_2])
x1 = relay.var('x1')
y1 = relay.var('y1')
branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
func_2 = relay.Function([x1, y1], branch_2)
func_2 = func_2.set_attribute('Primitive', tir.IntImm('int32', 1))
func_2 = func_2.set_attribute('Composite', tir.StringImm("add_sub_mul"))
call_2 = relay.Call(func_2, [input_1, input_2])
out = relay.multiply(call_1, call_2)
Expand Down Expand Up @@ -624,15 +612,13 @@ def after_A():
add_relu_1 = relay.add(x, y)
add_relu_1 = relay.nn.relu(add_relu_1)
add_relu_1 = relay.Function([x, y], add_relu_1)
add_relu_1 = add_relu_1.set_attribute('Primitive', tir.IntImm('int32', 1))
add_relu_1 = add_relu_1.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
x1 = relay.var('x1')
y1 = relay.var('y1')
add_relu_2 = relay.add(x1, y1)
add_relu_2 = relay.nn.relu(add_relu_2)
add_relu_2 = relay.Function([x1, y1], add_relu_2)
add_relu_2 = add_relu_2.set_attribute('Primitive', tir.IntImm('int32', 1))
add_relu_2 = add_relu_2.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
x2 = relay.var('x2')
Expand All @@ -641,7 +627,6 @@ def after_A():
sub = relay.subtract(x2, y2)
add_sub_mul = relay.multiply(add, sub)
add_sub_mul = relay.Function([x2, y2], add_sub_mul)
add_sub_mul = add_sub_mul.set_attribute('Primitive', tir.IntImm('int32', 1))
add_sub_mul = add_sub_mul.set_attribute('Composite', tir.StringImm('add_sub_mul'))
add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
return relay.Function(inputs, add_sub_mul_call)
Expand All @@ -655,7 +640,6 @@ def after_B():
add_relu = relay.add(x, y)
add_relu = relay.nn.relu(add_relu)
add_relu = relay.Function([x, y], add_relu)
add_relu = add_relu.set_attribute('Primitive', tir.IntImm('int32', 1))
add_relu = add_relu.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
add_relu_calls.append(add_relu_call)
Expand Down

0 comments on commit 1ae0d80

Please sign in to comment.