Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] polish the ir_mapping implimentation. #60675

Merged
merged 1 commit into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
pir::IrMapping mapper;
auto cloned_program = program.Clone(mapper);
std::vector<pir::OpResult> associated_array_key, associated_array_value;
for (auto &pair : mapper.Map<pir::Value>()) {
for (auto &pair : mapper.GetMap<pir::Value>()) {
associated_array_key.push_back(pair.first.dyn_cast<pir::OpResult>());
associated_array_value.push_back(pair.second.dyn_cast<pir::OpResult>());
}
Expand Down Expand Up @@ -1119,12 +1119,12 @@ SplitedResult SplitForwardBackward(
auto *cloned_op = op->Clone(forward_mapper, clone_options);
forward_program->block()->push_back(cloned_op);
});
auto &forward_value_map = forward_mapper.MutableMap<pir::Value>();
auto &forward_value_map = forward_mapper.GetMutableMap<pir::Value>();

// backward program construc.
// Step1. insert data op for inputs_values and middle_values
pir::IrMapping backward_mapper;
auto &backward_value_map = backward_mapper.MutableMap<pir::Value>();
auto &backward_value_map = backward_mapper.GetMutableMap<pir::Value>();
int counter = 0;
auto create_data_fn = [&backward_builder,
&backward_inputs,
Expand Down
84 changes: 49 additions & 35 deletions paddle/pir/core/ir_mapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,65 @@ namespace pir {
class Block;
class Operation;

namespace detail {
template <typename T, typename... OthersT>
struct ExactlyOneIrType {
using type = void;
};
template <typename T, typename FirstT, typename... OthersT>
struct ExactlyOneIrType<T, FirstT, OthersT...> {
using type =
std::conditional_t<std::is_convertible<T, FirstT>::value,
FirstT,
typename ExactlyOneIrType<T, OthersT...>::type>;
};
} // namespace detail
class IrMapping {
public:
template <typename T>
void Add(T from, T to) {
using IrType =
typename detail::ExactlyOneIrType<T, Value, Block*, Operation*>::type;
template <typename T>
std::unordered_map<T, T>& GetMutableMap() {
if constexpr (std::is_same<T, Value>::value) {
return value_map_;
} else if constexpr (std::is_same<T, Block*>::value) {
return block_map_;
} else if constexpr (std::is_same<T, Operation*>::value) {
return operation_map_;
} else {
IR_THROW("Not support type in IRMapping.");
}
}
template <typename T>
const std::unordered_map<T, T>& GetMap() const {
if constexpr (std::is_same<T, Value>::value) {
return value_map_;
} else if constexpr (std::is_same<T, Block*>::value) {
return block_map_;
} else if constexpr (std::is_same<T, Operation*>::value) {
return operation_map_;
} else {
IR_THROW("Not support type in IRMapping.");
}
}
template <typename T, typename S>
void Add(T from, S to) {
if (!from) return;
MutableMap<T>()[from] = to;
GetMutableMap<IrType<T>>()[from] = to;
}

template <typename T>
T Lookup(T from) const {
if (!from) return static_cast<T>(nullptr);
IR_ENFORCE(Map<T>().count(from) > 0, "Not found key in IRMapping.");
return Map<T>().at(from);
IR_ENFORCE(GetMap<IrType<T>>().count(from) > 0,
"Not found key in IRMapping.");
return GetMap<IrType<T>>().at(from);
}

template <typename T>
void Earse(T from) {
MutableMap<T>().erase(from);
GetMutableMap<IrType<T>>().erase(from);
}

void Clear() {
Expand All @@ -46,37 +87,10 @@ class IrMapping {
operation_map_.clear();
}

template <typename T>
using MapType = std::unordered_map<T, T>;

template <typename T>
const MapType<T> &Map() const {
if constexpr (std::is_convertible<T, Value>::value)
return value_map_;
else if constexpr (std::is_convertible<T, Block *>::value)
return block_map_;
else if constexpr (std::is_convertible<T, Operation *>::value)
return operation_map_;
else
IR_THROW("Not support type in IRMapping.");
}

template <typename T>
MapType<T> &MutableMap() {
if constexpr (std::is_convertible<T, Value>::value)
return value_map_;
else if constexpr (std::is_convertible<T, Block *>::value)
return block_map_;
else if constexpr (std::is_convertible<T, Operation *>::value)
return operation_map_;
else
IR_THROW("Not support type in IRMapping.");
}

private:
MapType<Value> value_map_;
MapType<Block *> block_map_;
MapType<Operation *> operation_map_;
std::unordered_map<Value, Value> value_map_;
std::unordered_map<Block*, Block*> block_map_;
std::unordered_map<Operation*, Operation*> operation_map_;
};

} // namespace pir
3 changes: 1 addition & 2 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {

// record outputs mapping info
for (uint32_t i = 0; i < num_results_; ++i) {
ir_mapping.Add(static_cast<Value>(result(i)),
static_cast<Value>(new_op->result(i)));
ir_mapping.Add(result(i), new_op->result(i));
}

if (options.IsCloneRegions()) {
Expand Down