Skip to content

Commit

Permalink
[BYOC] Support Tuple Output in C/DNNL Codegen (#5701)
Browse files Browse the repository at this point in the history
* Support tuple output runtime

* fix unit test
  • Loading branch information
comaniac authored May 30, 2020
1 parent 879158a commit 910edef
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 30 deletions.
19 changes: 19 additions & 0 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,25 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
return {output};
}

std::vector<Output> VisitExpr_(const TupleNode* node) final {
std::vector<Output> outs;
for (auto field : node->fields) {
auto res = VisitExpr(field);
CHECK_EQ(res.size(), 1U) << "Do not support tuple nest";
outs.push_back(res[0]);
}
return outs;
}

std::vector<Output> VisitExpr_(const TupleGetItemNode* op) final {
auto res = VisitExpr(op->tuple);
CHECK_GT(res.size(), static_cast<size_t>(op->index));

// Only keep the item we want for the child node.
// FIXME(@comaniac): The other items should still be requried for the primary outputs.
return {res[op->index]};
}

std::vector<Output> VisitExpr_(const ConstantNode* cn) final {
// Note this is for demonstration purpose. ConstantNode doesn't necessarily
// belong to calls. We need to revisit this when tuples come into play.
Expand Down
47 changes: 28 additions & 19 deletions src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,19 @@ class CodegenCBase {
* \endcode
*/
void GenerateBackendCFunc(const std::string& func_name, const Array<Var>& args,
const Output& out) {
const std::vector<Output>& outs) {
// Print signature
code_stream_ << "\n";
code_stream_ << "extern \"C\" int " << func_name << "_wrapper_(";
for (size_t i = 0; i < args.size(); i++) {
code_stream_ << "DLTensor* arg" << i << ",\n";
code_stream_ << "\t";
}
if (args.size() > 0) {
code_stream_ << "DLTensor* arg" << args.size() << ") {\n";
for (size_t i = 0; i < outs.size() - 1; i++) {
code_stream_ << "DLTensor* out" << i << ",\n";
code_stream_ << "\t";
}
code_stream_ << "DLTensor* out" << outs.size() - 1 << ") {\n";

EnterScope();

Expand All @@ -147,10 +149,12 @@ class CodegenCBase {
code_stream_ << "static_cast<" << dtype_str << "*>(arg" << i << "->data),\n";
PrintIndents();
}
if (args.size() > 0) {
code_stream_ << "static_cast<" << out.dtype << "*>(arg" << args.size() << "->data)";
for (size_t i = 0; i < outs.size() - 1; i++) {
code_stream_ << "static_cast<" << outs[i].dtype << "*>(out" << i << "->data),\n";
PrintIndents();
}
code_stream_ << ");\n";
code_stream_ << "static_cast<" << outs.back().dtype << "*>(out" << outs.size() - 1
<< "->data));\n";
PrintIndents();
code_stream_ << "return 0;\n";
ExitScope();
Expand Down Expand Up @@ -186,18 +190,19 @@ class CodegenCBase {
*/
std::string JitImpl(const std::string& ext_func_id, const Array<Var>& args,
const std::vector<std::string>& buf_decl,
const std::vector<std::string>& body, const std::vector<Output>& out) {
const std::vector<std::string>& body, const std::vector<Output>& outs) {
// Create the signature. For example, it could be:
// extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {}
// extern "C" void dnnl_0_(float* in0, float* in1, float* out0, float* out1) {}
code_stream_ << "extern \"C\" void " << ext_func_id << "_(";

CHECK_EQ(out.size(), 1U) << "Internal error: only single output is support.";

for (const auto& arg : args) {
const auto& dtype_str = GetDtypeString(arg);
code_stream_ << dtype_str << "* " << arg->name_hint() << ", ";
}
code_stream_ << out[0].dtype << "* out) {\n";
for (size_t i = 0; i < outs.size() - 1; ++i) {
code_stream_ << outs[i].dtype << "* out" << i << ", ";
}
code_stream_ << outs.back().dtype << "* out" << outs.size() - 1 << ") {\n";
this->EnterScope();

// Function body
Expand All @@ -212,22 +217,26 @@ class CodegenCBase {
}

// Copy output
if (out[0].need_copy) {
for (size_t i = 0; i < outs.size(); ++i) {
if (!outs[i].need_copy) {
continue;
}
this->PrintIndents();
code_stream_ << "std::memcpy(out, " << out[0].name << ", 4 * " << out[0].size << ");\n";
code_stream_ << "std::memcpy(out" << i << ", " << outs[i].name << ", 4 * " << outs[i].size
<< ");\n";
}

// Free buffers
for (size_t i = 0; i < buf_decl.size(); i++) {
this->PrintIndents();
code_stream_ << "std::free(buf_" << i << ");\n";
}
// Free buffers
for (size_t i = 0; i < buf_decl.size(); i++) {
this->PrintIndents();
code_stream_ << "std::free(buf_" << i << ");\n";
}

this->ExitScope();
code_stream_ << "}\n";

// Create the wrapper to call the ext_func
this->GenerateBackendCFunc(ext_func_id, args, out[0]);
this->GenerateBackendCFunc(ext_func_id, args, outs);
return code_stream_.str();
}

Expand Down
12 changes: 10 additions & 2 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public C
return {output};
}

std::vector<Output> VisitExpr_(const TupleNode* node) final {
std::vector<Output> outs;
for (auto field : node->fields) {
auto res = VisitExpr(field);
CHECK_EQ(res.size(), 1U) << "Do not support tuple nest";
outs.push_back(res[0]);
}
return outs;
}

std::vector<Output> VisitExpr_(const TupleGetItemNode* op) final {
auto res = VisitExpr(op->tuple);
CHECK_GT(res.size(), static_cast<size_t>(op->index));
Expand Down Expand Up @@ -347,8 +357,6 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
// Create a corresponding DNNL function for the given relay Function.
void GenDNNLFunc(const Function& func) {
CHECK(func.defined()) << "Input error: expect a Relay function.";
const auto* call = func->body.as<CallNode>();
CHECK(call) << "DNNL expects a single convolution or dense op";

// Record the external symbol for runtime lookup.
auto sid = GetExtSymbol(func);
Expand Down
48 changes: 39 additions & 9 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Unit tests for graph partitioning."""
# pylint: disable=not-callable
import os
import sys

Expand Down Expand Up @@ -201,8 +202,11 @@ def check_vm_result():
exe = runtime.vm.Executable.load_exec(code, lib)
vm = runtime.vm.VirtualMachine(exe)
vm.init(ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
outs = vm.run(**map_inputs)
outs = outs if isinstance(outs, runtime.container.ADT) else [outs]
results = result if isinstance(result, list) else [result]
for out, ref in zip(outs, results):
tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=tol, atol=tol)

def check_graph_runtime_result():
compile_engine.get().clear()
Expand All @@ -215,10 +219,14 @@ def check_graph_runtime_result():
rt_mod.set_input(name, data)
rt_mod.set_input(**param)
rt_mod.run()
out = tvm.nd.empty(out_shape, ctx=ctx)
out = rt_mod.get_output(0, out)

tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
out_shapes = out_shape if isinstance(out_shape, list) else [out_shape]
results = result if isinstance(result, list) else [result]

for idx, shape in enumerate(out_shapes):
out = tvm.nd.empty(shape, ctx=ctx)
out = rt_mod.get_output(idx, out)
tvm.testing.assert_allclose(out.asnumpy(), results[idx], rtol=tol, atol=tol)

check_vm_result()
check_graph_runtime_result()
Expand Down Expand Up @@ -1082,11 +1090,11 @@ def test_duplicate_merge_and_tuplegetitem():
target = "test_duplicate_merge_and_tuplegetitem"

@reg.register("nn.batch_norm", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
def batch_norm(attrs, args): # pylint: disable=unused-variable
return True

@reg.register("nn.relu", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
def relu(attrs, args): # pylint: disable=unused-variable
return True

def create_graph():
Expand Down Expand Up @@ -1195,11 +1203,11 @@ def test_flatten_tuple_output():
target = "test_flatten_tuple_output"

@reg.register("split", "target." + target)
def foo(attrs, args): # pylint: disable=unused-variable
def split(attrs, args): # pylint: disable=unused-variable
return True

@reg.register("abs", "target." + target)
def foo(attrs, args): # pylint: disable=unused-variable
def abs(attrs, args): # pylint: disable=unused-variable
return True

def create_graph():
Expand Down Expand Up @@ -1259,6 +1267,27 @@ def expected():
partitioned = seq(create_graph())
assert tvm.ir.structural_equal(partitioned, expected(), map_free_vars=True)

def test_tuple_output_exec():
"""Test C codegen and runtime for a subgraph with a tuple output"""
a = relay.var('a', shape=(10, 10), dtype='float32')
b = relay.var('b', shape=(10, 10), dtype='float32')
ba = relay.annotation.compiler_begin(a, 'ccompiler')
bb = relay.annotation.compiler_begin(b, 'ccompiler')
add = relay.add(ba, bb)
sub = relay.subtract(ba, bb)
out = relay.Tuple((add, sub))
eout = relay.annotation.compiler_end(out, 'ccompiler')
func=relay.Function([a, b], eout)
mod = tvm.IRModule()
mod["main"] = func
mod = transform.PartitionGraph()(mod)

a_data = np.random.rand(10, 10).astype('float32')
b_data = np.random.rand(10, 10).astype('float32')

check_result(mod, {'a': a_data, 'b': b_data},
[(10, 10), (10, 10)],
[(a_data + b_data), (a_data - b_data)])

if __name__ == "__main__":
test_multi_node_compiler()
Expand All @@ -1278,3 +1307,4 @@ def expected():
test_duplicate_merge_and_tuplegetitem()
test_constant_tuples()
test_flatten_tuple_output()
test_tuple_output_exec()

0 comments on commit 910edef

Please sign in to comment.