Skip to content

Commit

Permalink
DNNL-BYOC enhancement (#9797)
Browse files Browse the repository at this point in the history
* add unit test for byoc-dnnl

* add byoc-dnnl pattern and their test cases
  • Loading branch information
crazydemo authored Dec 29, 2021
1 parent d56ca35 commit 75cd670
Show file tree
Hide file tree
Showing 5 changed files with 622 additions and 52 deletions.
140 changes: 135 additions & 5 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@
- The other way is to implement the function by themselves to
check the attributes of the op and decide if it should be offloaded to DNNL.
"""
import logging

import tvm.ir
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name

from ...dataflow_pattern import wildcard, is_op
from .register import register_pattern_table

logger = logging.getLogger("DNNL")


def _register_external_op_helper(op_name, supported=True):
"""The helper function to indicate that a given operator can be supported
Expand Down Expand Up @@ -63,11 +70,26 @@ def _func_wrapper(expr):
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
_register_external_op_helper("tanh")
_register_external_op_helper("sigmoid")
_register_external_op_helper("add")
_register_external_op_helper("multiply")


def make_pattern(with_bias=True):
def make_conv_pattern(with_bias=True, with_eltwise=None):
"""Create patterns related to nn.conv2d.
Parameters
----------
with_bias : bool
Whether attach `bias_add` to `nn.conv2d`.
with_eltwise : str
The attached elementwise post-op name.
Returns
-------
conv_out : CallPattern
Call node sequence.
"""
data = wildcard()
weight = wildcard()
bias = wildcard()
Expand All @@ -76,12 +98,120 @@ def make_pattern(with_bias=True):
conv_out = is_op("add")(conv, bias)
else:
conv_out = conv
return is_op("nn.relu")(conv_out)
if with_eltwise:
return is_op(with_eltwise)(conv_out)
return conv_out


def make_dense_pattern(with_bias=True, with_eltwise=None):
"""Create patterns related to nn.dense.
Parameters
----------
with_bias : bool
Whether attach `bias_add` to `nn.dense`.
with_eltwise : str
The attached elementwise post-op name.
Returns
-------
dense_out : CallPattern
Call node sequence.
"""
data = wildcard()
weight = wildcard()
bias = wildcard()
dense = is_op("nn.dense")(data, weight)
if with_bias:
dense_out = is_op("add")(dense, bias)
else:
dense_out = dense
if with_eltwise:
dense_out = is_op(with_eltwise)(dense_out)
return dense_out


def make_dnnl_pattern(op, with_bias, with_eltwise):
"""Create dnnl patterns.
Parameters
----------
op : str
The first call node's op name.
with_bias : bool
Whether attach `bias_add` to `nn.dense`.
with_eltwise : str
The attached elementwise post-op name.
Returns
-------
pattern : Tuple(pattern_name, CallPattern)
Created pattern name, along with its CallPattern.
"""
pat_name = "dnnl." + op
pat_name += "_bias" if with_bias else ""
pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
if op == "conv2d":
dnnl_pattern = (pat_name, make_conv_pattern(with_bias, with_eltwise))
elif op == "dense":
dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
else:
logger.warning("Currently, only conv2d and dense op are supported, but got %s.", op)
dnnl_pattern = ()
return dnnl_pattern


@register_pattern_table("dnnl")
def pattern_table():
conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True))
conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
"""Create dnnl patterns.
Returns
-------
dnnl_patterns : List[dnnl_pattern]
Created patterns.
"""
elt_list = ["nn.relu", "tanh", "sigmoid", None]
dnnl_patterns = []
for with_bias in [True, False]:
for elt in elt_list:
if not with_bias and not elt:
return dnnl_patterns
dnnl_patterns.append(make_dnnl_pattern("conv2d", with_bias, elt))
dnnl_patterns.append(make_dnnl_pattern("dense", with_bias, elt))
return dnnl_patterns


def partition_for_dnnl(mod, params=None):
"""Partition the graph greedily offloading supported operators to DNNL.
Parameters
----------
mod : Module
The module to run passes on.
params : Optional[Dict[str, NDArray]]
Constant input parameters.
Returns
-------
mod : Module
Annotated and partitioned module.
"""

