Skip to content

Commit

Permalink
[BYOC-DNNL] add post_sum pattern (#12151)
Browse files Browse the repository at this point in the history
* add post_sum pattern

* add checkers for sum pattern

* fix lint

* fix error in test_pass_partition_graph

* fix lint error
  • Loading branch information
crazydemo authored Aug 1, 2022
1 parent a842449 commit c07d77f
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 4 deletions.
106 changes: 105 additions & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
check the attributes of the op and decide if it should be offloaded to DNNL.
"""
import logging
from functools import reduce

import tvm.ir
from tvm.ir import Op
from tvm import relay
from tvm.relay import transform
from tvm.relay.expr import GlobalVar
Expand All @@ -44,7 +46,7 @@
from tvm.relay.analysis import analysis as _analysis
from tvm.relay import expr as _expr


from tvm.relay.expr import Call, TupleGetItem
from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback
from .register import register_pattern_table
Expand Down Expand Up @@ -167,6 +169,94 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
return append_eltwise_ops(conv_out, with_eltwise)


def make_conv_bias_sum_relu_pattern(conv_type, has_relu=True):
"""Create patterns with sum op.
Parameters
----------
conv_type : str
Should be nn.conv1d / nn.conv2d / nn.conv3d.
has_relu : bool
Whether attach relu.
Returns
-------
out : CallPattern
Call node sequence.
"""
data1 = wildcard()
weight = wildcard()
bias = wildcard()
data2 = wildcard()
out = is_op(conv_type)(data1, weight)
out = is_op("add")(out, bias)
out = is_op("add")(out, data2)
if has_relu:
out = is_op("nn.relu")(out)
return out


def get_op_name(expr):
"""Get the operator name from an expression."""
if isinstance(expr, Op):
return expr.name
if isinstance(expr, Call):
return get_op_name(expr.op)
if isinstance(expr, TupleGetItem):
return get_op_name(expr.tuple_value)
if isinstance(expr, relay.Tuple):
return get_op_name(expr.fields[0])
return ""


def get_args(expr):
"""Get the arguments from an expression."""
if isinstance(expr, Call):
return expr.args
if isinstance(expr, TupleGetItem):
return get_args(expr.tuple_value)
if isinstance(expr, relay.Tuple):
return [arg for args in map(get_args, expr.fields) for arg in args]
return []


def get_attrs(expr):
"""Get the attributes from an expression."""
if isinstance(expr, Call):
return expr.attrs
if isinstance(expr, TupleGetItem):
return get_attrs(expr.tuple_value)
return {}


def make_predicate(checker):
"""Check whether the conv_bias_add_sum pattern is as expected."""

def predicate(expr):
if get_op_name(expr) == "nn.relu":
expr = expr.args[0]
for e, op_name in zip([expr, expr.args[0]], ["sum", "bias_add"]):
args = get_args(e)
attrs = get_attrs(e.args[0])
if not checker(attrs, args, op_name):
return False
return True

return predicate


def add_checker(attrs, args, op_name):
"""Check if add is supported by DNNL."""
if op_name == "sum":
if tuple(get_shape(args[0])) != tuple(get_shape(args[1])):
return False
if op_name == "bias_add":
channel = dict(attrs)["channels"]
const_shape = get_shape(args[1])
if channel != reduce(lambda x, y: x * y, const_shape):
return False
return True


def make_dense_pattern(with_bias=True, with_eltwise=None):
"""Create patterns related to nn.dense.
Expand Down Expand Up @@ -306,6 +396,20 @@ def pattern_table():
dnnl_patterns = list()
dnnl_patterns.append(make_qnn_conv2d_pattern())
dnnl_patterns.append(make_qnn_dense_pattern())
dnnl_patterns.append(
(
"dnnl.conv2d_bias_sum_relu",
make_conv_bias_sum_relu_pattern("nn.conv2d"),
make_predicate(add_checker),
)
)
dnnl_patterns.append(
(
"dnnl.conv2d_bias_sum",
make_conv_bias_sum_relu_pattern("nn.conv2d", False),
make_predicate(add_checker),
)
)

elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
for with_bias in [True, False]:
Expand Down
14 changes: 12 additions & 2 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr;

// parsing of name to extract attributes
auto op_name = nodes_[nid].GetOpName();
// Define RegExp.
std::regex bias_add_pat(".*_bias.*");
std::regex relu_pat(".*_relu.*");
Expand All @@ -192,9 +190,16 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::regex clip_pat(".*_clip.*");
std::regex gelu_pat(".*_gelu.*");
std::regex swish_pat(".*_swish.*");
std::regex sum_pat(".*_sum.*");

// parsing of name to extract attributes
auto op_name = nodes_[nid].GetOpName();

// Parsing post-ops.
dnnl::post_ops ops;
if (std::regex_match(op_name, sum_pat)) {
ops.append_sum(1.f);
}
if (std::regex_match(op_name, relu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
}
Expand Down Expand Up @@ -280,6 +285,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

void Convolution(const size_t& nid) {
auto node = nodes_[nid];
auto op_name = nodes_[nid].GetOpName();

// Setup attributes.
auto src_tr = GetInput(nid, 0);
Expand Down Expand Up @@ -361,6 +367,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

// TODO(@apeskov): Simulation of inplace primitive. just as PoC.
auto sum_in_tr = GetInputByName(nid, "sum_idx").TreatAs(dst_layout);
if (op_name.find("_sum") != std::string::npos) {
sum_in_tr = GetInput(nid, node.GetInputs().size() - 1);
sum_in_tr = sum_in_tr.TreatAs(dst_layout);
}

Submit(dnnl::convolution_forward(conv_prim_desc),
{{DNNL_ARG_SRC, src_tr},
Expand Down
42 changes: 42 additions & 0 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,48 @@ def test_conv2d_pattern(run_module, dtype="float32"):
run_and_verify_func(config, run_module=run_module, dtype=dtype)


def test_conv2d_bias_sum_relu(run_module, dtype="float32"):
x_shape = (1, 32, 8, 8)
k_shape = (16, 32, 3, 3)

def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"):
out, dic, param_lst = get_conv2d_bias(x_shape=x_shape, k_shape=k_shape, dtype=dtype)
beta = relay.const(np.zeros(k_shape[0]).astype(dtype))
gamma = relay.const(np.ones(k_shape[0]).astype(dtype))
moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype))
moving_var = relay.const(np.ones(k_shape[0]).astype(dtype))
out, _, _ = relay.nn.batch_norm(
out,
gamma=gamma,
beta=beta,
moving_mean=moving_mean,
moving_var=moving_var,
axis=1,
center=True,
scale=True,
epsilon=1e-5,
)
sum_data = relay.var("data1", shape=sum_shape, dtype=dtype)
out = relay.add(out, sum_data)
dic["data1"] = sum_shape
param_lst += ["data1"]
return relay.nn.relu(out), dic, param_lst

conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype
)
conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu)
config = conv2d_bn_sum_relu, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)

conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype
)
conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu)
config = conv2d_bn_sum_relu, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)


def test_conv2d_transpose(run_module, dtype="float32"):
x_shape = (1, 32, 8, 8)
for k_shape, groups in [((32, 16, 3, 3), 1), ((32, 1, 3, 3), 32), ((32, 4, 3, 3), 16)]:
Expand Down
6 changes: 5 additions & 1 deletion tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,11 @@ def expected():

def test_dnnl_fuse():
dnnl_patterns = get_pattern_table("dnnl")
dnnl_pat_dic = dict(dnnl_patterns)
valid_pats = list()
for pattern in dnnl_patterns:
if len(pattern) == 2:
valid_pats.append(pattern)
dnnl_pat_dic = dict(valid_pats)
(
conv2d_bias_relu_pat,
conv2d_bias_sigmoid_pat,
Expand Down

0 comments on commit c07d77f

Please sign in to comment.