-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BYOC-DNNL] add post_sum pattern #12151
Conversation
e8d8a67
to
b9ed19d
Compare
fe68b77
to
63a44ba
Compare
tests/python/contrib/test_dnnl.py
Outdated
@@ -192,6 +192,7 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, te | |||
if use_dnnl: | |||
processed_mod = partition_for_dnnl(processed_mod, params, alter_layout) | |||
check_dnnl_used(processed_mod) | |||
print(processed_mod) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
63a44ba
to
d89db44
Compare
python/tvm/relay/op/contrib/dnnl.py
Outdated
dnnl_patterns.append( | ||
("dnnl.conv2d_bias_sum_relu", make_conv_bias_sum_relu_pattern("nn.conv2d"), make_predicate(add_checker)) | ||
), | ||
dnnl_patterns.append( | ||
("dnnl.conv2d_bias_sum", make_conv_bias_sum_relu_pattern("nn.conv2d", False), make_predicate(add_checker)) | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dnnl_patterns.append( | |
("dnnl.conv2d_bias_sum_relu", make_conv_bias_sum_relu_pattern("nn.conv2d"), make_predicate(add_checker)) | |
), | |
dnnl_patterns.append( | |
("dnnl.conv2d_bias_sum", make_conv_bias_sum_relu_pattern("nn.conv2d", False), make_predicate(add_checker)) | |
), | |
dnnl_patterns.append( | |
( | |
"dnnl.conv2d_bias_sum_relu", | |
make_conv_bias_sum_relu_pattern("nn.conv2d"), | |
make_predicate(add_checker), | |
) | |
) | |
dnnl_patterns.append( | |
( | |
"dnnl.conv2d_bias_sum", | |
make_conv_bias_sum_relu_pattern("nn.conv2d", False), | |
make_predicate(add_checker), | |
) | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
@@ -361,6 +367,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase { | |||
|
|||
// TODO(@apeskov): Simulation of inplace primitive. just as PoC. | |||
auto sum_in_tr = GetInputByName(nid, "sum_idx").TreatAs(dst_layout); | |||
if (op_name.find("_sum") != std::string::npos) { | |||
sum_in_tr = GetInput(nid, node.GetInputs().size()-1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sum_in_tr = GetInput(nid, node.GetInputs().size()-1); | |
sum_in_tr = GetInput(nid, node.GetInputs().size() - 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
tests/python/contrib/test_dnnl.py
Outdated
param_lst += ["data1"] | ||
return relay.nn.relu(out), dic, param_lst | ||
|
||
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype) | |
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu( | |
x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
tests/python/contrib/test_dnnl.py
Outdated
config = conv2d_bn_sum_relu, dic, param_lst | ||
run_and_verify_func(config, run_module=run_module, dtype=dtype) | ||
|
||
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype) | |
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu( | |
x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
tests/python/contrib/test_dnnl.py
Outdated
# tvm.testing.main() | ||
test_conv2d_bias_sum_relu(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# tvm.testing.main() | |
test_conv2d_bias_sum_relu(True) | |
tvm.testing.main() | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
654bab6
to
b2dca37
Compare
b2dca37
to
001898f
Compare
@masahi Could you please review this PR? This PR adds |
* add post_sum pattern * add checkers for sum pattern * fix lint * fix error in test_pass_partition_graph * fix lint error
This PR add
conv2d-add-sum-relu
pattern, and the corresponding test case is added.