Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion paddle/fluid/eager/api/utils/global_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,13 @@ class Controller {
void MergeOpMetaInfoMap(
const std::unordered_map<std::string, std::vector<paddle::OpMetaInfo>>&
map) {
op_meta_info_map_.insert(map.begin(), map.end());
for (const auto& [key, value] : map) {
if (op_meta_info_map_.count(key)) {
VLOG(3) << "Replacing existing OpMetaInfo for op: " << key;
}
VLOG(3) << "Merging OpMetaInfo for op: " << key;
op_meta_info_map_[key] = value;
}
}

std::unordered_map<std::string,
Expand Down
27 changes: 19 additions & 8 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -964,8 +964,10 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
auto op_name = OpMetaInfoHelper::GetOpName(base_op_meta);

if (OpInfoMap::Instance().Has(op_name)) {
LOG(WARNING) << "Operator (" << op_name << ") has been registered.";
return;
LOG(WARNING) << "Operator (" << op_name
<< ") has been registered before as PIR op.";
LOG(WARNING) << "PIR Operator (" << op_name
<< ") has been overridden by Custom op!.";
}

auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta);
Expand Down Expand Up @@ -1268,21 +1270,24 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
OpInfoMap::Instance().Insert(cur_op_name, info);
}

void RegisterOperatorWithMetaInfoMap(
const paddle::OpMetaInfoMap& op_meta_info_map, void* dso_handle) {
std::unordered_map<std::string, std::vector<OpMetaInfo>>
RegisterOperatorWithMetaInfoMap(const paddle::OpMetaInfoMap& op_meta_info_map,
void* dso_handle) {
auto& meta_info_map = op_meta_info_map.GetMap();
VLOG(3) << "Custom Operator: size of op meta info map - "
<< meta_info_map.size();
// pair: {op_type, OpMetaInfo}
::pir::IrContext* ctx = ::pir::IrContext::Instance();
auto* custom_dialect =
ctx->GetOrRegisterDialect<paddle::dialect::CustomOpDialect>();
std::unordered_map<std::string, std::vector<OpMetaInfo>> diff_map;
for (auto& pair : meta_info_map) {
VLOG(3) << "Custom Operator: pair first -> op name: " << pair.first;

// Register PIR op

if (custom_dialect->HasRegistered(pair.first)) {
if (custom_dialect->HasRegistered(paddle::framework::kCustomDialectPrefix +
pair.first)) {
VLOG(3) << "The operator `" << pair.first
<< "` has been registered. "
"Therefore, we will not repeat the registration here.";
Expand All @@ -1293,25 +1298,31 @@ void RegisterOperatorWithMetaInfoMap(
<< OpMetaInfoHelper::GetOpName(meta_info);
custom_dialect->RegisterCustomOp(meta_info);
}
diff_map[pair.first] = pair.second;

// Register Fluid op
RegisterOperatorWithMetaInfo(pair.second, dso_handle);
}
return diff_map;
}

////////////////////// User APIs ///////////////////////

// load op api
const std::unordered_map<std::string, std::vector<OpMetaInfo>>&
std::unordered_map<std::string, std::vector<OpMetaInfo>>
LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) {
void* handle = phi::dynload::GetOpDsoHandle(dso_name);
VLOG(3) << "load custom_op lib: " << dso_name;
typedef OpMetaInfoMap& get_op_meta_info_map_t();
auto* get_op_meta_info_map =
detail::DynLoad<get_op_meta_info_map_t>(handle, "PD_GetOpMetaInfoMap");
auto& op_meta_info_map = get_op_meta_info_map();
RegisterOperatorWithMetaInfoMap(op_meta_info_map, handle);
return op_meta_info_map.GetMap();
auto diff_map = RegisterOperatorWithMetaInfoMap(op_meta_info_map, handle);
for (auto& pair : diff_map) {
VLOG(3) << "diff op name: " << pair.first;
}
// return op_meta_info_map.GetMap();
return diff_map;
}

} // namespace paddle::framework
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/framework/custom_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,13 @@ class CustomGradOpMaker<imperative::OpBase>
};

// Load custom op api: register op after user compiled
const std::unordered_map<std::string, std::vector<OpMetaInfo>>&
std::unordered_map<std::string, std::vector<OpMetaInfo>>
LoadOpMetaInfoAndRegisterOp(const std::string& dso_name);

// Register custom op api: register op directly
void RegisterOperatorWithMetaInfoMap(
const paddle::OpMetaInfoMap& op_meta_info_map, void* dso_handle = nullptr);
std::unordered_map<std::string, std::vector<OpMetaInfo>>
RegisterOperatorWithMetaInfoMap(const paddle::OpMetaInfoMap& op_meta_info_map,
void* dso_handle = nullptr);

// Interface for selective register custom op.
void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/framework/op_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,12 @@ class TEST_API OpInfoMap {
}

void Insert(const std::string& type, const OpInfo& info) {
PADDLE_ENFORCE_NE(Has(type),
true,
common::errors::AlreadyExists(
"Operator (%s) has been registered.", type));
map_.insert({type, info});
if (Has(type)) {
map_[type] = info; // override ops
VLOG(0) << "Overriding op: " << type;
} else {
map_.insert({type, info});
}
}

const OpInfo& Get(const std::string& type) const {
Expand Down
12 changes: 10 additions & 2 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2969,8 +2969,16 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_glog", framework::InitGLOG);
m.def("init_memory_method", framework::InitMemoryMethod);
m.def("load_op_meta_info_and_register_op", [](const std::string dso_name) {
egr::Controller::Instance().MergeOpMetaInfoMap(
framework::LoadOpMetaInfoAndRegisterOp(dso_name));
const auto &new_op_meta_info_map =
framework::LoadOpMetaInfoAndRegisterOp(dso_name);
// Merging failed?
egr::Controller::Instance().MergeOpMetaInfoMap(new_op_meta_info_map);

py::list key_list;
for (const auto &pair : new_op_meta_info_map) {
key_list.append(pair.first);
}
return key_list;
});
m.def("init_devices", []() { framework::InitDevices(); });
m.def("init_default_kernel_signatures",
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -3141,14 +3141,14 @@ def get_op_proto(self, type):
raise ValueError(f'Operator "{type}" has not been registered.')
return self.op_proto_map[type]

def update_op_proto(self):
def update_op_proto(self, new_op_list):
op_protos = get_all_op_protos()
custom_op_names = []
for proto in op_protos:
if proto.type not in self.op_proto_map:
self.op_proto_map[proto.type] = proto
custom_op_names.append(proto.type)

custom_op_names = list(set(custom_op_names).union(set(new_op_list)))
return custom_op_names

def has_op_proto(self, type):
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/utils/cpp_extension/extension_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,9 @@ def bootstrap_context():


def load_op_meta_info_and_register_op(lib_filename: str) -> list[str]:
core.load_op_meta_info_and_register_op(lib_filename)
return OpProtoHolder.instance().update_op_proto()
new_list = core.load_op_meta_info_and_register_op(lib_filename)
proto_sync_ops = OpProtoHolder.instance().update_op_proto(new_list)
return proto_sync_ops


def custom_write_stub(resource, pyfile):
Expand Down