Skip to content

Commit

Permalink
Remove ccompiler constant updater
Browse files Browse the repository at this point in the history
Change-Id: Iea9ee0f689683512fa114afeadeccb7fc9048e4f
  • Loading branch information
mbaret committed Oct 27, 2020
1 parent c808324 commit 9f4d039
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 30 deletions.
29 changes: 0 additions & 29 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,37 +290,8 @@ runtime::Module CCompiler(const ObjectRef& ref) {
return csource.CreateCSourceModule(ref);
}

/*!
* \brief A visitor to add the constants used as params for MetadataModule.
*/
struct CCompilerConstantUpdater : public ExprVisitor {
public:
explicit CCompilerConstantUpdater(const std::string& symbol) : symbol_(symbol) {}

Map<String, runtime::NDArray> GetConstants(const Expr& expr) {
VisitExpr(expr);
return this->params_;
}

void VisitExpr_(const ConstantNode* cn) final {
std::string name = symbol_ + "_p" + std::to_string(const_idx_++);
params_.Set(name, cn->data);
}

private:
int const_idx_{0};
std::string symbol_;
Map<String, runtime::NDArray> params_;
};

Map<String, runtime::NDArray> GetConstants(const Expr& expr, const std::string symbol) {
return CCompilerConstantUpdater(symbol).GetConstants(expr);
}

TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler);

TVM_REGISTER_GLOBAL("relay.ext.ccompiler.constant_updater").set_body_typed(GetConstants);

} // namespace contrib
} // namespace relay
} // namespace tvm
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class CodegenCBase {
* \return The created variable name
*/
std::string CreateConstVar(const std::string& symbol, int const_id) const {
return symbol + "_p" + std::to_string(const_id++);
return symbol + "_const_" + std::to_string(const_id++);
}

/*! \brief The external function source code stream. */
Expand Down
5 changes: 5 additions & 0 deletions tests/python/relay/test_external_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ def test_extern_gcc():


def test_extern_gcc_consts():
@tvm._ffi.register_func("relay.ext.ccompiler.constant_updater")
def constant_updater(expr, symbol):
"""A dummy constant updater just to test that a custom one works."""
return {"ccompiler_0_p0": tvm.nd.array(y0_data)}

x = relay.var("x", shape=(8, 8))
y0_data = np.random.uniform(0, 1, (8, 8)).astype("float32")

Expand Down

0 comments on commit 9f4d039

Please sign in to comment.