Skip to content

Commit

Permalink
relu of dnnl json runtime only support 4-dims input (#9122)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunwayforever authored Sep 26, 2021
1 parent 80de123 commit f573007
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
5 changes: 2 additions & 3 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

auto data_entry = node.GetInputs()[0];
dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
auto data_md = dnnl::memory::desc{{shape}, dt::f32, tag::abcd};
dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32);

auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference,
dnnl::algorithm::eltwise_relu, data_md, 0);
Expand All @@ -349,9 +349,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
net_.push_back(relu);

auto data_memory = BindDNNLMemory(data_entry, data_md);
auto out_md = dnnl::memory::desc(shape, dt::f32, tag::abcd);
JSONGraphNodeEntry out_entry(nid, 0);
auto out_memory = BindDNNLMemory(out_entry, out_md);
auto out_memory = BindDNNLMemory(out_entry, data_md);

net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}});
}
Expand Down
28 changes: 16 additions & 12 deletions tests/python/relay/test_json_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_relu():
dtype = "float32"
shape = (1, 32, 14, 14)

def gen_relu():
def gen_relu(shape):
data0 = relay.var("data0", shape=shape, dtype=dtype)
out = relay.nn.relu(data0)

Expand All @@ -250,18 +250,22 @@ def gen_relu():

return mod, ref_mod

mod, ref_mod = gen_relu()
def check(shape):
mod, ref_mod = gen_relu(shape)

data0 = np.random.uniform(-1, 1, shape).astype(dtype)
check_result(
mod,
ref_mod,
{
"data0": data0,
},
shape,
tol=1e-5,
)

data0 = np.random.uniform(-1, 1, shape).astype(dtype)
check_result(
mod,
ref_mod,
{
"data0": data0,
},
(1, 32, 14, 14),
tol=1e-5,
)
check(shape=(1, 32, 14, 14))
check(shape=(1, 32))


def test_dense():
Expand Down

0 comments on commit f573007

Please sign in to comment.