Skip to content

Commit

Permalink
【fix bug】fix bug for program_converter (#61051)
Browse files Browse the repository at this point in the history
* fix convert bug

* modify code style
  • Loading branch information
zyt1024 authored Jan 25, 2024
1 parent 4cca092 commit 43c38ed
Showing 1 changed file with 28 additions and 34 deletions.
62 changes: 28 additions & 34 deletions paddle/fluid/framework/program_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,42 +43,36 @@ std::pair<bool, std::unordered_multimap<std::string, OpDesc*>> DetectLegacyOps(

// get *all kinds* of formats of op versions and op version map to a unified
// representation before comparison can be done in a neat way
if (!program->HasOpVersionMap()) {
is_legacy_program = true;
} else {
legacy_op_versions =
paddle::framework::compatible::pb::GetLegacyOpVersions();
legacy_op_versions = paddle::framework::compatible::pb::GetLegacyOpVersions();

const auto* _op_version_map = program->OpVersionMap();
for (int i = 0; i < _op_version_map->pair_size(); ++i) {
auto pair =
std::make_pair(_op_version_map->pair(i).op_name(),
static_cast<uint32_t>(
_op_version_map->pair(i).op_version().version()));
program_op_versions.insert(pair);
}
const auto* _op_version_map = program->OpVersionMap();
for (int i = 0; i < _op_version_map->pair_size(); ++i) {
auto pair = std::make_pair(
_op_version_map->pair(i).op_name(),
static_cast<uint32_t>(_op_version_map->pair(i).op_version().version()));
program_op_versions.insert(pair);
}

const size_t num_blocks = program->Size();
for (size_t i = 0; i < num_blocks; i++) {
BlockDesc* block = program->MutableBlock(i);
const size_t num_ops = block->OpSize();
for (size_t j = 0; j < num_ops; j++) {
OpDesc* op = block->Op(static_cast<int>(j));
const std::string& op_type = op->Type();
if (needConvertedOperators.find(op_type) !=
needConvertedOperators.end()) {
// If an operator (program_op) is in the needConvertedOperators set,
// it indicates that the operator may need to be converted.
// Further judgement: if the operator does not exist in the
// program_op_version_map, the operator needs to be converted.
// Moreover, if the operator does exist and its program_op_version_
// is less than or equal legacy_op_version, the operator also needs to
// be converted.
if (!program_op_versions.count(op_type) ||
program_op_versions[op_type] <= legacy_op_versions[op_type]) {
is_legacy_program = true;
legacy_op_map.insert(std::make_pair(op_type, op));
}
const size_t num_blocks = program->Size();
for (size_t i = 0; i < num_blocks; i++) {
BlockDesc* block = program->MutableBlock(i);
const size_t num_ops = block->OpSize();
for (size_t j = 0; j < num_ops; j++) {
OpDesc* op = block->Op(static_cast<int>(j));
const std::string& op_type = op->Type();
if (needConvertedOperators.find(op_type) !=
needConvertedOperators.end()) {
// If an operator (program_op) is in the needConvertedOperators set,
// it indicates that the operator may need to be converted.
// Further judgement: if the operator does not exist in the
// program_op_version_map, the operator needs to be converted.
// Moreover, if the operator does exist and its program_op_version_
// is less than or equal legacy_op_version, the operator also needs to
// be converted.
if (!program_op_versions.count(op_type) ||
program_op_versions[op_type] <= legacy_op_versions[op_type]) {
is_legacy_program = true;
legacy_op_map.insert(std::make_pair(op_type, op));
}
}
}
Expand Down

0 comments on commit 43c38ed

Please sign in to comment.