Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC-DNNL] add post_sum pattern #12151

Merged
merged 5 commits into from
Aug 1, 2022
Merged

Conversation

crazydemo
Copy link
Contributor

This PR add conv2d-add-sum-relu pattern, and the corresponding test case is added.

@crazydemo crazydemo force-pushed the upstream-sum_pattern branch 3 times, most recently from fe68b77 to 63a44ba Compare July 26, 2022 02:37
@@ -192,6 +192,7 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, te
if use_dnnl:
processed_mod = partition_for_dnnl(processed_mod, params, alter_layout)
check_dnnl_used(processed_mod)
print(processed_mod)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment on lines 398 to 403
dnnl_patterns.append(
("dnnl.conv2d_bias_sum_relu", make_conv_bias_sum_relu_pattern("nn.conv2d"), make_predicate(add_checker))
),
dnnl_patterns.append(
("dnnl.conv2d_bias_sum", make_conv_bias_sum_relu_pattern("nn.conv2d", False), make_predicate(add_checker))
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dnnl_patterns.append(
("dnnl.conv2d_bias_sum_relu", make_conv_bias_sum_relu_pattern("nn.conv2d"), make_predicate(add_checker))
),
dnnl_patterns.append(
("dnnl.conv2d_bias_sum", make_conv_bias_sum_relu_pattern("nn.conv2d", False), make_predicate(add_checker))
),
dnnl_patterns.append(
(
"dnnl.conv2d_bias_sum_relu",
make_conv_bias_sum_relu_pattern("nn.conv2d"),
make_predicate(add_checker),
)
)
dnnl_patterns.append(
(
"dnnl.conv2d_bias_sum",
make_conv_bias_sum_relu_pattern("nn.conv2d", False),
make_predicate(add_checker),
)
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@@ -361,6 +367,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

// TODO(@apeskov): Simulation of inplace primitive. just as PoC.
auto sum_in_tr = GetInputByName(nid, "sum_idx").TreatAs(dst_layout);
if (op_name.find("_sum") != std::string::npos) {
sum_in_tr = GetInput(nid, node.GetInputs().size()-1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sum_in_tr = GetInput(nid, node.GetInputs().size()-1);
sum_in_tr = GetInput(nid, node.GetInputs().size() - 1);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

param_lst += ["data1"]
return relay.nn.relu(out), dic, param_lst

conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype)
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

config = conv2d_bn_sum_relu, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)

conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype)
conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment on lines 1774 to 1775
# tvm.testing.main()
test_conv2d_bias_sum_relu(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# tvm.testing.main()
test_conv2d_bias_sum_relu(True)
tvm.testing.main()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@crazydemo
Copy link
Contributor Author

@masahi Could you please review this PR? This PR adds conv2d-add-sum-relu pattern with required checks, and the corresponding test case is added.

@masahi masahi merged commit c07d77f into apache:main Aug 1, 2022
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
* add post_sum pattern

* add checkers for sum pattern

* fix lint

* fix error in test_pass_partition_graph

* fix lint error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants