Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] change IR clone to const and support clone operation successors #60752

Merged
merged 3 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/pir/group.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct Group {
std::vector<::pir::Operation*> new_ops;
// Mapper from original to new ops.
std::unordered_map<::pir::Operation*, ::pir::Operation*> ops_mapper;
::pir::CloneOptions clone_options(false, true);
auto clone_options = ::pir::CloneOptions(false, true, false);
for (auto* op : ops) {
VLOG(4) << "clone op :" << op->name();
auto* new_op = op->Clone(ir_mapping, clone_options);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
using OpResultMap =
std::pair<std::vector<pir::OpResult>, std::vector<pir::OpResult>>;
std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
Program &program) { // NOLINT
const Program &program) {
// Limitation of this function:
// 1. don't support Parameters.
pir::IrMapping mapper;
Expand Down Expand Up @@ -1111,7 +1111,7 @@ SplitedResult SplitForwardBackward(
// forward program construct.
VLOG(4) << "start create forward program.";
pir::IrMapping forward_mapper;
auto clone_options = pir::CloneOptions(true, true);
auto clone_options = pir::CloneOptions::All();
range_block_do(
program.block(),
forward_range,
Expand Down
27 changes: 19 additions & 8 deletions paddle/pir/core/ir_mapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,22 @@ struct ExactlyOneIrType<T, FirstT, OthersT...> {
typename ExactlyOneIrType<T, OthersT...>::type>;
};
} // namespace detail

class IrMapping {
public:
template <typename T>
using IrType =
typename detail::ExactlyOneIrType<T, Value, Block*, Operation*>::type;
using remove_lowlevel_const_t = std::conditional_t<
std::is_pointer<T>::value,
std::add_pointer_t<std::remove_const_t<std::remove_pointer_t<T>>>,
T>;
template <typename T>
using IrType = typename detail::ExactlyOneIrType<remove_lowlevel_const_t<T>,
Value,
Block*,
Operation*>::type;

template <typename T>
std::unordered_map<T, T>& GetMutableMap() {
auto& GetMutableMap() {
if constexpr (std::is_same<T, Value>::value) {
return value_map_;
} else if constexpr (std::is_same<T, Block*>::value) {
Expand All @@ -50,8 +59,9 @@ class IrMapping {
IR_THROW("Not support type in IRMapping.");
}
}

template <typename T>
const std::unordered_map<T, T>& GetMap() const {
const auto& GetMap() const {
if constexpr (std::is_same<T, Value>::value) {
return value_map_;
} else if constexpr (std::is_same<T, Block*>::value) {
Expand All @@ -62,15 +72,16 @@ class IrMapping {
IR_THROW("Not support type in IRMapping.");
}
}

template <typename T, typename S>
void Add(T from, S to) {
if (!from) return;
GetMutableMap<IrType<T>>()[from] = to;
}

template <typename T>
T Lookup(T from) const {
if (!from) return static_cast<T>(nullptr);
IrType<T> Lookup(T from) const {
huangjiyi marked this conversation as resolved.
Show resolved Hide resolved
if (!from) return static_cast<IrType<T>>(nullptr);
IR_ENFORCE(GetMap<IrType<T>>().count(from) > 0,
"Not found key in IRMapping.");
return GetMap<IrType<T>>().at(from);
Expand All @@ -89,8 +100,8 @@ class IrMapping {

private:
std::unordered_map<Value, Value> value_map_;
std::unordered_map<Block*, Block*> block_map_;
std::unordered_map<Operation*, Operation*> operation_map_;
std::unordered_map<const Block*, Block*> block_map_;
std::unordered_map<const Operation*, Operation*> operation_map_;
};

} // namespace pir
17 changes: 11 additions & 6 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,7 @@ Operation *Operation::Create(const std::vector<Value> &inputs,
return op;
}

Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
IR_ENFORCE(num_successors_ == 0,
"Operation::Clone is not unimplemented for multiple successors.");

Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) const {
auto inputs = operands_source();
if (options.IsCloneOperands()) {
// replace value by IRMapping inplacely.
Expand All @@ -153,7 +150,15 @@ Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
for (auto &result : results()) {
output_types.push_back(result.type());
}
auto *new_op = Create(inputs, attributes_, output_types, info_, num_regions_);

std::vector<Block *> successors = {};
if (options.IsCloneSuccessors()) {
for (uint32_t i = 0; i < num_successors_; ++i) {
successors.push_back(ir_mapping.Lookup(successor(i)));
}
}
auto *new_op = Create(
inputs, attributes_, output_types, info_, num_regions_, successors);
ir_mapping.Add(this, new_op);

// record outputs mapping info
Expand Down Expand Up @@ -246,7 +251,7 @@ Operation::Operation(const AttributeMap &attributes,
///
/// \brief op ouput related public interfaces implementation
///
std::vector<OpResult> Operation::results() {
std::vector<OpResult> Operation::results() const {
std::vector<OpResult> res;
for (uint32_t i = 0; i < num_results(); ++i) {
res.push_back(result(i));
Expand Down
22 changes: 17 additions & 5 deletions paddle/pir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,28 @@ class OpOperendImpl;

class CloneOptions {
public:
CloneOptions() : clone_regions_{false}, clone_operands_{false} {}
CloneOptions(bool clone_regions, bool clone_operands)
: clone_regions_(clone_regions), clone_operands_(clone_operands) {}
CloneOptions()
: clone_regions_{false},
clone_operands_{false},
clone_successors_{false} {}
CloneOptions(bool clone_regions, bool clone_operands, bool clone_successors)
: clone_regions_(clone_regions),
clone_operands_(clone_operands),
clone_successors_(clone_successors) {}

bool IsCloneRegions() const { return clone_regions_; }
bool IsCloneOperands() const { return clone_operands_; }
bool IsCloneSuccessors() const { return clone_successors_; }

static CloneOptions &All() {
static CloneOptions all{true, true, true};
return all;
}

private:
bool clone_regions_{true};
bool clone_operands_{true};
bool clone_successors_{true};
};

class IR_API alignas(8) Operation final
Expand All @@ -72,7 +84,7 @@ class IR_API alignas(8) Operation final
/// \brief Deep copy all information and create a new operation.
///
Operation *Clone(IrMapping &ir_mapping,
CloneOptions options = CloneOptions());
CloneOptions options = CloneOptions()) const;
///
/// \brief Destroy the operation objects and free memory by create().
///
Expand Down Expand Up @@ -114,7 +126,7 @@ class IR_API alignas(8) Operation final
T result_type(uint32_t index) const {
return result(index).type().dyn_cast<T>();
}
std::vector<OpResult> results();
std::vector<OpResult> results() const;

///
/// \brief op input related public interfaces
Expand Down
6 changes: 3 additions & 3 deletions paddle/pir/core/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ Program::~Program() {
}
}

std::shared_ptr<Program> Program::Clone(IrMapping& ir_mapping) {
std::shared_ptr<Program> Program::Clone(IrMapping& ir_mapping) const {
pir::IrContext* ctx = pir::IrContext::Instance();
auto new_program = std::make_shared<Program>(ctx);
auto clone_options = CloneOptions(true, true);
for (auto& op : *block()) {
auto clone_options = CloneOptions::All();
for (const auto& op : *block()) {
auto* new_op = op.Clone(ir_mapping, clone_options);
new_program->block()->push_back(new_op);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class IR_API Program {

static std::unique_ptr<Program> Parse(std::istream& is, IrContext* ctx);

std::shared_ptr<Program> Clone(IrMapping& ir_mapping); // NOLINT
std::shared_ptr<Program> Clone(IrMapping& ir_mapping) const; // NOLINT

Block* block() { return &module_.block(); }
const Block* block() const { return &module_op().block(); }
Expand Down
23 changes: 15 additions & 8 deletions paddle/pir/core/region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,30 @@ Region::Iterator Region::erase(ConstIterator position) {
return blocks_.erase(position);
}

void Region::CloneInto(Region &other, IrMapping &ir_mapping) {
void Region::CloneInto(Region &other, IrMapping &ir_mapping) const {
if (empty()) {
return;
}
other.clear();
auto clone_options = CloneOptions(false, false);
// clone blocks, block arguments and sub operations
for (auto &block : *this) {
for (const auto &block : *this) {
auto new_block = new Block;
ir_mapping.Add(&block, new_block);
for (auto &arg : block.args()) {
for (const auto &arg : block.args()) {
ir_mapping.Add(arg, new_block->AddArgument(arg.type()));
}
other.push_back(new_block);
// clone sub operations, but not map operands nor clone regions
for (auto op_iter = block.begin(); op_iter != block.end(); ++op_iter) {
new_block->push_back(op_iter->Clone(ir_mapping, clone_options));
}
// clone sub operations, but not map operands nor clone regions
{
auto clone_options = CloneOptions(false, false, true);
auto iter = begin();
auto new_iter = other.begin();
for (; iter != end(); ++iter, ++new_iter) {
const Block &block = *iter;
Block &new_block = *new_iter;
for (const auto &op : block)
new_block.push_back(op.Clone(ir_mapping, clone_options));
}
}
// after all operation results are mapped, map operands and clone regions.
Expand All @@ -69,7 +76,7 @@ void Region::CloneInto(Region &other, IrMapping &ir_mapping) {
auto op_iter = iter->begin();
auto new_op_iter = new_iter->begin();
for (; op_iter != iter->end(); ++op_iter, ++new_op_iter) {
Operation &op = *op_iter;
const Operation &op = *op_iter;
Operation &new_op = *new_op_iter;
// operands of new_op are same as op, now map them.
for (uint32_t i = 0; i < op.num_operands(); ++i)
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/region.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class IR_API Region {
void Walk(FuncT &&callback);

// clone this region into another region, target region will be overwritten.
void CloneInto(Region &other, IrMapping &ir_mapping); // NOLINT
void CloneInto(Region &other, IrMapping &ir_mapping) const; // NOLINT

// take the last block of region.
// if region is empty, return nullptr;
Expand Down
5 changes: 2 additions & 3 deletions test/cpp/pir/core/ir_region_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ TEST(region, clone_op_test) {
pir::Operation &op = *program.module_op();
pir::Block &block = op.region(0).front();
pir::IrMapping mapper;
pir::Operation &new_op = *op.Clone(mapper, pir::CloneOptions(true, true));
pir::Operation &new_op = *op.Clone(mapper, pir::CloneOptions::All());

// (8) Check the cloned op recursively
EXPECT_EQ(mapper.Lookup(&op), &new_op);
Expand All @@ -103,8 +103,7 @@ TEST(region, clone_op_test) {
}
EXPECT_EQ(op.num_results(), new_op.num_results());
for (uint32_t i = 0; i < op.num_results(); ++i) {
EXPECT_EQ(mapper.Lookup(static_cast<pir::Value>(op.result(i))),
static_cast<pir::Value>(new_op.result(i)));
EXPECT_EQ(mapper.Lookup(op.result(i)), new_op.result(i));
}
EXPECT_TRUE(std::equal(op.attributes().begin(),
op.attributes().end(),
Expand Down