Skip to content

Commit

Permalink
[PIR] polish the ir_mapping implimentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang committed Jan 10, 2024
1 parent da5399a commit 6a6d6a7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 40 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
pir::IrMapping mapper;
auto cloned_program = program.Clone(mapper);
std::vector<pir::OpResult> associated_array_key, associated_array_value;
for (auto &pair : mapper.Map<pir::Value>()) {
for (auto &pair : mapper.GetMap<pir::Value>()) {
associated_array_key.push_back(pair.first.dyn_cast<pir::OpResult>());
associated_array_value.push_back(pair.second.dyn_cast<pir::OpResult>());
}
Expand Down Expand Up @@ -1119,12 +1119,12 @@ SplitedResult SplitForwardBackward(
auto *cloned_op = op->Clone(forward_mapper, clone_options);
forward_program->block()->push_back(cloned_op);
});
auto &forward_value_map = forward_mapper.MutableMap<pir::Value>();
auto &forward_value_map = forward_mapper.GetMutableMap<pir::Value>();

// backward program construc.
// Step1. insert data op for inputs_values and middle_values
pir::IrMapping backward_mapper;
auto &backward_value_map = backward_mapper.MutableMap<pir::Value>();
auto &backward_value_map = backward_mapper.GetMutableMap<pir::Value>();
int counter = 0;
auto create_data_fn = [&backward_builder,
&backward_inputs,
Expand Down
84 changes: 49 additions & 35 deletions paddle/pir/core/ir_mapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,65 @@ namespace pir {
class Block;
class Operation;

namespace detail {
template <typename T, typename... OthersT>
struct ExactlyOneIrType {
using type = void;
};
template <typename T, typename FirstT, typename... OthersT>
struct ExactlyOneIrType<T, FirstT, OthersT...> {
using type =
std::conditional_t<std::is_convertible<T, FirstT>::value,
FirstT,
typename ExactlyOneIrType<T, OthersT...>::type>;
};
} // namespace detail
class IrMapping {
public:
template <typename T>
void Add(T from, T to) {
using IrType =
typename detail::ExactlyOneIrType<T, Value, Block*, Operation*>::type;
template <typename T>
std::unordered_map<T, T>& GetMutableMap() {
if constexpr (std::is_same<T, Value>::value) {
return value_map_;
} else if constexpr (std::is_same<T, Block*>::value) {
return block_map_;
} else if constexpr (std::is_same<T, Operation*>::value) {
return operation_map_;
} else {
IR_THROW("Not support type in IRMapping.");
}
}
template <typename T>
const std::unordered_map<T, T>& GetMap() const {
if constexpr (std::is_same<T, Value>::value) {
return value_map_;
} else if constexpr (std::is_same<T, Block*>::value) {
return block_map_;
} else if constexpr (std::is_same<T, Operation*>::value) {
return operation_map_;
} else {
IR_THROW("Not support type in IRMapping.");
}
}
template <typename T, typename S>
void Add(T from, S to) {
if (!from) return;
MutableMap<T>()[from] = to;
GetMutableMap<IrType<T>>()[from] = to;
}

template <typename T>
T Lookup(T from) const {
if (!from) return static_cast<T>(nullptr);
IR_ENFORCE(Map<T>().count(from) > 0, "Not found key in IRMapping.");
return Map<T>().at(from);
IR_ENFORCE(GetMap<IrType<T>>().count(from) > 0,
"Not found key in IRMapping.");
return GetMap<IrType<T>>().at(from);
}

template <typename T>
void Earse(T from) {
MutableMap<T>().erase(from);
GetMutableMap<IrType<T>>().erase(from);
}

void Clear() {
Expand All @@ -46,37 +87,10 @@ class IrMapping {
operation_map_.clear();
}

template <typename T>
using MapType = std::unordered_map<T, T>;

template <typename T>
const MapType<T> &Map() const {
if constexpr (std::is_convertible<T, Value>::value)
return value_map_;
else if constexpr (std::is_convertible<T, Block *>::value)
return block_map_;
else if constexpr (std::is_convertible<T, Operation *>::value)
return operation_map_;
else
IR_THROW("Not support type in IRMapping.");
}

template <typename T>
MapType<T> &MutableMap() {
if constexpr (std::is_convertible<T, Value>::value)
return value_map_;
else if constexpr (std::is_convertible<T, Block *>::value)
return block_map_;
else if constexpr (std::is_convertible<T, Operation *>::value)
return operation_map_;
else
IR_THROW("Not support type in IRMapping.");
}

private:
MapType<Value> value_map_;
MapType<Block *> block_map_;
MapType<Operation *> operation_map_;
std::unordered_map<Value, Value> value_map_;
std::unordered_map<Block*, Block*> block_map_;
std::unordered_map<Operation*, Operation*> operation_map_;
};

} // namespace pir
3 changes: 1 addition & 2 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {

// record outputs mapping info
for (uint32_t i = 0; i < num_results_; ++i) {
ir_mapping.Add(static_cast<Value>(result(i)),
static_cast<Value>(new_op->result(i)));
ir_mapping.Add(result(i), new_op->result(i));
}

if (options.IsCloneRegions()) {
Expand Down

0 comments on commit 6a6d6a7

Please sign in to comment.