Skip to content

Commit

Permalink
fix test and clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Nov 27, 2019
1 parent 7f5ea88 commit e09d2fa
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 178 deletions.
1 change: 1 addition & 0 deletions cmake/modules/contrib/Extern.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ if(DNNL_IDX GREATER -1)
list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})

find_library(EXTERN_LIBRARY_DNNL dnnl)
include_directories(/usr/local/include)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL})
file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/*)
list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
Expand Down
1 change: 0 additions & 1 deletion python/tvm/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def export_library(self,
path_cc = temp.relpath(c_file_name)
with open(path_cc, "w") as f:
f.write(m.get_source())
print(m.get_source())
files.append(path_cc)
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
Expand Down
47 changes: 47 additions & 0 deletions src/relay/backend/contrib/contrib_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,53 @@ class ExternSourcePrinter {
code_stream_ << "}";
}

virtual std::string JIT(void) = 0;

std::string JitImpl(std::string subgraph_id,
std::vector<std::string> args,
std::vector<std::string> buf_decl,
std::vector<std::string> body,
std::vector<std::pair<std::string, int>> out) {
// Create the signature. For example, it could be:
// extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {}
code_stream_ << "extern \"C\" void " << subgraph_id << "_(";

for (const auto& arg : args) {
code_stream_ << "float* " << arg << ", ";
}
code_stream_ << "float* out) {\n";
this->EnterScope();

// Function body
for (auto decl : buf_decl) {
this->PrintIndents();
code_stream_ << decl << "\n";
}
code_stream_ << "\n";
for (auto stmt : body) {
this->PrintIndents();
code_stream_ << stmt << "\n";
}

// Copy output
CHECK_EQ(out.size(), 1U) << "Internal error: only single output is support.";
this->PrintIndents();
code_stream_ << "std::memcpy(out, " << out[0].first << ", 4 * " << out[0].second << ");\n";

// Free buffers
for (size_t i = 0; i < buf_decl.size(); i++) {
this->PrintIndents();
code_stream_ << "std::free(buf_" << i << ");\n";
}

this->ExitScope();
code_stream_ << "}\n";

// Create the wrapper to call the subgraph
this->GenerateSubgraphWrapper(subgraph_id, args.size() + 1 /* output */);
return code_stream_.str();
}

/*! \brief The external function source code stream. */
std::ostringstream code_stream_;

Expand Down
45 changes: 3 additions & 42 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,47 +164,8 @@ class DnnlBuilder : public ExprVisitor, public ExternSourcePrinter {
out_.push_back({out, out_size});
}

std::string jit_dnnl() {
// Create the signature. For example, it could be:
// extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {}
code_stream_ << "extern \"C\" void " << subgraph_id_ << "_(";

for (const auto& arg : subgraph_args_) {
code_stream_ << "float* " << arg << ", ";
}
code_stream_ << "float* out) {\n";
this->EnterScope();

// Function body
for (auto decl : buf_decl_) {
this->PrintIndents();
code_stream_ << decl << "\n";
}
code_stream_ << "\n";
for (auto stmt : subgraph_body) {
this->PrintIndents();
code_stream_ << stmt << "\n";
}

// Copy output
CHECK_EQ(out_.size(), 1U) << "Internal error: only single output is support yet.";
this->PrintIndents();
code_stream_ << "std::memcpy(out, " << out_[0].first << ", 4 * "
<< out_[0].second << ");\n";

// Free buffers
for (size_t i = 0; i < buf_decl_.size(); i++) {
this->PrintIndents();
code_stream_ << "std::free(buf_" << i << ");\n";
}

this->ExitScope();
code_stream_ << "}\n";

// Create the wrapper to call the subgraph
this->GenerateSubgraphWrapper(subgraph_id_,
subgraph_args_.size() + 1 /* output */);
return code_stream_.str();
std::string JIT(void) {
return JitImpl(subgraph_id_, subgraph_args_, buf_decl_, subgraph_body, out_);
}

private:
Expand Down Expand Up @@ -279,7 +240,7 @@ class DNNLCodegen : public ExternCodegenBase {

auto builder = DnnlBuilder("dnnl_" + sid);
builder.VisitExpr(func->body);
code_stream_ << builder.jit_dnnl();
code_stream_ << builder.JIT();
}

/*!
Expand Down
48 changes: 4 additions & 44 deletions src/relay/backend/contrib/gcc/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,51 +108,12 @@ class GccBuilder : public ExprVisitor, public ExternSourcePrinter {
out_.push_back({out, out_size});
}

std::string jit_csource() {
std::string JIT(void) {
// Write function macros
for (auto decl : func_decl_) {
code_stream_ << decl << "\n";
}

// Write subgraph function declaration
code_stream_ << "extern \"C\" void " << subgraph_id_ << "_(";

for (const auto& arg : subgraph_args_) {
code_stream_ << "float* " << arg << ", ";
}

code_stream_ << "float* out) {\n";
this->EnterScope();

// Function body
for (auto decl : buf_decl_) {
this->PrintIndents();
code_stream_ << decl << "\n";
}
code_stream_ << "\n";
for (auto stmt : subgraph_body) {
this->PrintIndents();
code_stream_ << stmt << "\n";
}

// Copy output
CHECK(out_.size() == 1) << "Internal error";
this->PrintIndents();
code_stream_ << "std::memcpy(out, " << out_[0].first << ", 4 * " << out_[0].second << ");\n";

// Free buffers
for (size_t i = 0; i < buf_decl_.size(); i++) {
this->PrintIndents();
code_stream_ << "std::free(buf_" << i << ");\n";
}

this->ExitScope();
code_stream_ << "}\n";

// Create the wrapper to call the subgraph
this->GenerateSubgraphWrapper(subgraph_id_,
subgraph_args_.size() + 1 /* output */);
return code_stream_.str();
return JitImpl(subgraph_id_, subgraph_args_, buf_decl_, subgraph_body, out_);
}

