Skip to content

Commit

Permalink
[BYOC] support arbitrary input dims for add/mul/relu of dnnl c_src co…
Browse files Browse the repository at this point in the history
…degen (apache#9127)

* support arbitrary input dims for add/mul/relu of dnnl c_src codegen

* fix lint

* fix

Co-authored-by: sunway <wei.sun@hexintek.com>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent e369d0a commit ad24256
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 34 deletions.
41 changes: 27 additions & 14 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ inline size_t GetShape1DSize(const Type& type) {
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
}

inline std::string GetShapeString(std::vector<int> shape) {
std::string v = "std::vector<long int>{";
for (auto s : shape) {
v += std::to_string(s) + ",";
}
v += "}";
return v;
}

std::vector<std::string> Conv2d(const CallNode* call) {
std::vector<std::string> args;
const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>();
Expand Down Expand Up @@ -98,12 +107,8 @@ std::vector<std::string> Dense(const CallNode* call) {
std::vector<std::string> Relu(const CallNode* call) {
std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type());

// Args: N, C, H, W
for (auto s : ishape) {
args.push_back(std::to_string(s));
}

args.push_back(GetShapeString(ishape));
return args;
}

Expand All @@ -123,15 +128,25 @@ std::vector<std::string> BatchNorm(const CallNode* call) {
return args;
}

// should comply with src/runtime/contrib/dnnl/dnnl.cc
#define DNNL_BINARY_ADD 0
#define DNNL_BINARY_MUL 1

std::vector<std::string> Add(const CallNode* call) {
std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type());

args.push_back(std::to_string(DNNL_BINARY_ADD));
// Args: H, W
for (auto s : ishape) {
args.push_back(std::to_string(s));
}
args.push_back(GetShapeString(ishape));
return args;
}

std::vector<std::string> Multiply(const CallNode* call) {
std::vector<std::string> args;
auto ishape = GetShape(call->args[0]->checked_type());
args.push_back(std::to_string(DNNL_BINARY_MUL));
// Args: H, W
args.push_back(GetShapeString(ishape));
return args;
}

Expand Down Expand Up @@ -239,11 +254,9 @@ class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public C

using ArgFunType = std::function<std::vector<std::string>(const CallNode*)>;
static const std::map<std::string, std::pair<std::string, ArgFunType>> op_map = {
{"nn.conv2d", {"dnnl_conv2d", Conv2d}},
{"nn.dense", {"dnnl_dense", Dense}},
{"nn.relu", {"dnnl_relu", Relu}},
{"nn.batch_norm", {"dnnl_bn", BatchNorm}},
{"add", {"dnnl_add", Add}},
{"nn.conv2d", {"dnnl_conv2d", Conv2d}}, {"nn.dense", {"dnnl_dense", Dense}},
{"nn.relu", {"dnnl_relu", Relu}}, {"nn.batch_norm", {"dnnl_bn", BatchNorm}},
{"add", {"dnnl_binary_op", Add}}, {"multiply", {"dnnl_binary_op", Multiply}},
};

const auto op_name = GetRef<Op>(op_node)->name;
Expand Down
69 changes: 52 additions & 17 deletions src/runtime/contrib/dnnl/dnnl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,32 @@ typedef struct {
void** data;
} DnnlPackedArgs;

inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape,
memory::data_type dtype) {
using tag = memory::format_tag;

dnnl::memory::desc data_md;

switch (shape.size()) {
case 2:
data_md = dnnl::memory::desc({shape, dtype, tag::ab});
break;
case 3:
data_md = dnnl::memory::desc({shape, dtype, tag::abc});
break;
case 4:
data_md = dnnl::memory::desc({shape, dtype, tag::abcd});
break;
case 5:
data_md = dnnl::memory::desc({shape, dtype, tag::abcde});
break;
default:
LOG(FATAL) << "Unsupported data shape dimension: " << shape.size();
break;
}
return data_md;
}

// Read from memory, write to handle
inline void read_from_dnnl_memory(void* handle, const memory& mem) {
size_t bytes = mem.get_desc().get_size();
Expand Down Expand Up @@ -175,16 +201,13 @@ extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int
read_from_dnnl_memory(out, dst_memory);
}

extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_) {
using tag = memory::format_tag;
extern "C" void dnnl_relu(float* data, float* out, std::vector<int64_t> shape) {
using dt = memory::data_type;

engine eng(engine::kind::cpu, 0);
stream s(eng);

memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_};

auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw};
auto data_md = GenDNNLMemDescByShape(shape, dt::f32);

auto data_memory = memory(data_md, eng, data);
auto dst_memory = memory(data_md, eng);
Expand Down Expand Up @@ -241,27 +264,39 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo
free(weight);
}

extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, int p_C_, int p_H_,
int p_W_) {
using tag = memory::format_tag;
// should comply with src/relay/backend/contrib/dnnl/codegen.cc
#define DNNL_BINARY_ADD 0
#define DNNL_BINARY_MUL 1

extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_type,
std::vector<int64_t> shape) {
using dt = memory::data_type;

engine eng(engine::kind::cpu, 0);
stream s(eng);

memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_};

auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw};
auto weight_md = memory::desc({{data_tz}, dt::f32, tag::nchw});
auto dst_md = memory::desc({{data_tz}, dt::f32, tag::nchw});
auto data_md = GenDNNLMemDescByShape(shape, dt::f32);

auto data_memory = memory(data_md, eng, data);
auto weight_memory = memory(weight_md, eng, weight);
auto dst_memory = memory(dst_md, eng);
auto weight_memory = memory(data_md, eng, weight);
auto dst_memory = memory(data_md, eng);

auto add_desc = binary::desc(algorithm::binary_add, data_md, weight_md, dst_md);
algorithm algo = algorithm::undef;
switch (algo_type) {
case DNNL_BINARY_ADD:
algo = algorithm::binary_add;
break;
case DNNL_BINARY_MUL:
algo = algorithm::binary_mul;
break;
default:
LOG(FATAL) << "Unsupported dnnl algorithm: " << algo_type;
break;
}

auto add_desc = binary::desc(algo, data_md, data_md, data_md);
auto add_prim_desc = binary::primitive_desc(add_desc, eng);
assert(dst_md == add_prim_desc.dst_desc());
assert(data_md == add_prim_desc.dst_desc());

auto add = binary(add_prim_desc);
add.execute(
Expand Down
9 changes: 6 additions & 3 deletions src/runtime/contrib/dnnl/dnnl_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/logging.h>

#include <vector>

#include "dnnl.hpp"

Expand Down Expand Up @@ -54,14 +57,14 @@ extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights,
extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_,
int p_O_);

extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_);
extern "C" TVM_DLL void dnnl_relu(float* data, float* out, std::vector<int64_t> shape);

extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean,
float* variance, float* out, float* new_mean, float* new_variance,
int p_n_, int p_c_, int p_h_, int p_w_, int p_e_);

extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_,
int p_h_, int p_w_);
extern "C" TVM_DLL void dnnl_binary_op(float* data, float* weight, float* out, int binary_algo,
std::vector<int64_t> shape);

} // namespace contrib
} // namespace runtime
Expand Down

0 comments on commit ad24256

Please sign in to comment.