@@ -964,8 +964,10 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
964964 auto op_name = OpMetaInfoHelper::GetOpName (base_op_meta);
965965
966966 if (OpInfoMap::Instance ().Has (op_name)) {
967- LOG (WARNING) << " Operator (" << op_name << " ) has been registered." ;
968- return ;
967+ LOG (WARNING) << " Operator (" << op_name
968+ << " ) has been registered before as PIR op." ;
969+ LOG (WARNING) << " PIR Operator (" << op_name
970+ << " ) has been overridden by Custom op!." ;
969971 }
970972
971973 auto & op_inputs = OpMetaInfoHelper::GetInputs (base_op_meta);
@@ -1268,21 +1270,24 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
12681270 OpInfoMap::Instance ().Insert (cur_op_name, info);
12691271}
12701272
1271- void RegisterOperatorWithMetaInfoMap (
1272- const paddle::OpMetaInfoMap& op_meta_info_map, void * dso_handle) {
1273+ std::unordered_map<std::string, std::vector<OpMetaInfo>>
1274+ RegisterOperatorWithMetaInfoMap (const paddle::OpMetaInfoMap& op_meta_info_map,
1275+ void * dso_handle) {
12731276 auto & meta_info_map = op_meta_info_map.GetMap ();
12741277 VLOG (3 ) << " Custom Operator: size of op meta info map - "
12751278 << meta_info_map.size ();
12761279 // pair: {op_type, OpMetaInfo}
12771280 ::pir::IrContext* ctx = ::pir::IrContext::Instance ();
12781281 auto * custom_dialect =
12791282 ctx->GetOrRegisterDialect <paddle::dialect::CustomOpDialect>();
1283+ std::unordered_map<std::string, std::vector<OpMetaInfo>> diff_map;
12801284 for (auto & pair : meta_info_map) {
12811285 VLOG (3 ) << " Custom Operator: pair first -> op name: " << pair.first ;
12821286
12831287 // Register PIR op
12841288
1285- if (custom_dialect->HasRegistered (pair.first )) {
1289+ if (custom_dialect->HasRegistered (paddle::framework::kCustomDialectPrefix +
1290+ pair.first )) {
12861291 VLOG (3 ) << " The operator `" << pair.first
12871292 << " ` has been registered. "
12881293 " Therefore, we will not repeat the registration here." ;
@@ -1293,25 +1298,31 @@ void RegisterOperatorWithMetaInfoMap(
12931298 << OpMetaInfoHelper::GetOpName (meta_info);
12941299 custom_dialect->RegisterCustomOp (meta_info);
12951300 }
1301+ diff_map[pair.first ] = pair.second ;
12961302
12971303 // Register Fluid op
12981304 RegisterOperatorWithMetaInfo (pair.second , dso_handle);
12991305 }
1306+ return diff_map;
13001307}
13011308
13021309// //////////////////// User APIs ///////////////////////
13031310
13041311// load op api
1305- const std::unordered_map<std::string, std::vector<OpMetaInfo>>&
1312+ std::unordered_map<std::string, std::vector<OpMetaInfo>>
13061313LoadOpMetaInfoAndRegisterOp (const std::string& dso_name) {
13071314 void * handle = phi::dynload::GetOpDsoHandle (dso_name);
13081315 VLOG (3 ) << " load custom_op lib: " << dso_name;
13091316 typedef OpMetaInfoMap& get_op_meta_info_map_t ();
13101317 auto * get_op_meta_info_map =
13111318 detail::DynLoad<get_op_meta_info_map_t >(handle, " PD_GetOpMetaInfoMap" );
13121319 auto & op_meta_info_map = get_op_meta_info_map ();
1313- RegisterOperatorWithMetaInfoMap (op_meta_info_map, handle);
1314- return op_meta_info_map.GetMap ();
1320+ auto diff_map = RegisterOperatorWithMetaInfoMap (op_meta_info_map, handle);
1321+ for (auto & pair : diff_map) {
1322+ VLOG (3 ) << " diff op name: " << pair.first ;
1323+ }
1324+ // return op_meta_info_map.GetMap();
1325+ return diff_map;
13151326}
13161327
13171328} // namespace paddle::framework
0 commit comments