private:
Expand Down Expand Up @@ -189,7 +150,7 @@ class GccCodegen : public ExternCodegenBase {

auto builder = GccBuilder("gcc_" + sid);
builder.VisitExpr(func->body);
code_stream_ << builder.jit_csource();
code_stream_ << builder.JIT();
}

runtime::Module CreateExternModule(const NodeRef& ref) {
Expand Down Expand Up @@ -217,7 +178,6 @@ class GccCodegen : public ExternCodegenBase {
for (int64_t j = 0; j < p_DIM2_; ++j) { \
int64_t k = i * p_DIM2_ + j; \
out[k] = a[k] p_OP_ b[k]; \
std::cout << a[k] << " " << b[k] << out[k] << std::endl; \
} \
} \
}
Expand All @@ -236,7 +196,7 @@ class GccCodegen : public ExternCodegenBase {
LOG(FATAL) << "The input ref is expected to be a Relay function or module"
<< "\n";
}
LOG(INFO) << code_stream_.str();

// Create a CSourceModule
const auto* pf = runtime::Registry::Get("module.csource_module_create");
CHECK(pf != nullptr) << "Cannot find csource module to create the external function";
Expand Down
129 changes: 38 additions & 91 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,28 @@ def visit_call(self, call):
new_call = relay.Call(call.op, params, call.attrs)
return new_call

def check_result(mod, map_inputs, out_shape, result, tol=1e-7):
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, "llvm")
kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11"]
tmp_path = util.tempdir()
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.module.load(lib_path)

ctx = tvm.cpu()
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)

for name, data in map_inputs.items():
rt_mod.set_input(name, data)
rt_mod.run()
out = tvm.nd.empty(out_shape, ctx=ctx)
out = rt_mod.get_output(0, out)

tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)


def test_multi_node_subgraph():
x = relay.var('x', shape=(10, 10))
Expand Down Expand Up @@ -175,34 +197,14 @@ def test_multi_node_subgraph():
for _ in range(8):
w_data.append(np.random.rand(10, 10).astype('float32'))

with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, "llvm")
kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11"]
tmp_path = util.tempdir()
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.module.load(lib_path)

ctx = tvm.cpu()
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
for i in range(8):
data = np.random.rand(10, 10).astype('float32')
w_data.append(data)
var = "w" + str(i)
rt_mod.set_input(var, data)
rt_mod.run()
out = tvm.nd.empty((30, 10), ctx=ctx)
out = rt_mod.get_output(0, out)

tvm.testing.assert_allclose(
out.asnumpy(),
np.concatenate(
(((x_data + w_data[0]) - w_data[1]) * w_data[2],
((x_data + w_data[3]) - w_data[4]) * w_data[5],
x_data + w_data[6] - w_data[7]),
axis=0))
map_inputs = {"w{}".format(i): w_data[i] for i in range(8)}
map_inputs["x"] = x_data
check_result(
mod, map_inputs, (30, 10),
np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2],
((x_data + w_data[3]) - w_data[4]) * w_data[5],
x_data + w_data[6] - w_data[7]),
axis=0))


def test_extern_gcc_single_op():
Expand All @@ -216,25 +218,7 @@ def test_extern_gcc_single_op():
mod["main"] = f
mod = relay.build_extern(mod, "gcc")

with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, "llvm")
kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11"]
tmp_path = util.tempdir()
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.module.load(lib_path)

ctx = tvm.cpu()
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
rt_mod.set_input("x", x_data)
rt_mod.set_input("y", y_data)
rt_mod.run()
out = tvm.nd.empty((8, 8), ctx=ctx)
out = rt_mod.get_output(0, out)

tvm.testing.assert_allclose(out.asnumpy(), (x_data + y_data))
check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data)


def test_extern_gcc():
Expand All @@ -249,26 +233,7 @@ def test_extern_gcc():
mod["main"] = f
mod = relay.build_extern(mod, "gcc")

with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, "llvm")
kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11"]
tmp_path = util.tempdir()
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.module.load(lib_path)

ctx = tvm.cpu()
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
rt_mod.set_input("x", x_data)
rt_mod.set_input("y", y_data)
rt_mod.run()
out = tvm.nd.empty((2, 2), ctx=ctx)
out = rt_mod.get_output(0, out)

tvm.testing.assert_allclose(out.asnumpy(),
(y_data * y_data) - (x_data + x_data))
check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data))


def test_extern_dnnl():
Expand Down Expand Up @@ -301,28 +266,10 @@ def test_extern_dnnl():
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)

with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, "llvm")
kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11"]
tmp_path = util.tempdir()
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.module.load(lib_path)

ctx = tvm.cpu()
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
rt_mod.set_input("data", i_data)
rt_mod.set_input("weight1", w1_data)
rt_mod.run()
out = tvm.nd.empty((1, 32, 14, 14), ctx=ctx)
out = rt_mod.get_output(0, out)

ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=ctx)
ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
ref_res = ref_ex.evaluate()(i_data, w1_data)

tvm.testing.assert_allclose(out.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
check_result(mod, {"data": i_data, "weight1": w1_data},
(1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)


@nottest
Expand Down Expand Up @@ -351,8 +298,8 @@ def test_extern_dnnl_mobilenet():


if __name__ == "__main__":
# test_multi_node_subgraph()
# test_extern_gcc_single_op()
# test_extern_gcc()
test_multi_node_subgraph()
test_extern_gcc_single_op()
test_extern_gcc()
test_extern_dnnl()
# test_extern_dnnl_mobilenet()

0 comments on commit e09d2fa

Please sign in to comment.