if params:
mod["main"] = bind_params_by_name(mod["main"], params)
seq = tvm.transform.Sequential(
[
transform.CanonicalizeOps(),
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
# fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
transform.SimplifyExpr(),
transform.FoldConstant(),
transform.MergeComposite(pattern_table()),
transform.AnnotateTarget("dnnl"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
]
)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
return mod
18 changes: 18 additions & 0 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,27 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {

if (name == "dnnl.conv2d_bias_relu") {
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "nn.relu"});
} else if (name == "dnnl.conv2d_bias_tanh") {
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "tanh"});
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.conv2d_bias_sigmoid") {
call = GetRootCall(fn->body.as<CallNode>(), 2, {"nn.conv2d", "add", "sigmoid"});
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.conv2d_bias") {
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "add"});
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.conv2d_relu") {
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "nn.relu"});
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.conv2d_tanh") {
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "tanh"});
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.conv2d_sigmoid") {
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.conv2d", "sigmoid"});
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name == "dnnl.dense_bias") {
call = GetRootCall(fn->body.as<CallNode>(), 1, {"nn.dense", "add"});
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else {
LOG(FATAL) << "Unrecognized DNNL pattern: " << name;
}
Expand Down
99 changes: 66 additions & 33 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,31 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if ("nn.conv2d" == op_name) {
Conv2d(nid);
} else if ("dnnl.conv2d_relu" == op_name) {
Conv2d(nid, true, false);
Conv2d(nid, true, false, dnnl::algorithm::eltwise_relu);
} else if ("dnnl.conv2d_tanh" == op_name) {
Conv2d(nid, true, false, dnnl::algorithm::eltwise_tanh);
} else if ("dnnl.conv2d_sigmoid" == op_name) {
Conv2d(nid, true, false, dnnl::algorithm::eltwise_logistic);
} else if ("dnnl.conv2d_bias" == op_name) {
Conv2d(nid, false, true);
} else if ("dnnl.conv2d_bias_relu" == op_name) {
Conv2d(nid, true, true);
Conv2d(nid, true, true, dnnl::algorithm::eltwise_relu);
} else if ("dnnl.conv2d_bias_tanh" == op_name) {
Conv2d(nid, true, true, dnnl::algorithm::eltwise_tanh);
} else if ("dnnl.conv2d_bias_sigmoid" == op_name) {
Conv2d(nid, true, true, dnnl::algorithm::eltwise_logistic);
} else if ("nn.dense" == op_name) {
Dense(nid);
} else if ("dnnl.dense_bias" == op_name) {
Dense(nid, true);
} else if ("nn.batch_norm" == op_name) {
BatchNorm(nid);
} else if ("nn.relu" == op_name) {
Relu(nid);
Eltwise(nid, dnnl::algorithm::eltwise_relu);
} else if ("tanh" == op_name) {
Eltwise(nid, dnnl::algorithm::eltwise_tanh);
} else if ("sigmoid" == op_name) {
Eltwise(nid, dnnl::algorithm::eltwise_logistic);
} else if ("add" == op_name) {
Binary(nid, dnnl::algorithm::binary_add);
} else if ("multiply" == op_name) {
Expand Down Expand Up @@ -150,7 +166,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
return entry_out_mem_[eid].first;
}

void Conv2d(const size_t& nid, const bool has_relu = false, const bool has_bias = false) {
void Conv2d(const size_t& nid, const bool has_elt = false, const bool has_bias = false,
dnnl::algorithm algo = dnnl::algorithm::eltwise_relu) {
auto node = nodes_[nid];

// Setup attributes.
Expand All @@ -159,24 +176,29 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
dnnl::memory::dims input_shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
dnnl::memory::dims weight_shape = nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
std::vector<std::string> str_strides = node.GetAttr<std::vector<std::string>>("strides");
std::vector<std::string> str_dilates = node.GetAttr<std::vector<std::string>>("dilation");
std::vector<std::string> str_padding = node.GetAttr<std::vector<std::string>>("padding");
dnnl::memory::dim groups = std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);

dnnl::memory::dim N = input_shape[0], // batch size
IC = input_shape[1], // input channels
IH = input_shape[2], // input height
IW = input_shape[3], // input width
OC = weight_shape[0], // output channels
KH = weight_shape[2], // weight height
KW = weight_shape[3], // weight width
PW_L = std::stoi(str_padding[1]), // width padding: left
PW_R = std::stoi(str_padding[3]), // width padding: right
PH_L = std::stoi(str_padding[0]), // height padding: top
PH_R = std::stoi(str_padding[2]), // height padding: bottom
SH = std::stoi(str_strides[0]), // height-wise stride
SW = std::stoi(str_strides[1]), // weight-wise stride
OH = (IH - KH + PH_L + PH_R) / SH + 1, // output height
OW = (IW - KW + PW_L + PW_R) / SW + 1; // output width
dnnl::memory::dim N = input_shape[0], // batch size
IC = input_shape[1], // input channels
IH = input_shape[2], // input height
IW = input_shape[3], // input width
OC = weight_shape[0], // output channels
KH = weight_shape[2], // weight height
KW = weight_shape[3], // weight width
PW_L = std::stoi(str_padding[1]), // width padding: left
PW_R = std::stoi(str_padding[3]), // width padding: right
PH_L = std::stoi(str_padding[0]), // height padding: top
PH_R = std::stoi(str_padding[2]), // height padding: bottom
SH = std::stoi(str_strides[0]), // height-wise stride
SW = std::stoi(str_strides[1]), // weight-wise stride
DH = std::stoi(str_dilates[0]) - 1, // height-wise dilate
DW = std::stoi(str_dilates[1]) - 1, // weight-wise dilate
DKH = 1 + (KH - 1) * (DH + 1), // dilated weight height
DKW = 1 + (KW - 1) * (DW + 1), // dilated weight width
OH = (IH - DKH + PH_L + PH_R) / SH + 1, // output height
OW = (IW - DKW + PW_L + PW_R) / SW + 1; // output width

// Memory shapes.
dnnl::memory::dims src_dims = {N, IC, IH, IW};
Expand All @@ -187,6 +209,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
dnnl::memory::dims bias_dims = {OC};
dnnl::memory::dims dst_dims = {N, OC, OH, OW};
dnnl::memory::dims strides_dims = {SH, SW};
dnnl::memory::dims dilates_dims = {DH, DW};
dnnl::memory::dims padding_dims_l = {PH_L, PW_L};
dnnl::memory::dims padding_dims_r = {PH_R, PW_R};

Expand All @@ -199,13 +222,14 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// Covn2d description.
auto conv_desc = dnnl::convolution_forward::desc(
dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md,
conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, padding_dims_l, padding_dims_r);
conv_weights_md, conv_bias_md, conv_dst_md, strides_dims, dilates_dims, padding_dims_l,
padding_dims_r);

// Enable ReLU
// Enable elementwise post-ops
dnnl::primitive_attr attr;
if (has_relu) {
if (has_elt) {
dnnl::post_ops ops;
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
ops.append_eltwise(1.f, algo, 0.f, 0.f);
attr.set_post_ops(ops);
}

Expand Down Expand Up @@ -245,7 +269,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{DNNL_ARG_DST, conv2d_dst_memory}});
}

