Skip to content

Commit c563ed1

Browse files
authored
Fix namespace conflict issue between PIR and custom op, with style of overridding (#74622)
* Fix namespace conflict issue between PIR and custom op, with style of override. * fix miscs. * polish
1 parent 50c3189 commit c563ed1

File tree

7 files changed

+51
-23
lines changed

7 files changed

+51
-23
lines changed

paddle/fluid/eager/api/utils/global_utils.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,13 @@ class Controller {
101101
void MergeOpMetaInfoMap(
102102
const std::unordered_map<std::string, std::vector<paddle::OpMetaInfo>>&
103103
map) {
104-
op_meta_info_map_.insert(map.begin(), map.end());
104+
for (const auto& [key, value] : map) {
105+
if (op_meta_info_map_.count(key)) {
106+
VLOG(3) << "Replacing existing OpMetaInfo for op: " << key;
107+
}
108+
VLOG(3) << "Merging OpMetaInfo for op: " << key;
109+
op_meta_info_map_[key] = value;
110+
}
105111
}
106112

107113
std::unordered_map<std::string,

paddle/fluid/framework/custom_operator.cc

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -964,8 +964,10 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
964964
auto op_name = OpMetaInfoHelper::GetOpName(base_op_meta);
965965

966966
if (OpInfoMap::Instance().Has(op_name)) {
967-
LOG(WARNING) << "Operator (" << op_name << ") has been registered.";
968-
return;
967+
LOG(WARNING) << "Operator (" << op_name
968+
<< ") has been registered before as PIR op.";
969+
LOG(WARNING) << "PIR Operator (" << op_name
970+
<< ") has been overridden by Custom op!.";
969971
}
970972

971973
auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta);
@@ -1268,21 +1270,24 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
12681270
OpInfoMap::Instance().Insert(cur_op_name, info);
12691271
}
12701272

1271-
void RegisterOperatorWithMetaInfoMap(
1272-
const paddle::OpMetaInfoMap& op_meta_info_map, void* dso_handle) {
1273+
std::unordered_map<std::string, std::vector<OpMetaInfo>>
1274+
RegisterOperatorWithMetaInfoMap(const paddle::OpMetaInfoMap& op_meta_info_map,
1275+
void* dso_handle) {
12731276
auto& meta_info_map = op_meta_info_map.GetMap();
12741277
VLOG(3) << "Custom Operator: size of op meta info map - "
12751278
<< meta_info_map.size();
12761279
// pair: {op_type, OpMetaInfo}
12771280
::pir::IrContext* ctx = ::pir::IrContext::Instance();
12781281
auto* custom_dialect =
12791282
ctx->GetOrRegisterDialect<paddle::dialect::CustomOpDialect>();
1283+
std::unordered_map<std::string, std::vector<OpMetaInfo>> diff_map;
12801284
for (auto& pair : meta_info_map) {
12811285
VLOG(3) << "Custom Operator: pair first -> op name: " << pair.first;
12821286

12831287
// Register PIR op
12841288

1285-
if (custom_dialect->HasRegistered(pair.first)) {
1289+
if (custom_dialect->HasRegistered(paddle::framework::kCustomDialectPrefix +
1290+
pair.first)) {
12861291
VLOG(3) << "The operator `" << pair.first
12871292
<< "` has been registered. "
12881293
"Therefore, we will not repeat the registration here.";
@@ -1293,25 +1298,31 @@ void RegisterOperatorWithMetaInfoMap(
12931298
<< OpMetaInfoHelper::GetOpName(meta_info);
12941299
custom_dialect->RegisterCustomOp(meta_info);
12951300
}
1301+
diff_map[pair.first] = pair.second;
12961302

12971303
// Register Fluid op
12981304
RegisterOperatorWithMetaInfo(pair.second, dso_handle);
12991305
}
1306+
return diff_map;
13001307
}
13011308

13021309
////////////////////// User APIs ///////////////////////
13031310

13041311
// load op api
1305-
const std::unordered_map<std::string, std::vector<OpMetaInfo>>&
1312+
std::unordered_map<std::string, std::vector<OpMetaInfo>>
13061313
LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) {
13071314
void* handle = phi::dynload::GetOpDsoHandle(dso_name);
13081315
VLOG(3) << "load custom_op lib: " << dso_name;
13091316
typedef OpMetaInfoMap& get_op_meta_info_map_t();
13101317
auto* get_op_meta_info_map =
13111318
detail::DynLoad<get_op_meta_info_map_t>(handle, "PD_GetOpMetaInfoMap");
13121319
auto& op_meta_info_map = get_op_meta_info_map();
1313-
RegisterOperatorWithMetaInfoMap(op_meta_info_map, handle);
1314-
return op_meta_info_map.GetMap();
1320+
auto diff_map = RegisterOperatorWithMetaInfoMap(op_meta_info_map, handle);
1321+
for (auto& pair : diff_map) {
1322+
VLOG(3) << "diff op name: " << pair.first;
1323+
}
1324+
// return op_meta_info_map.GetMap();
1325+
return diff_map;
13151326
}
13161327

13171328
} // namespace paddle::framework

paddle/fluid/framework/custom_operator.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,12 +311,13 @@ class CustomGradOpMaker<imperative::OpBase>
311311
};
312312

