Skip to content

Commit

Permalink
[PIR] change IR clone to const and support clone operation successors (
Browse files Browse the repository at this point in the history
…#60752)

* support ir clone const and support clone operation successors

* refine ir_mapping

* refine region clone
  • Loading branch information
huangjiyi authored Jan 12, 2024
1 parent f968050 commit d604bcd
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 38 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/pir/group.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct Group {
std::vector<::pir::Operation*> new_ops;
// Mapper from original to new ops.
std::unordered_map<::pir::Operation*, ::pir::Operation*> ops_mapper;
::pir::CloneOptions clone_options(false, true);
auto clone_options = ::pir::CloneOptions(false, true, false);
for (auto* op : ops) {
VLOG(4) << "clone op :" << op->name();
auto* new_op = op->Clone(ir_mapping, clone_options);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
using OpResultMap =
std::pair<std::vector<pir::OpResult>, std::vector<pir::OpResult>>;
std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
Program &program) { // NOLINT
const Program &program) {
// Limitation of this function:
// 1. don't support Parameters.
pir::IrMapping mapper;
Expand Down Expand Up @@ -1111,7 +1111,7 @@ SplitedResult SplitForwardBackward(
// forward program construct.
VLOG(4) << "start create forward program.";
pir::IrMapping forward_mapper;
auto clone_options = pir::CloneOptions(true, true);
auto clone_options = pir::CloneOptions::All();
range_block_do(
program.block(),
forward_range,
Expand Down
27 changes: 19 additions & 8 deletions paddle/pir/core/ir_mapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,22 @@ struct ExactlyOneIrType<T, FirstT, OthersT...> {
typename ExactlyOneIrType<T, OthersT...>::type>;
};
} // namespace detail

class IrMapping {
public:
template <typename T>
using IrType =
typename detail::ExactlyOneIrType<T, Value, Block*, Operation*>::type;
using remove_lowlevel_const_t = std::conditional_t<
std::is_pointer<T>::value,
std::add_pointer_t<std::remove_const_t<std::remove_pointer_t<T>>>,
T>;
template <typename T>
using IrType = typename detail::ExactlyOneIrType<remove_lowlevel_const_t<T>,
Value,
Block*,
Operation*>::type;

template <typename T>
std::unordered_map<T, T>& GetMutableMap() {
auto& GetMutableMap() {
if constexpr (std::is_same<T, Value>::value) {
return value_map_;
} else if constexpr (std::is_same<T, Block*>::value) {
Expand All @@ -50,8 +59,9 @@ class IrMapping {
IR_THROW("Not support type in IRMapping.");
}
}

template <typename T>
const std::unordered_map<T, T>& GetMap() const {
const auto& GetMap() const {
if constexpr (std::is_same<T, Value>::value) {
return value_map_;
} else if constexpr (std::is_same<T, Block*>::value) {
Expand All @@ -62,15 +72,16 @@ class IrMapping {
IR_THROW("Not support type in IRMapping.");
}
}

template <typename T, typename S>
void Add(T from, S to) {
if (!from) return;
GetMutableMap<IrType<T>>()[from] = to;
}

template <typename T>
T Lookup(T from) const {
if (!from) return static_cast<T>(nullptr);
IrType<T> Lookup(T from) const {
if (!from) return static_cast<IrType<T>>(nullptr);
IR_ENFORCE(GetMap<IrType<T>>().count(from) > 0,
"Not found key in IRMapping.");
return GetMap<IrType<T>>().at(from);
Expand All @@ -89,8 +100,8 @@ class IrMapping {

private:
std::unordered_map<Value, Value> value_map_;
std::unordered_map<Block*, Block*> block_map_;
std::unordered_map<Operation*, Operation*> operation_map_;
std::unordered_map<const Block*, Block*> block_map_;
std::unordered_map<const Operation*, Operation*> operation_map_;
};

} // namespace pir
17 changes: 11 additions & 6 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,7 @@ Operation *Operation::Create(const std::vector<Value> &inputs,
return op;
}

Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
IR_ENFORCE(num_successors_ == 0,
"Operation::Clone is not unimplemented for multiple successors.");

Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) const {
auto inputs = operands_source();
if (options.IsCloneOperands()) {
// replace value by IRMapping inplacely.
Expand All @@ -153,7 +150,15 @@ Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
for (auto &result : results()) {
output_types.push_back(result.type());
}
auto *new_op = Create(inputs, attributes_, output_types, info_, num_regions_);

std::vector<Block *> successors = {};
if (options.IsCloneSuccessors()) {
for (uint32_t i = 0; i < num_successors_; ++i) {
successors.push_back(ir_mapping.Lookup(successor(i)));
}
}
auto *new_op = Create(
inputs, attributes_, output_types, info_, num_regions_, successors);
ir_mapping.Add(this, new_op);

// record outputs mapping info
Expand Down Expand Up @@ -246,7 +251,7 @@ Operation::Operation(const AttributeMap &attributes,
///
/// \brief op ouput related public interfaces implementation
///
std::vector<OpResult> Operation::results() {
std::vector<OpResult> Operation::results() const {
std::vector<OpResult> res;
for (uint32_t i = 0; i < num_results(); ++i) {
res.push_back(result(i));
Expand Down
22 changes: 17 additions & 5 deletions paddle/pir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,28 @@ class OpOperendImpl;

class CloneOptions {
public:
CloneOptions() : clone_regions_{false}, clone_operands_{false} {}
CloneOptions(bool clone_regions, bool clone_operands)
: clone_regions_(clone_regions), clone_operands_(clone_operands) {}
CloneOptions()
: clone_regions_{false},
clone_operands_{false},
clone_successors_{false} {}
CloneOptions(bool clone_regions, bool clone_operands, bool clone_successors)
: clone_regions_(clone_regions),
clone_operands_(clone_operands),
clone_successors_(clone_successors) {}

bool IsCloneRegions() const { return clone_regions_; }
bool IsCloneOperands() const { return clone_operands_; }
bool IsCloneSuccessors() const { return clone_successors_; }

static CloneOptions &All() {
static CloneOptions all{true, true, true};
return all;
}

private:
bool clone_regions_{true};
bool clone_operands_{true};
bool clone_successors_{true};
};

class IR_API alignas(8) Operation final
Expand All @@ -72,7 +84,7 @@ class IR_API alignas(8) Operation final
/// \brief Deep copy all information and create a new operation.
///
Operation *Clone(IrMapping &ir_mapping,
CloneOptions options = CloneOptions());
CloneOptions options = CloneOptions()) const;
///
/// \brief Destroy the operation objects and free memory by create().
///
Expand Down Expand Up @@ -114,7 +126,7 @@ class IR_API alignas(8) Operation final
T result_type(uint32_t index) const {
return result(index).type().dyn_cast<T>();
}
std::vector<OpResult> results();
std::vector<OpResult> results() const;

///
/// \brief op input related public interfaces
Expand Down
6 changes: 3 additions & 3 deletions paddle/pir/core/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ Program::~Program() {
}
}

std::shared_ptr<Program> Program::Clone(IrMapping& ir_mapping) {
std::shared_ptr<Program> Program::Clone(IrMapping& ir_mapping) const {
pir::IrContext* ctx = pir::IrContext::Instance();
auto new_program = std::make_shared<Program>(ctx);
auto clone_options = CloneOptions(true, true);
for (auto& op : *block()) {
auto clone_options = CloneOptions::All();
for (const auto& op : *block()) {
auto* new_op = op.Clone(ir_mapping, clone_options);
new_program->block()->push_back(new_op);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class IR_API Program {

static std::unique_ptr<Program> Parse(std::istream& is, IrContext* ctx);

std::shared_ptr<Program> Clone(IrMapping& ir_mapping); // NOLINT
std::shared_ptr<Program> Clone(IrMapping& ir_mapping) const; // NOLINT

Block* block() { return &module_.block(); }
const Block* block() const { return &module_op().block(); }
Expand Down
23 changes: 15 additions & 8 deletions paddle/pir/core/region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,30 @@ Region::Iterator Region::erase(ConstIterator position) {
return blocks_.erase(position);
}

void Region::CloneInto(Region &other, IrMapping &ir_mapping) {
void Region::CloneInto(Region &other, IrMapping &ir_mapping) const {
if (empty()) {
return;
}
other.clear();
auto clone_options = CloneOptions(false, false);
// clone blocks, block arguments and sub operations
for (auto &block : *this) {
for (const auto &block : *this) {
auto new_block = new Block;
ir_mapping.Add(&block, new_block);
for (auto &arg : block.args()) {
for (const auto &arg : block.args()) {
ir_mapping.Add(arg, new_block->AddArgument(arg.type()));
}
other.push_back(new_block);
// clone sub operations, but not map operands nor clone regions
for (auto op_iter = block.begin(); op_iter != block.end(); ++op_iter) {
new_block->push_back(op_iter->Clone(ir_mapping, clone_options));
}
// clone sub operations, but not map operands nor clone regions
{
auto clone_options = CloneOptions(false, false, true);
auto iter = begin();
auto new_iter = other.begin();
for (; iter != end(); ++iter, ++new_iter) {
const Block &block = *iter;
Block &new_block = *new_iter;
for (const auto &op : block)
new_block.push_back(op.Clone(ir_mapping, clone_options));
}
}
// after all operation results are mapped, map operands and clone regions.
Expand All @@ -69,7 +76,7 @@ void Region::CloneInto(Region &other, IrMapping &ir_mapping) {
auto op_iter = iter->begin();
auto new_op_iter = new_iter->begin();
for (; op_iter != iter->end(); ++op_iter, ++new_op_iter) {
Operation &op = *op_iter;
const Operation &op = *op_iter;
Operation &new_op = *new_op_iter;
// operands of new_op are same as op, now map them.
for (uint32_t i = 0; i < op.num_operands(); ++i)
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/region.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class IR_API Region {
void Walk(FuncT &&callback);

// clone this region into another region, target region will be overwritten.
void CloneInto(Region &other, IrMapping &ir_mapping); // NOLINT
void CloneInto(Region &other, IrMapping &ir_mapping) const; // NOLINT

// take the last block of region.
// if region is empty, return nullptr;
Expand Down
5 changes: 2 additions & 3 deletions test/cpp/pir/core/ir_region_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ TEST(region, clone_op_test) {
pir::Operation &op = *program.module_op();
pir::Block &block = op.region(0).front();
pir::IrMapping mapper;
pir::Operation &new_op = *op.Clone(mapper, pir::CloneOptions(true, true));
pir::Operation &new_op = *op.Clone(mapper, pir::CloneOptions::All());

// (8) Check the cloned op recursively
EXPECT_EQ(mapper.Lookup(&op), &new_op);
Expand All @@ -103,8 +103,7 @@ TEST(region, clone_op_test) {
}
EXPECT_EQ(op.num_results(), new_op.num_results());
for (uint32_t i = 0; i < op.num_results(); ++i) {
EXPECT_EQ(mapper.Lookup(static_cast<pir::Value>(op.result(i))),
static_cast<pir::Value>(new_op.result(i)));
EXPECT_EQ(mapper.Lookup(op.result(i)), new_op.result(i));
}
EXPECT_TRUE(std::equal(op.attributes().begin(),
op.attributes().end(),
Expand Down

0 comments on commit d604bcd

Please sign in to comment.