diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 126d1d5839de..97231dfe3401 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -40,7 +41,7 @@ class CodegenC : public ExprVisitor, public CodegenCBase { public: explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; } - void VisitExpr_(const VarNode* node) { + void VisitExpr_(const VarNode* node) final { ext_func_args_.push_back(GetRef(node)); out_.clear(); Output output; @@ -48,6 +49,55 @@ class CodegenC : public ExprVisitor, public CodegenCBase { out_.push_back(output); } + void VisitExpr_(const ConstantNode* cn) final { + Constant constant = GetRef(cn); + if (visited_.count(constant)) { + // Note this is for demostration purpose. ConstantNode doesn't necessarily + // belong to calls. We need to revisit this when tuples come into play. + out_.push_back(visited_[constant]); + return; + } + + std::ostringstream decl_stream; + std::ostringstream buf_stream; + + out_.clear(); + Output output; + output.name = "const_" + std::to_string(const_idx_++); + out_.push_back(output); + visited_[constant] = output; + + runtime::NDArray array = cn->data; + const auto& shape = array.Shape(); + const DLTensor& dl_tensor = array.ToDLPack()->dl_tensor; + + // Get the number of elements. + int64_t num_elems = 1; + for (auto i : shape) num_elems *= i; + + const auto* type_node = cn->checked_type().as(); + CHECK(type_node); + const auto& dtype = GetDtypeString(type_node); + // Define a const buffer: float const_0[64] = {1.0, 2.0, ...}; + // + // Technically, you may need: static float* const_0 = (float*)malloc(4 * 64) + // to avoid possible stack overflow. + buf_stream << dtype << " " << output.name << "[" << num_elems << "] = {"; + if (dtype == "float") { + float* p_flt = static_cast(dl_tensor.data); + for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", "; + if (num_elems) buf_stream << p_flt[num_elems - 1]; + } else if (dtype == "int") { + int* p_flt = static_cast(dl_tensor.data); + for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", "; + if (num_elems) buf_stream << p_flt[num_elems - 1]; + } else { + LOG(FATAL) << "Only float and int are supported for now."; + } + buf_stream << "};"; + ext_func_body.insert(ext_func_body.begin(), buf_stream.str()); + } + void VisitExpr_(const CallNode* call) final { std::ostringstream macro_stream; std::ostringstream decl_stream; @@ -138,6 +188,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase { int func_idx = 0; /*! \brief The index of allocated buffers. */ int buf_idx_ = 0; + /*! \brief The index of global constants. */ + int const_idx_ = 0; /*! \brief The arguments of a C compiler compatible function. */ Array ext_func_args_; /*! \brief The statements of a C compiler compatible function. */ @@ -148,6 +200,8 @@ class CodegenC : public ExprVisitor, public CodegenCBase { std::vector buf_decl_; /*! \brief The name and index pairs for output. */ std::vector out_; + /*! \brief The cached expressions. */ + std::unordered_map visited_; }; class CSourceCodegen : public CSourceModuleCodegenBase { diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index f003d22e4c42..60cecef0ce3c 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -197,7 +197,7 @@ class CodegenCBase { * \return true if the call's name is equivalent to the given name. Otherwise, * false. */ - bool IsOp(const CallNode* call, std::string op_name) const { + bool IsOp(const CallNode* call, const std::string& op_name) const { const auto* op_node = call->op.as(); CHECK(op_node) << "Expects a single op."; Op op = GetRef(op_node); @@ -218,7 +218,7 @@ class CodegenCBase { * * \return The emitted code string. */ - std::string JitImpl(std::string ext_func_id, const Array& args, + std::string JitImpl(const std::string& ext_func_id, const Array& args, const std::vector& buf_decl, const std::vector& body, const std::vector& out) { diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 84f8c2705bab..17f5cfa0778b 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -42,6 +42,8 @@ #include #include +#include "../backend/utils.h" + namespace tvm { namespace relay { namespace partitioning { @@ -200,14 +202,20 @@ class Partitioner : public ExprMutator { auto input = VisitExpr(call->args[0]); Array params; Array args; + std::unordered_map params_bind; // The subgraph may be merged so we need to update it again. subgraph = GetSubgraph(GetRef(call)); CHECK(subgraph); + // Record the constants for propagation. for (auto pair : subgraph->args) { params.push_back(pair.first); - args.push_back(pair.second); + if (const auto* cn = pair.second.as()) { + params_bind[pair.first->name_hint()] = cn->data; + } else { + args.push_back(pair.second); + } } auto subgraph_func = @@ -223,6 +231,11 @@ class Partitioner : public ExprMutator { tvm::tir::StringImmNode::make(compiler_attrs->compiler)); subgraph_func = WithAttr(std::move(subgraph_func), attr::kInline, tvm::Integer(1)); + + // Constant propagation + if (!params_bind.empty()) { + subgraph_func = backend::BindParamsByName(subgraph_func, params_bind); + } CHECK(!module_->ContainGlobalVar(name)) << "Global function " << name << " already exists"; // Create a global function and add it to the IRModule for the subgraph. diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index c4fbbc1458d9..1f37ab84d4a5 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -634,6 +634,50 @@ def expected(): assert relay.analysis.alpha_equal(partitioned, ref_mod) +def test_constant_propagation(): + ones = np.ones(shape=(8, 8), dtype="float32") + + def expected(): + mod = tvm.IRModule() + x = relay.const(ones) + y = relay.var("y", shape=(8, 8)) + x0 = relay.const(ones) + y0 = relay.var("y0", shape=(8, 8)) + add = x0 + y0 + # Function that uses C compiler + func = relay.Function([y0], add) + func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler")) + func = func.with_attr("ExternalSymbol", + tvm.tir.StringImm("ccompiler_0")) + glb_0 = relay.GlobalVar("ccompiler_0") + mod[glb_0] = func + add_call = relay.Call(glb_0, [y]) + log = relay.log(add_call) + main = relay.Function([y], log) + mod["main"] = main + return mod + + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + add = x + y + log = relay.log(add) + f = relay.Function([x, y], log) + f = relay.build_module.bind_params_by_name(f, {"x": tvm.nd.array(ones)}) + mod = tvm.IRModule() + mod["main"] = f + mod = WhiteListAnnotator(["add"], "ccompiler")(mod) + mod = transform.PartitionGraph()(mod) + + expected_mod = expected() + assert relay.alpha_equal(mod, expected_mod) + + y_data = np.random.rand(8, 8).astype('float32') + np_add = ones + y_data + check_result(mod, {"y": y_data}, (8, 8), np.log(np_add)) + + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() @@ -643,3 +687,4 @@ def expected(): test_extern_dnnl_mobilenet() test_function_lifting() test_function_lifting_inline() + test_constant_propagation()