void Dense(const size_t& nid) {
void Dense(const size_t& nid, const bool has_bias = false) {
auto node = nodes_[nid];

// Setup attributes.
Expand Down Expand Up @@ -281,9 +305,18 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
// Memories.
auto data_memory = BindDNNLMemory(data_entry, data_md);
auto weight_memory = BindDNNLMemory(weight_entry, weight_md);

// Bias memory.
auto bias_memory = dnnl::memory(bias_md, engine_);
float bias[OC] = {0};
write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float));
if (has_bias) {
auto bias_entry = node.GetInputs()[2];
BindDNNLMemory(bias_entry, bias_memory);
} else {
float bias[OC] = {0};
write_to_dnnl_memory(bias, bias_memory, OC * sizeof(float));
}

// Output memory.
JSONGraphNodeEntry out_entry(nid, 0);
auto dst_memory = BindDNNLMemory(out_entry, dense_prim_desc.dst_desc());

Expand Down Expand Up @@ -335,20 +368,20 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{DNNL_ARG_VARIANCE, variance_memory}});
}

void Relu(const size_t& nid) {
void Eltwise(const size_t& nid, dnnl::algorithm algo) {
auto node = nodes_[nid];

auto data_entry = node.GetInputs()[0];
dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32);

auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference,
dnnl::algorithm::eltwise_relu, data_md, 0);
auto relu_prim_desc = dnnl::eltwise_forward::primitive_desc(relu_desc, engine_);
ICHECK(data_md == relu_prim_desc.dst_desc());
auto elt_desc =
dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, data_md, 0);
auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_);
ICHECK(data_md == elt_prim_desc.dst_desc());

auto relu = dnnl::eltwise_forward(relu_prim_desc);
net_.push_back(relu);
auto elt = dnnl::eltwise_forward(elt_prim_desc);
net_.push_back(elt);

auto data_memory = BindDNNLMemory(data_entry, data_md);
JSONGraphNodeEntry out_entry(nid, 0);
Expand Down
Loading

0 comments on commit 75cd670

Please sign in to comment.