Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] change IR clone to const and support clone operation successors #60752

Merged
merged 3 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/pir/group.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct Group {
std::vector<::pir::Operation*> new_ops;
// Mapper from original to new ops.
std::unordered_map<::pir::Operation*, ::pir::Operation*> ops_mapper;
::pir::CloneOptions clone_options(false, true);
auto clone_options = ::pir::CloneOptions(false, true, false);
for (auto* op : ops) {
VLOG(4) << "clone op :" << op->name();
auto* new_op = op->Clone(ir_mapping, clone_options);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
using OpResultMap =
std::pair<std::vector<pir::OpResult>, std::vector<pir::OpResult>>;
std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
Program &program) { // NOLINT
const Program &program) {
// Limitation of this function:
// 1. don't support Parameters.
pir::IrMapping mapper;
Expand Down Expand Up @@ -1111,7 +1111,7 @@ SplitedResult SplitForwardBackward(
// forward program construct.
VLOG(4) << "start create forward program.";
pir::IrMapping forward_mapper;
auto clone_options = pir::CloneOptions(true, true);
auto clone_options = pir::CloneOptions::All();
range_block_do(
program.block(),
forward_range,
Expand Down
41 changes: 32 additions & 9 deletions paddle/pir/core/ir_mapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,46 @@ class Block;
class Operation;

namespace detail {
template <typename T, typename... OthersT>
struct ExactlyOneConstIrType {
using type = void;
};
template <typename T, typename FirstT, typename... OthersT>
struct ExactlyOneConstIrType<T, FirstT, OthersT...> {
using type =
std::conditional_t<std::is_convertible<T, FirstT>::value,
FirstT,
typename ExactlyOneConstIrType<T, OthersT...>::type>;
};
template <typename T>
using ConstIrType = typename detail::
ExactlyOneConstIrType<T, Value, const Block*, const Operation*>::type;

template <typename T, typename... OthersT>
struct ExactlyOneIrType {
using type = void;
};
template <typename T, typename FirstT, typename... OthersT>
struct ExactlyOneIrType<T, FirstT, OthersT...> {
using type =
std::conditional_t<std::is_convertible<T, FirstT>::value,
std::conditional_t<std::is_convertible<FirstT, ConstIrType<T>>::value,
FirstT,
typename ExactlyOneIrType<T, OthersT...>::type>;
};
template <typename T>
using IrType =
typename detail::ExactlyOneIrType<T, Value, Block*, Operation*>::type;
} // namespace detail

class IrMapping {
public:
template <typename T>
using IrType =
typename detail::ExactlyOneIrType<T, Value, Block*, Operation*>::type;
using ConstIrType = detail::ConstIrType<T>;
template <typename T>
std::unordered_map<T, T>& GetMutableMap() {
using IrType = detail::IrType<T>;

template <typename T>
std::unordered_map<ConstIrType<T>, T>& GetMutableMap() {
if constexpr (std::is_same<T, Value>::value) {
return value_map_;
} else if constexpr (std::is_same<T, Block*>::value) {
Expand All @@ -50,8 +71,9 @@ class IrMapping {
IR_THROW("Not support type in IRMapping.");
}
}

template <typename T>
const std::unordered_map<T, T>& GetMap() const {
const std::unordered_map<ConstIrType<T>, T>& GetMap() const {
if constexpr (std::is_same<T, Value>::value) {
return value_map_;
} else if constexpr (std::is_same<T, Block*>::value) {
Expand All @@ -62,15 +84,16 @@ class IrMapping {
IR_THROW("Not support type in IRMapping.");
}
}

template <typename T, typename S>
void Add(T from, S to) {
if (!from) return;
GetMutableMap<IrType<T>>()[from] = to;
}

template <typename T>
T Lookup(T from) const {
if (!from) return static_cast<T>(nullptr);
IrType<T> Lookup(T from) const {
huangjiyi marked this conversation as resolved.
Show resolved Hide resolved
if (!from) return static_cast<IrType<T>>(nullptr);
IR_ENFORCE(GetMap<IrType<T>>().count(from) > 0,
"Not found key in IRMapping.");
return GetMap<IrType<T>>().at(from);
Expand All @@ -89,8 +112,8 @@ class IrMapping {

private:
std::unordered_map<Value, Value> value_map_;
std::unordered_map<Block*, Block*> block_map_;
std::unordered_map<Operation*, Operation*> operation_map_;
std::unordered_map<const Block*, Block*> block_map_;
std::unordered_map<const Operation*, Operation*> operation_map_;
};

} // namespace pir
17 changes: 11 additions & 6 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,7 @@ Operation *Operation::Create(const std::vector<Value> &inputs,
return op;
}

Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
IR_ENFORCE(num_successors_ == 0,
"Operation::Clone is not unimplemented for multiple successors.");

Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) const {
auto inputs = operands_source();
if (options.IsCloneOperands()) {
// replace value by IRMapping inplacely.
Expand All @@ -153,7 +150,15 @@ Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
for (auto &result : results()) {
output_types.push_back(result.type());
}
auto *new_op = Create(inputs, attributes_, output_types, info_, num_regions_);

std::vector<Block *> successors = {};
if (options.IsCloneSuccessors()) {
for (uint32_t i = 0; i < num_successors_; ++i) {
successors.push_back(ir_mapping.Lookup(successor(i)));
}
}
auto *new_op = Create(
inputs, attributes_, output_types, info_, num_regions_, successors);
ir_mapping.Add(this, new_op);

// record outputs mapping info
Expand Down Expand Up @@ -246,7 +251,7 @@ Operation::Operation(const AttributeMap &attributes,
///
/// \brief op ouput related public interfaces implementation
///
std::vector<OpResult> Operation::results() {
std::vector<OpResult> Operation::results() const {
std::vector<OpResult> res;
for (uint32_t i = 0; i < num_results(); ++i) {
res.push_back(result(i));
Expand Down
22 changes: 17 additions & 5 deletions paddle/pir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,28 @@ class OpOperendImpl;

class CloneOptions {
public:
CloneOptions() : clone_regions_{false}, clone_operands_{false} {}
CloneOptions(bool clone_regions, bool clone_operands)
: clone_regions_(clone_regions), clone_operands_(clone_operands) {}
CloneOptions()
: clone_regions_{false},
clone_operands_{false},
clone_successors_{false} {}
CloneOptions(bool clone_regions, bool clone_operands, bool clone_successors)
: clone_regions_(clone_regions),
clone_operands_(clone_operands),
clone_successors_(clone_successors) {}

bool IsCloneRegions() const { return clone_regions_; }
bool IsCloneOperands() const { return clone_operands_; }
bool IsCloneSuccessors() const { return clone_successors_; }

static CloneOptions &All() {
static CloneOptions all{true, true, true};
return all;
}

private:
bool clone_regions_{true};
bool clone_operands_{true};
bool clone_successors_{true};
};

class IR_API alignas(8) Operation final
Expand All @@ -72,7 +84,7 @@ class IR_API alignas(8) Operation final
/// \brief Deep copy all information and create a new operation.
///
Operation *Clone(IrMapping &ir_mapping,
CloneOptions options = CloneOptions());
CloneOptions options = CloneOptions()) const;
///
/// \brief Destroy the operation objects and free memory by create().
///
Expand Down Expand Up @@ -114,7 +126,7 @@ class IR_API alignas(8) Operation final
T result_type(uint32_t index) const {
return result(index).type().dyn_cast<T>();
}
std::vector<OpResult> results();
std::vector<OpResult> results() const;

///
/// \brief op input related public interfaces
Expand Down
6 changes: 3 additions & 3 deletions paddle/pir/core/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ Program::~Program() {
}
}

std::shared_ptr<Program> Program::Clone(IrMapping& ir_mapping) {
std::shared_ptr<Program> Program::Clone(IrMapping& ir_mapping) const {
pir::IrContext* ctx = pir::IrContext::Instance();
auto new_program = std::make_shared<Program>(ctx);
auto clone_options = CloneOptions(true, true);
for (auto& op : *block()) {
auto clone_options = CloneOptions::All();
for (const auto& op : *block()) {
auto* new_op = op.Clone(ir_mapping, clone_options);
new_program->block()->push_back(new_op);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class IR_API Program {

static std::unique_ptr<Program> Parse(std::istream& is, IrContext* ctx);

std::shared_ptr<Program> Clone(IrMapping& ir_mapping); // NOLINT
std::shared_ptr<Program> Clone(IrMapping& ir_mapping) const; // NOLINT

Block* block() { return &module_.block(); }
const Block* block() const { return &module_op().block(); }
Expand Down
10 changes: 5 additions & 5 deletions paddle/pir/core/region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ Region::Iterator Region::erase(ConstIterator position) {
return blocks_.erase(position);
}

void Region::CloneInto(Region &other, IrMapping &ir_mapping) {
void Region::CloneInto(Region &other, IrMapping &ir_mapping) const {
if (empty()) {
return;
}
other.clear();
auto clone_options = CloneOptions(false, false);
auto clone_options = CloneOptions();
// clone blocks, block arguments and sub operations
for (auto &block : *this) {
for (const auto &block : *this) {
auto new_block = new Block;
ir_mapping.Add(&block, new_block);
for (auto &arg : block.args()) {
for (const auto &arg : block.args()) {
ir_mapping.Add(arg, new_block->AddArgument(arg.type()));
}
other.push_back(new_block);
Expand All @@ -69,7 +69,7 @@ void Region::CloneInto(Region &other, IrMapping &ir_mapping) {
auto op_iter = iter->begin();
auto new_op_iter = new_iter->begin();
for (; op_iter != iter->end(); ++op_iter, ++new_op_iter) {
Operation &op = *op_iter;
const Operation &op = *op_iter;
Operation &new_op = *new_op_iter;
// operands of new_op are same as op, now map them.
for (uint32_t i = 0; i < op.num_operands(); ++i)
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/region.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class IR_API Region {
void Walk(FuncT &&callback);

// clone this region into another region, target region will be overwritten.
void CloneInto(Region &other, IrMapping &ir_mapping); // NOLINT
void CloneInto(Region &other, IrMapping &ir_mapping) const; // NOLINT

// take the last block of region.
// if region is empty, return nullptr;
Expand Down
5 changes: 2 additions & 3 deletions test/cpp/pir/core/ir_region_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ TEST(region, clone_op_test) {
pir::Operation &op = *program.module_op();
pir::Block &block = op.region(0).front();
pir::IrMapping mapper;
pir::Operation &new_op = *op.Clone(mapper, pir::CloneOptions(true, true));
pir::Operation &new_op = *op.Clone(mapper, pir::CloneOptions::All());

// (8) Check the cloned op recursively
EXPECT_EQ(mapper.Lookup(&op), &new_op);
Expand All @@ -103,8 +103,7 @@ TEST(region, clone_op_test) {
}
EXPECT_EQ(op.num_results(), new_op.num_results());
for (uint32_t i = 0; i < op.num_results(); ++i) {
EXPECT_EQ(mapper.Lookup(static_cast<pir::Value>(op.result(i))),
static_cast<pir::Value>(new_op.result(i)));
EXPECT_EQ(mapper.Lookup(op.result(i)), new_op.result(i));
}
EXPECT_TRUE(std::equal(op.attributes().begin(),
op.attributes().end(),
Expand Down