313313
// Load custom op api: register op after user compiled
314-
const std::unordered_map<std::string, std::vector<OpMetaInfo>>&
314+
std::unordered_map<std::string, std::vector<OpMetaInfo>>
315315
LoadOpMetaInfoAndRegisterOp(const std::string& dso_name);
316316

317317
// Register custom op api: register op directly
318-
void RegisterOperatorWithMetaInfoMap(
319-
const paddle::OpMetaInfoMap& op_meta_info_map, void* dso_handle = nullptr);
318+
std::unordered_map<std::string, std::vector<OpMetaInfo>>
319+
RegisterOperatorWithMetaInfoMap(const paddle::OpMetaInfoMap& op_meta_info_map,
320+
void* dso_handle = nullptr);
320321

321322
// Interface for selective register custom op.
322323
void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,

paddle/fluid/framework/op_info.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,12 @@ class TEST_API OpInfoMap {
138138
}
139139

140140
void Insert(const std::string& type, const OpInfo& info) {
141-
PADDLE_ENFORCE_NE(Has(type),
142-
true,
143-
common::errors::AlreadyExists(
144-
"Operator (%s) has been registered.", type));
145-
map_.insert({type, info});
141+
if (Has(type)) {
142+
map_[type] = info; // override ops
143+
VLOG(0) << "Overriding op: " << type;
144+
} else {
145+
map_.insert({type, info});
146+
}
146147
}
147148

148149
const OpInfo& Get(const std::string& type) const {

paddle/fluid/pybind/pybind.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3033,8 +3033,16 @@ All parameter, weight, gradient are variables in Paddle.
30333033
m.def("init_glog", framework::InitGLOG);
30343034
m.def("init_memory_method", framework::InitMemoryMethod);
30353035
m.def("load_op_meta_info_and_register_op", [](const std::string dso_name) {
3036-
egr::Controller::Instance().MergeOpMetaInfoMap(
3037-
framework::LoadOpMetaInfoAndRegisterOp(dso_name));
3036+
const auto &new_op_meta_info_map =
3037+
framework::LoadOpMetaInfoAndRegisterOp(dso_name);
3038+
// Merging failed?
3039+
egr::Controller::Instance().MergeOpMetaInfoMap(new_op_meta_info_map);
3040+
3041+
py::list key_list;
3042+
for (const auto &pair : new_op_meta_info_map) {
3043+
key_list.append(pair.first);
3044+
}
3045+
return key_list;
30383046
});
30393047
m.def("init_devices", []() { framework::InitDevices(); });
30403048
m.def("init_default_kernel_signatures",

python/paddle/base/framework.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3143,14 +3143,14 @@ def get_op_proto(self, type):
31433143
raise ValueError(f'Operator "{type}" has not been registered.')
31443144
return self.op_proto_map[type]
31453145

3146-
def update_op_proto(self):
3146+
def update_op_proto(self, new_op_list):
31473147
op_protos = get_all_op_protos()
31483148
custom_op_names = []
31493149
for proto in op_protos:
31503150
if proto.type not in self.op_proto_map:
31513151
self.op_proto_map[proto.type] = proto
31523152
custom_op_names.append(proto.type)
3153-
3153+
custom_op_names = list(set(custom_op_names).union(set(new_op_list)))
31543154
return custom_op_names
31553155

31563156
def has_op_proto(self, type):

python/paddle/utils/cpp_extension/extension_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,9 @@ def bootstrap_context():
164164

165165

166166
def load_op_meta_info_and_register_op(lib_filename: str) -> list[str]:
167-
core.load_op_meta_info_and_register_op(lib_filename)
168-
return OpProtoHolder.instance().update_op_proto()
167+
new_list = core.load_op_meta_info_and_register_op(lib_filename)
168+
proto_sync_ops = OpProtoHolder.instance().update_op_proto(new_list)
169+
return proto_sync_ops
169170

170171

171172
def custom_write_stub(resource, pyfile):

0 commit comments

Comments
 (0)