Skip to content

Commit

Permalink
[Relay][BYOCG] Propagate constant to subgraphs (apache#5094)
Browse files Browse the repository at this point in the history
* bind constant to subgraphs

* con -> constant
  • Loading branch information
zhiics authored and Trevor Morris committed Apr 16, 2020
1 parent eb276dd commit 397d007
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 4 deletions.
56 changes: 55 additions & 1 deletion src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/object.h>

Expand All @@ -40,14 +41,63 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
public:
explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }

void VisitExpr_(const VarNode* node) {
void VisitExpr_(const VarNode* node) final {
ext_func_args_.push_back(GetRef<Var>(node));
out_.clear();
Output output;
output.name = node->name_hint();
out_.push_back(output);
}

void VisitExpr_(const ConstantNode* cn) final {
Constant constant = GetRef<Constant>(cn);
if (visited_.count(constant)) {
// Note this is for demostration purpose. ConstantNode doesn't necessarily
// belong to calls. We need to revisit this when tuples come into play.
out_.push_back(visited_[constant]);
return;
}

std::ostringstream decl_stream;
std::ostringstream buf_stream;

out_.clear();
Output output;
output.name = "const_" + std::to_string(const_idx_++);
out_.push_back(output);
visited_[constant] = output;

runtime::NDArray array = cn->data;
const auto& shape = array.Shape();
const DLTensor& dl_tensor = array.ToDLPack()->dl_tensor;

// Get the number of elements.
int64_t num_elems = 1;
for (auto i : shape) num_elems *= i;

const auto* type_node = cn->checked_type().as<TensorTypeNode>();
CHECK(type_node);
const auto& dtype = GetDtypeString(type_node);
// Define a const buffer: float const_0[64] = {1.0, 2.0, ...};
//
// Technically, you may need: static float* const_0 = (float*)malloc(4 * 64)
// to avoid possible stack overflow.
buf_stream << dtype << " " << output.name << "[" << num_elems << "] = {";
if (dtype == "float") {
float* p_flt = static_cast<float*>(dl_tensor.data);
for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
if (num_elems) buf_stream << p_flt[num_elems - 1];
} else if (dtype == "int") {
int* p_flt = static_cast<int*>(dl_tensor.data);
for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
if (num_elems) buf_stream << p_flt[num_elems - 1];
} else {
LOG(FATAL) << "Only float and int are supported for now.";
}
buf_stream << "};";
ext_func_body.insert(ext_func_body.begin(), buf_stream.str());
}

void VisitExpr_(const CallNode* call) final {
std::ostringstream macro_stream;
std::ostringstream decl_stream;
Expand Down Expand Up @@ -138,6 +188,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
int func_idx = 0;
/*! \brief The index of allocated buffers. */
int buf_idx_ = 0;
/*! \brief The index of global constants. */
int const_idx_ = 0;
/*! \brief The arguments of a C compiler compatible function. */
Array<Var> ext_func_args_;
/*! \brief The statements of a C compiler compatible function. */
Expand All @@ -148,6 +200,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
std::vector<std::string> buf_decl_;
/*! \brief The name and index pairs for output. */
std::vector<Output> out_;
/*! \brief The cached expressions. */
std::unordered_map<Expr, Output, ObjectHash, ObjectEqual> visited_;
};

class CSourceCodegen : public CSourceModuleCodegenBase {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class CodegenCBase {
* \return true if the call's name is equivalent to the given name. Otherwise,
* false.
*/
bool IsOp(const CallNode* call, std::string op_name) const {
bool IsOp(const CallNode* call, const std::string& op_name) const {
const auto* op_node = call->op.as<OpNode>();
CHECK(op_node) << "Expects a single op.";
Op op = GetRef<Op>(op_node);
Expand All @@ -218,7 +218,7 @@ class CodegenCBase {
*
* \return The emitted code string.
*/
std::string JitImpl(std::string ext_func_id, const Array<Var>& args,
std::string JitImpl(const std::string& ext_func_id, const Array<Var>& args,
const std::vector<std::string>& buf_decl,
const std::vector<std::string>& body,
const std::vector<Output>& out) {
Expand Down
15 changes: 14 additions & 1 deletion src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
#include <utility>
#include <vector>

#include "../backend/utils.h"

namespace tvm {
namespace relay {
namespace partitioning {
Expand Down Expand Up @@ -200,14 +202,20 @@ class Partitioner : public ExprMutator {
auto input = VisitExpr(call->args[0]);
Array<Var> params;
Array<Expr> args;
std::unordered_map<std::string, runtime::NDArray> params_bind;

// The subgraph may be merged so we need to update it again.
subgraph = GetSubgraph(GetRef<Call>(call));
CHECK(subgraph);

// Record the constants for propagation.
for (auto pair : subgraph->args) {
params.push_back(pair.first);
args.push_back(pair.second);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
args.push_back(pair.second);
}
}

auto subgraph_func =
Expand All @@ -223,6 +231,11 @@ class Partitioner : public ExprMutator {
tvm::tir::StringImmNode::make(compiler_attrs->compiler));
subgraph_func =
WithAttr(std::move(subgraph_func), attr::kInline, tvm::Integer(1));

// Constant propagation
if (!params_bind.empty()) {
subgraph_func = backend::BindParamsByName(subgraph_func, params_bind);
}
CHECK(!module_->ContainGlobalVar(name))
<< "Global function " << name << " already exists";
// Create a global function and add it to the IRModule for the subgraph.
Expand Down
45 changes: 45 additions & 0 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,50 @@ def expected():
assert relay.analysis.alpha_equal(partitioned, ref_mod)


def test_constant_propagation():
ones = np.ones(shape=(8, 8), dtype="float32")

def expected():
mod = tvm.IRModule()
x = relay.const(ones)
y = relay.var("y", shape=(8, 8))
x0 = relay.const(ones)
y0 = relay.var("y0", shape=(8, 8))
add = x0 + y0
# Function that uses C compiler
func = relay.Function([y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0"))
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [y])
log = relay.log(add_call)
main = relay.Function([y], log)
mod["main"] = main
return mod

x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
add = x + y
log = relay.log(add)
f = relay.Function([x, y], log)
f = relay.build_module.bind_params_by_name(f, {"x": tvm.nd.array(ones)})
mod = tvm.IRModule()
mod["main"] = f
mod = WhiteListAnnotator(["add"], "ccompiler")(mod)
mod = transform.PartitionGraph()(mod)

expected_mod = expected()
assert relay.alpha_equal(mod, expected_mod)

y_data = np.random.rand(8, 8).astype('float32')
np_add = ones + y_data
check_result(mod, {"y": y_data}, (8, 8), np.log(np_add))


if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
Expand All @@ -643,3 +687,4 @@ def expected():
test_extern_dnnl_mobilenet()
test_function_lifting()
test_function_lifting_inline()
test_constant_propagation()

0 comments on commit 397d007

Please sign in to comment.