diff --git a/paddle/fluid/eager/api/utils/global_utils.h b/paddle/fluid/eager/api/utils/global_utils.h index 99287e66d5f825..2be972011101fe 100644 --- a/paddle/fluid/eager/api/utils/global_utils.h +++ b/paddle/fluid/eager/api/utils/global_utils.h @@ -101,7 +101,13 @@ class Controller { void MergeOpMetaInfoMap( const std::unordered_map>& map) { - op_meta_info_map_.insert(map.begin(), map.end()); + for (const auto& [key, value] : map) { + if (op_meta_info_map_.count(key)) { + VLOG(3) << "Replacing existing OpMetaInfo for op: " << key; + } + VLOG(3) << "Merging OpMetaInfo for op: " << key; + op_meta_info_map_[key] = value; + } } std::unordered_map& op_meta_infos, auto op_name = OpMetaInfoHelper::GetOpName(base_op_meta); if (OpInfoMap::Instance().Has(op_name)) { - LOG(WARNING) << "Operator (" << op_name << ") has been registered."; - return; + LOG(WARNING) << "Operator (" << op_name + << ") has been registered before as PIR op."; + LOG(WARNING) << "PIR Operator (" << op_name + << ") has been overridden by Custom op!."; } auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta); @@ -1268,8 +1270,9 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, OpInfoMap::Instance().Insert(cur_op_name, info); } -void RegisterOperatorWithMetaInfoMap( - const paddle::OpMetaInfoMap& op_meta_info_map, void* dso_handle) { +std::unordered_map> +RegisterOperatorWithMetaInfoMap(const paddle::OpMetaInfoMap& op_meta_info_map, + void* dso_handle) { auto& meta_info_map = op_meta_info_map.GetMap(); VLOG(3) << "Custom Operator: size of op meta info map - " << meta_info_map.size(); @@ -1277,12 +1280,14 @@ void RegisterOperatorWithMetaInfoMap( ::pir::IrContext* ctx = ::pir::IrContext::Instance(); auto* custom_dialect = ctx->GetOrRegisterDialect(); + std::unordered_map> diff_map; for (auto& pair : meta_info_map) { VLOG(3) << "Custom Operator: pair first -> op name: " << pair.first; // Register PIR op - if (custom_dialect->HasRegistered(pair.first)) { + if (custom_dialect->HasRegistered(paddle::framework::kCustomDialectPrefix + + pair.first)) { VLOG(3) << "The operator `" << pair.first << "` has been registered. " "Therefore, we will not repeat the registration here."; @@ -1293,16 +1298,18 @@ void RegisterOperatorWithMetaInfoMap( << OpMetaInfoHelper::GetOpName(meta_info); custom_dialect->RegisterCustomOp(meta_info); } + diff_map[pair.first] = pair.second; // Register Fluid op RegisterOperatorWithMetaInfo(pair.second, dso_handle); } + return diff_map; } ////////////////////// User APIs /////////////////////// // load op api -const std::unordered_map>& +std::unordered_map> LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { void* handle = phi::dynload::GetOpDsoHandle(dso_name); VLOG(3) << "load custom_op lib: " << dso_name; @@ -1310,8 +1317,12 @@ LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { auto* get_op_meta_info_map = detail::DynLoad(handle, "PD_GetOpMetaInfoMap"); auto& op_meta_info_map = get_op_meta_info_map(); - RegisterOperatorWithMetaInfoMap(op_meta_info_map, handle); - return op_meta_info_map.GetMap(); + auto diff_map = RegisterOperatorWithMetaInfoMap(op_meta_info_map, handle); + for (auto& pair : diff_map) { + VLOG(3) << "diff op name: " << pair.first; + } + // return op_meta_info_map.GetMap(); + return diff_map; } } // namespace paddle::framework diff --git a/paddle/fluid/framework/custom_operator.h b/paddle/fluid/framework/custom_operator.h index 1226be3df7564a..c779aa44aa8bf9 100644 --- a/paddle/fluid/framework/custom_operator.h +++ b/paddle/fluid/framework/custom_operator.h @@ -311,12 +311,13 @@ class CustomGradOpMaker }; // Load custom op api: register op after user compiled -const std::unordered_map>& +std::unordered_map> LoadOpMetaInfoAndRegisterOp(const std::string& dso_name); // Register custom op api: register op directly -void RegisterOperatorWithMetaInfoMap( - const paddle::OpMetaInfoMap& op_meta_info_map, void* dso_handle = nullptr); +std::unordered_map> +RegisterOperatorWithMetaInfoMap(const paddle::OpMetaInfoMap& op_meta_info_map, + void* dso_handle = nullptr); // Interface for selective register custom op. void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h index a48eb2edbcfccb..a23c7a06dcb597 100644 --- a/paddle/fluid/framework/op_info.h +++ b/paddle/fluid/framework/op_info.h @@ -138,11 +138,12 @@ class TEST_API OpInfoMap { } void Insert(const std::string& type, const OpInfo& info) { - PADDLE_ENFORCE_NE(Has(type), - true, - common::errors::AlreadyExists( - "Operator (%s) has been registered.", type)); - map_.insert({type, info}); + if (Has(type)) { + map_[type] = info; // override ops + VLOG(0) << "Overriding op: " << type; + } else { + map_.insert({type, info}); + } } const OpInfo& Get(const std::string& type) const { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 19034ba6459c13..805509a9b05890 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2969,8 +2969,16 @@ All parameter, weight, gradient are variables in Paddle. m.def("init_glog", framework::InitGLOG); m.def("init_memory_method", framework::InitMemoryMethod); m.def("load_op_meta_info_and_register_op", [](const std::string dso_name) { - egr::Controller::Instance().MergeOpMetaInfoMap( - framework::LoadOpMetaInfoAndRegisterOp(dso_name)); + const auto &new_op_meta_info_map = + framework::LoadOpMetaInfoAndRegisterOp(dso_name); + // Merging failed? + egr::Controller::Instance().MergeOpMetaInfoMap(new_op_meta_info_map); + + py::list key_list; + for (const auto &pair : new_op_meta_info_map) { + key_list.append(pair.first); + } + return key_list; }); m.def("init_devices", []() { framework::InitDevices(); }); m.def("init_default_kernel_signatures", diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index 973063a331d007..4058b7a626cf57 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -3141,14 +3141,14 @@ def get_op_proto(self, type): raise ValueError(f'Operator "{type}" has not been registered.') return self.op_proto_map[type] - def update_op_proto(self): + def update_op_proto(self, new_op_list): op_protos = get_all_op_protos() custom_op_names = [] for proto in op_protos: if proto.type not in self.op_proto_map: self.op_proto_map[proto.type] = proto custom_op_names.append(proto.type) - + custom_op_names = list(set(custom_op_names).union(set(new_op_list))) return custom_op_names def has_op_proto(self, type): diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 2a2a84d0d736c0..6a9b1f40af7ae3 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -164,8 +164,9 @@ def bootstrap_context(): def load_op_meta_info_and_register_op(lib_filename: str) -> list[str]: - core.load_op_meta_info_and_register_op(lib_filename) - return OpProtoHolder.instance().update_op_proto() + new_list = core.load_op_meta_info_and_register_op(lib_filename) + proto_sync_ops = OpProtoHolder.instance().update_op_proto(new_list) + return proto_sync_ops def custom_write_stub(resource, pyfile):