diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 71f19a1c21bcf..88faff22cd310 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -565,6 +565,19 @@ class RelayBuildModule : public runtime::ModuleNode { auto ext_mods = executor_codegen_->GetExternalModules(); ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, GetTargetHost(), executor_codegen_->GetMetadata()); + // Remove external params which were stored in metadata module. + for (tvm::runtime::Module mod : ext_mods) { + auto pf_var = mod.GetFunction("get_const_vars"); + if (pf_var != nullptr) { + Array variables = pf_var(); + for (size_t i = 0; i < variables.size(); i++) { + auto it = ret_.params.find(variables[i].operator std::string()); + if (it != ret_.params.end()) { + ret_.params.erase(it); + } + } + } + } } private: diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index be92ef200c31c..156abfc4c22ac 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -352,6 +352,8 @@ def test_load_params_with_constants_in_ext_codegen(): mod = transform.PartitionGraph()(mod) graph_module = relay.build(mod, target="llvm", params=params) + # Params will be stored in metadata module. + assert len(graph_module.get_params()) == 0 lib = update_lib(graph_module.get_lib()) rt_mod = tvm.contrib.graph_executor.create(graph_module.get_graph_json(), lib, tvm.cpu(0)) rt_mod.load_params(runtime.save_param_dict(graph_module.get_params()))