Skip to content

Commit

Permalink
[PIR]add block arguement. (PaddlePaddle#57249)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Sep 15, 2023
1 parent 5a64329 commit 362a9c9
Show file tree
Hide file tree
Showing 19 changed files with 349 additions and 85 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
ctx_, defining_op_result, parameter_name_mappings_[var_name]);

pir::Block* block = program_->block();
pir::Block::iterator insert_pos = std::find(
pir::Block::Iterator insert_pos = std::find(
block->begin(), block->end(), defining_op_result.owner());

IR_ENFORCE(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ class InplacePass : public pir::Pass {
.at("op_name")
.dyn_cast<pir::StrAttribute>()
.AsString();
pir::Block::iterator insert_pos =
pir::Block::Iterator insert_pos =
std::find(block->begin(), block->end(), kv.first);
IR_ENFORCE(insert_pos != block->end(),
"Operator %s not found in block.",
Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,7 @@ void BindOpResult(py::module *m) {
return paddle::dialect::greater_equal(self, other);
})
.def("__hash__",
[](OpResult &self) {
return std::hash<pir::Value>{}(self.dyn_cast<pir::Value>());
})
[](OpResult &self) { return std::hash<pir::Value>{}(self); })
.def("get_defining_op",
&OpResult::GetDefiningOp,
return_value_policy::reference)
Expand Down
21 changes: 17 additions & 4 deletions paddle/pir/core/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@

namespace pir {
Block::~Block() {
assert(use_empty() && "block destroyed still has uses.");
if (!use_empty()) {
LOG(FATAL) << "Destoryed a block that is still in use.";
}
clear();
ClearArguments();
}
void Block::push_back(Operation *op) { insert(ops_.end(), op); }

Expand All @@ -33,13 +36,13 @@ Operation *Block::GetParentOp() const {
return parent_ ? parent_->GetParent() : nullptr;
}

Block::iterator Block::insert(const_iterator iterator, Operation *op) {
Block::iterator iter = ops_.insert(iterator, op);
Block::Iterator Block::insert(ConstIterator iterator, Operation *op) {
Block::Iterator iter = ops_.insert(iterator, op);
op->SetParent(this, iter);
return iter;
}

Block::iterator Block::erase(const_iterator position) {
Block::Iterator Block::erase(ConstIterator position) {
IR_ENFORCE((*position)->GetParent() == this, "iterator not own this block.");
(*position)->Destroy();
return ops_.erase(position);
Expand Down Expand Up @@ -75,6 +78,16 @@ void Block::ResetOpListOrder(const OpListType &new_op_list) {
}
}

void Block::ClearArguments() {
for (auto &argument : arguments_) {
argument.Destroy();
}
arguments_.clear();
}
void Block::AddArgument(Type type) {
arguments_.emplace_back(BlockArgument::Create(type, this, arguments_.size()));
}

bool Block::TopoOrderCheck(const OpListType &op_list) {
std::unordered_set<Value> visited_values;
for (const Operation *op : op_list) {
Expand Down
59 changes: 46 additions & 13 deletions paddle/pir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cstddef>
#include <list>

#include "paddle/pir/core/block_argument.h"
#include "paddle/pir/core/block_operand.h"
#include "paddle/pir/core/dll_decl.h"
#include "paddle/pir/core/region.h"
Expand All @@ -29,9 +30,9 @@ class IR_API Block {
using OpListType = std::list<Operation *>;

public:
using iterator = OpListType::iterator;
using reverse_iterator = OpListType::reverse_iterator;
using const_iterator = OpListType::const_iterator;
using Iterator = OpListType::iterator;
using ReverseIterator = OpListType::reverse_iterator;
using ConstIterator = OpListType::const_iterator;

Block() = default;
~Block();
Expand All @@ -42,19 +43,19 @@ class IR_API Block {
bool empty() const { return ops_.empty(); }
size_t size() const { return ops_.size(); }

const_iterator begin() const { return ops_.begin(); }
const_iterator end() const { return ops_.end(); }
iterator begin() { return ops_.begin(); }
iterator end() { return ops_.end(); }
reverse_iterator rbegin() { return ops_.rbegin(); }
reverse_iterator rend() { return ops_.rend(); }
ConstIterator begin() const { return ops_.begin(); }
ConstIterator end() const { return ops_.end(); }
Iterator begin() { return ops_.begin(); }
Iterator end() { return ops_.end(); }
ReverseIterator rbegin() { return ops_.rbegin(); }
ReverseIterator rend() { return ops_.rend(); }

Operation *back() const { return ops_.back(); }
Operation *front() const { return ops_.front(); }
void push_back(Operation *op);
void push_front(Operation *op);
iterator insert(const_iterator iterator, Operation *op);
iterator erase(const_iterator position);
Iterator insert(ConstIterator iterator, Operation *op);
Iterator erase(ConstIterator position);
void clear();
operator Region::iterator() { return position_; }

Expand All @@ -73,6 +74,29 @@ class IR_API Block {
// This is a unsafe funcion, please use it carefully.
void ResetOpListOrder(const OpListType &new_op_list);

///
/// \brief Block argument management
///
using BlockArgListType = std::vector<BlockArgument>;
using ArgsIterator = BlockArgListType::iterator;

ArgsIterator args_begin() { return arguments_.begin(); }
ArgsIterator args_end() { return arguments_.end(); }
bool args_empty() const { return arguments_.empty(); }
uint32_t args_size() const { return arguments_.size(); }
BlockArgument argument(uint32_t index) { return arguments_[index]; }
Type argument_type(uint32_t index) const { return arguments_[index].type(); }

void ClearArguments();
void AddArgument(Type type);
template <class TypeIter>
void AddArguments(TypeIter first, TypeIter last);

template <class TypeContainer>
void AddArguments(const TypeContainer &container) {
AddArguments(container.begin(), container.end());
}

private:
Block(Block &) = delete;
Block &operator=(const Block &) = delete;
Expand All @@ -84,9 +108,18 @@ class IR_API Block {
static bool TopoOrderCheck(const OpListType &op_list);

private:
Region *parent_; // not owned
OpListType ops_; // owned
Region::iterator position_;
BlockOperand first_use_;
OpListType ops_; // owned
BlockArgListType arguments_; // owned
Region *parent_; // not owned
};

template <class TypeIter>
void Block::AddArguments(TypeIter first, TypeIter last) {
while (first != last) {
AddArgument(*first++);
}
}

} // namespace pir
95 changes: 95 additions & 0 deletions paddle/pir/core/block_argument.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pir/core/block_argument.h"
#include "paddle/pir/core/enforce.h"
#include "paddle/pir/core/value_impl.h"

#define CHECK_NULL_IMPL(func_name) \
IR_ENFORCE(impl_, "impl_ is null when called BlockArgument:" #func_name)

#define IMPL_ static_cast<detail::BlockArgumentImpl *>(impl_)

namespace pir {

namespace detail {
///
/// \brief BlockArgumentImpl is the implementation of an block argument.
///
class BlockArgumentImpl : public ValueImpl {
public:
static bool classof(const ValueImpl &value) {
return value.kind() == BLOCK_ARGUMENT_INDEX;
}

private:
BlockArgumentImpl(Type type, Block *owner, uint32_t index)
: ValueImpl(type, BLOCK_ARGUMENT_INDEX), owner_(owner), index_(index) {}

~BlockArgumentImpl();
// access construction and owner
friend BlockArgument;
Block *owner_;
uint32_t index_;
};

BlockArgumentImpl::~BlockArgumentImpl() {
if (!use_empty()) {
LOG(FATAL) << "Destoryed a blockargument that is still in use.";
}
}

} // namespace detail

BlockArgument::BlockArgument(detail::BlockArgumentImpl *impl) : Value(impl) {}

bool BlockArgument::classof(Value value) {
return value && detail::BlockArgumentImpl::classof(*value.impl());
}

Block *BlockArgument::owner() const {
CHECK_NULL_IMPL(owner);
return IMPL_->owner_;
}

uint32_t BlockArgument::arg_index() const {
CHECK_NULL_IMPL(arg_index);
return IMPL_->index_;
}

BlockArgument BlockArgument::Create(Type type, Block *owner, uint32_t index) {
return new detail::BlockArgumentImpl(type, owner, index);
}
/// Destroy the argument.
void BlockArgument::Destroy() {
if (impl_) {
LOG(WARNING) << "Destroying a null block argument.";
} else {
delete IMPL_;
}
}

void BlockArgument::set_arg_index(uint32_t index) {
CHECK_NULL_IMPL(set_arg_number);
IMPL_->index_ = index;
}

BlockArgument BlockArgument::dyn_cast_from(Value value) {
if (classof(value)) {
return static_cast<detail::BlockArgumentImpl *>(value.impl());
} else {
return nullptr;
}
}

} // namespace pir
55 changes: 55 additions & 0 deletions paddle/pir/core/block_argument.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/pir/core/value.h"
namespace pir {
class Block;

namespace detail {
class BlockArgumentImpl;
} // namespace detail

///
/// \brief BlockArgument class represents the value defined by a result of
/// operation. This class only provides interfaces, for specific implementation,
/// see Impl class.
///
class IR_API BlockArgument : public Value {
public:
BlockArgument() = default;
Block *owner() const;
uint32_t arg_index() const;

private:
/// constructor
BlockArgument(detail::BlockArgumentImpl *impl); // NOLINT

/// create a new argument with the given type and owner.
static BlockArgument Create(Type type, Block *owner, uint32_t index);
/// Destroy the argument.
void Destroy();
/// set the position in the block argument list.
void set_arg_index(uint32_t index);
// Access create annd destroy.
friend Block;

// Access classof annd dyn_cast_from.
friend Value;
static bool classof(Value value);
static BlockArgument dyn_cast_from(Value value);
};

} // namespace pir
12 changes: 6 additions & 6 deletions paddle/pir/core/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class PointerAttribute;
///
class Builder {
public:
Builder(IrContext *context, Block *block, Block::iterator insert_point)
Builder(IrContext *context, Block *block, Block::Iterator insert_point)
: context_(context) {
SetInsertionPoint(block, insert_point);
}
Expand All @@ -57,10 +57,10 @@ class Builder {
: Builder(context, block, block->end()) {}

explicit Builder(IrContext *context)
: Builder(context, nullptr, Block::iterator{}) {}
: Builder(context, nullptr, Block::Iterator{}) {}

/// Set the insertion point to the specified location.
void SetInsertionPoint(Block *block, Block::iterator insert_point) {
void SetInsertionPoint(Block *block, Block::Iterator insert_point) {
// TODO(liuyuanle): check that insertPoint is in this rather than some other
// block.
this->block_ = block;
Expand All @@ -70,13 +70,13 @@ class Builder {
/// Set the insertion point to the specified operation, which will cause
/// subsequent insertions to go right before it.
void SetInsertionPoint(Operation *op) {
SetInsertionPoint(op->GetParent(), Block::iterator{*op});
SetInsertionPoint(op->GetParent(), Block::Iterator{*op});
}

/// Set the insertion point to the node after the specified operation, which
/// will cause subsequent insertions to go right after it.
void SetInsertionPointAfter(Operation *op) {
SetInsertionPoint(op->GetParent(), std::next(Block::iterator{*op}));
SetInsertionPoint(op->GetParent(), std::next(Block::Iterator{*op}));
}

/// Set the insertion point to the start of the specified block.
Expand Down Expand Up @@ -138,7 +138,7 @@ class Builder {
IrContext *context_;
Block *block_;
// The insertion point within the list that this builder is inserting before.
Block::iterator insert_point_;
Block::Iterator insert_point_;
};

} // namespace pir
Loading

0 comments on commit 362a9c9

Please sign in to comment.