diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 678a79a5540b8..7b61cde6f62ae 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -218,7 +218,7 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { ctx_, defining_op_result, parameter_name_mappings_[var_name]); pir::Block* block = program_->block(); - pir::Block::iterator insert_pos = std::find( + pir::Block::Iterator insert_pos = std::find( block->begin(), block->end(), defining_op_result.owner()); IR_ENFORCE( diff --git a/paddle/fluid/pir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/inplace_pass.cc index adfa5866799b9..030240d46e37f 100644 --- a/paddle/fluid/pir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/inplace_pass.cc @@ -318,7 +318,7 @@ class InplacePass : public pir::Pass { .at("op_name") .dyn_cast() .AsString(); - pir::Block::iterator insert_pos = + pir::Block::Iterator insert_pos = std::find(block->begin(), block->end(), kv.first); IR_ENFORCE(insert_pos != block->end(), "Operator %s not found in block.", diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 465a8719b3c7f..5e76cca0da5bb 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -452,9 +452,7 @@ void BindOpResult(py::module *m) { return paddle::dialect::greater_equal(self, other); }) .def("__hash__", - [](OpResult &self) { - return std::hash{}(self.dyn_cast()); - }) + [](OpResult &self) { return std::hash{}(self); }) .def("get_defining_op", &OpResult::GetDefiningOp, return_value_policy::reference) diff --git a/paddle/pir/core/block.cc b/paddle/pir/core/block.cc index f92d532298150..5561ea345b688 100644 --- a/paddle/pir/core/block.cc +++ b/paddle/pir/core/block.cc @@ -22,8 +22,11 @@ namespace pir { Block::~Block() { - assert(use_empty() && "block destroyed still has uses."); + if (!use_empty()) { + LOG(FATAL) << "Destoryed a block that is still in use."; + } clear(); + ClearArguments(); } void Block::push_back(Operation *op) { insert(ops_.end(), op); } @@ -33,13 +36,13 @@ Operation *Block::GetParentOp() const { return parent_ ? parent_->GetParent() : nullptr; } -Block::iterator Block::insert(const_iterator iterator, Operation *op) { - Block::iterator iter = ops_.insert(iterator, op); +Block::Iterator Block::insert(ConstIterator iterator, Operation *op) { + Block::Iterator iter = ops_.insert(iterator, op); op->SetParent(this, iter); return iter; } -Block::iterator Block::erase(const_iterator position) { +Block::Iterator Block::erase(ConstIterator position) { IR_ENFORCE((*position)->GetParent() == this, "iterator not own this block."); (*position)->Destroy(); return ops_.erase(position); @@ -75,6 +78,16 @@ void Block::ResetOpListOrder(const OpListType &new_op_list) { } } +void Block::ClearArguments() { + for (auto &argument : arguments_) { + argument.Destroy(); + } + arguments_.clear(); +} +void Block::AddArgument(Type type) { + arguments_.emplace_back(BlockArgument::Create(type, this, arguments_.size())); +} + bool Block::TopoOrderCheck(const OpListType &op_list) { std::unordered_set visited_values; for (const Operation *op : op_list) { diff --git a/paddle/pir/core/block.h b/paddle/pir/core/block.h index 3a8b4fafc345d..5eccdec9e3c03 100644 --- a/paddle/pir/core/block.h +++ b/paddle/pir/core/block.h @@ -17,6 +17,7 @@ #include #include +#include "paddle/pir/core/block_argument.h" #include "paddle/pir/core/block_operand.h" #include "paddle/pir/core/dll_decl.h" #include "paddle/pir/core/region.h" @@ -29,9 +30,9 @@ class IR_API Block { using OpListType = std::list; public: - using iterator = OpListType::iterator; - using reverse_iterator = OpListType::reverse_iterator; - using const_iterator = OpListType::const_iterator; + using Iterator = OpListType::iterator; + using ReverseIterator = OpListType::reverse_iterator; + using ConstIterator = OpListType::const_iterator; Block() = default; ~Block(); @@ -42,19 +43,19 @@ class IR_API Block { bool empty() const { return ops_.empty(); } size_t size() const { return ops_.size(); } - const_iterator begin() const { return ops_.begin(); } - const_iterator end() const { return ops_.end(); } - iterator begin() { return ops_.begin(); } - iterator end() { return ops_.end(); } - reverse_iterator rbegin() { return ops_.rbegin(); } - reverse_iterator rend() { return ops_.rend(); } + ConstIterator begin() const { return ops_.begin(); } + ConstIterator end() const { return ops_.end(); } + Iterator begin() { return ops_.begin(); } + Iterator end() { return ops_.end(); } + ReverseIterator rbegin() { return ops_.rbegin(); } + ReverseIterator rend() { return ops_.rend(); } Operation *back() const { return ops_.back(); } Operation *front() const { return ops_.front(); } void push_back(Operation *op); void push_front(Operation *op); - iterator insert(const_iterator iterator, Operation *op); - iterator erase(const_iterator position); + Iterator insert(ConstIterator iterator, Operation *op); + Iterator erase(ConstIterator position); void clear(); operator Region::iterator() { return position_; } @@ -73,6 +74,29 @@ class IR_API Block { // This is a unsafe funcion, please use it carefully. void ResetOpListOrder(const OpListType &new_op_list); + /// + /// \brief Block argument management + /// + using BlockArgListType = std::vector; + using ArgsIterator = BlockArgListType::iterator; + + ArgsIterator args_begin() { return arguments_.begin(); } + ArgsIterator args_end() { return arguments_.end(); } + bool args_empty() const { return arguments_.empty(); } + uint32_t args_size() const { return arguments_.size(); } + BlockArgument argument(uint32_t index) { return arguments_[index]; } + Type argument_type(uint32_t index) const { return arguments_[index].type(); } + + void ClearArguments(); + void AddArgument(Type type); + template + void AddArguments(TypeIter first, TypeIter last); + + template + void AddArguments(const TypeContainer &container) { + AddArguments(container.begin(), container.end()); + } + private: Block(Block &) = delete; Block &operator=(const Block &) = delete; @@ -84,9 +108,18 @@ class IR_API Block { static bool TopoOrderCheck(const OpListType &op_list); private: - Region *parent_; // not owned - OpListType ops_; // owned Region::iterator position_; BlockOperand first_use_; + OpListType ops_; // owned + BlockArgListType arguments_; // owned + Region *parent_; // not owned }; + +template +void Block::AddArguments(TypeIter first, TypeIter last) { + while (first != last) { + AddArgument(*first++); + } +} + } // namespace pir diff --git a/paddle/pir/core/block_argument.cc b/paddle/pir/core/block_argument.cc new file mode 100644 index 0000000000000..4d05cc54b279e --- /dev/null +++ b/paddle/pir/core/block_argument.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/pir/core/block_argument.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/value_impl.h" + +#define CHECK_NULL_IMPL(func_name) \ + IR_ENFORCE(impl_, "impl_ is null when called BlockArgument:" #func_name) + +#define IMPL_ static_cast(impl_) + +namespace pir { + +namespace detail { +/// +/// \brief BlockArgumentImpl is the implementation of an block argument. +/// +class BlockArgumentImpl : public ValueImpl { + public: + static bool classof(const ValueImpl &value) { + return value.kind() == BLOCK_ARGUMENT_INDEX; + } + + private: + BlockArgumentImpl(Type type, Block *owner, uint32_t index) + : ValueImpl(type, BLOCK_ARGUMENT_INDEX), owner_(owner), index_(index) {} + + ~BlockArgumentImpl(); + // access construction and owner + friend BlockArgument; + Block *owner_; + uint32_t index_; +}; + +BlockArgumentImpl::~BlockArgumentImpl() { + if (!use_empty()) { + LOG(FATAL) << "Destoryed a blockargument that is still in use."; + } +} + +} // namespace detail + +BlockArgument::BlockArgument(detail::BlockArgumentImpl *impl) : Value(impl) {} + +bool BlockArgument::classof(Value value) { + return value && detail::BlockArgumentImpl::classof(*value.impl()); +} + +Block *BlockArgument::owner() const { + CHECK_NULL_IMPL(owner); + return IMPL_->owner_; +} + +uint32_t BlockArgument::arg_index() const { + CHECK_NULL_IMPL(arg_index); + return IMPL_->index_; +} + +BlockArgument BlockArgument::Create(Type type, Block *owner, uint32_t index) { + return new detail::BlockArgumentImpl(type, owner, index); +} +/// Destroy the argument. +void BlockArgument::Destroy() { + if (impl_) { + LOG(WARNING) << "Destroying a null block argument."; + } else { + delete IMPL_; + } +} + +void BlockArgument::set_arg_index(uint32_t index) { + CHECK_NULL_IMPL(set_arg_number); + IMPL_->index_ = index; +} + +BlockArgument BlockArgument::dyn_cast_from(Value value) { + if (classof(value)) { + return static_cast(value.impl()); + } else { + return nullptr; + } +} + +} // namespace pir diff --git a/paddle/pir/core/block_argument.h b/paddle/pir/core/block_argument.h new file mode 100644 index 0000000000000..27f1779650ef1 --- /dev/null +++ b/paddle/pir/core/block_argument.h @@ -0,0 +1,55 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/value.h" +namespace pir { +class Block; + +namespace detail { +class BlockArgumentImpl; +} // namespace detail + +/// +/// \brief BlockArgument class represents the value defined by a result of +/// operation. This class only provides interfaces, for specific implementation, +/// see Impl class. +/// +class IR_API BlockArgument : public Value { + public: + BlockArgument() = default; + Block *owner() const; + uint32_t arg_index() const; + + private: + /// constructor + BlockArgument(detail::BlockArgumentImpl *impl); // NOLINT + + /// create a new argument with the given type and owner. + static BlockArgument Create(Type type, Block *owner, uint32_t index); + /// Destroy the argument. + void Destroy(); + /// set the position in the block argument list. + void set_arg_index(uint32_t index); + // Access create annd destroy. + friend Block; + + // Access classof annd dyn_cast_from. + friend Value; + static bool classof(Value value); + static BlockArgument dyn_cast_from(Value value); +}; + +} // namespace pir diff --git a/paddle/pir/core/builder.h b/paddle/pir/core/builder.h index acb621e7808e7..81e25a0d365f0 100644 --- a/paddle/pir/core/builder.h +++ b/paddle/pir/core/builder.h @@ -48,7 +48,7 @@ class PointerAttribute; /// class Builder { public: - Builder(IrContext *context, Block *block, Block::iterator insert_point) + Builder(IrContext *context, Block *block, Block::Iterator insert_point) : context_(context) { SetInsertionPoint(block, insert_point); } @@ -57,10 +57,10 @@ class Builder { : Builder(context, block, block->end()) {} explicit Builder(IrContext *context) - : Builder(context, nullptr, Block::iterator{}) {} + : Builder(context, nullptr, Block::Iterator{}) {} /// Set the insertion point to the specified location. - void SetInsertionPoint(Block *block, Block::iterator insert_point) { + void SetInsertionPoint(Block *block, Block::Iterator insert_point) { // TODO(liuyuanle): check that insertPoint is in this rather than some other // block. this->block_ = block; @@ -70,13 +70,13 @@ class Builder { /// Set the insertion point to the specified operation, which will cause /// subsequent insertions to go right before it. void SetInsertionPoint(Operation *op) { - SetInsertionPoint(op->GetParent(), Block::iterator{*op}); + SetInsertionPoint(op->GetParent(), Block::Iterator{*op}); } /// Set the insertion point to the node after the specified operation, which /// will cause subsequent insertions to go right after it. void SetInsertionPointAfter(Operation *op) { - SetInsertionPoint(op->GetParent(), std::next(Block::iterator{*op})); + SetInsertionPoint(op->GetParent(), std::next(Block::Iterator{*op})); } /// Set the insertion point to the start of the specified block. @@ -138,7 +138,7 @@ class Builder { IrContext *context_; Block *block_; // The insertion point within the list that this builder is inserting before. - Block::iterator insert_point_; + Block::Iterator insert_point_; }; } // namespace pir diff --git a/paddle/pir/core/op_result.cc b/paddle/pir/core/op_result.cc index 510f98d99b526..b90f67093321c 100644 --- a/paddle/pir/core/op_result.cc +++ b/paddle/pir/core/op_result.cc @@ -15,38 +15,40 @@ #include "paddle/pir/core/enforce.h" #include "paddle/pir/core/op_result_impl.h" -#define CHECK_NULL_IMPL(class_name, func_name) \ - IR_ENFORCE(impl_, \ - "impl_ pointer is null when call func:" #func_name \ - " , in class: " #class_name ".") - -#define CHECK_OPRESULT_NULL_IMPL(func_name) CHECK_NULL_IMPL(OpResult, func_name) +#define CHECK_OPRESULT_NULL_IMPL(func_name) \ + IR_ENFORCE(impl_, "impl_ pointer is null when call OpResult::" #func_name) +#define IMPL_ static_cast(impl_) namespace pir { - // OpResult bool OpResult::classof(Value value) { - return value && pir::isa(value.impl()); + return value && detail::OpResultImpl::classof(*value.impl()); } Operation *OpResult::owner() const { CHECK_OPRESULT_NULL_IMPL(owner); - return impl()->owner(); + return IMPL_->owner(); } uint32_t OpResult::GetResultIndex() const { CHECK_OPRESULT_NULL_IMPL(GetResultIndex); - return impl()->GetResultIndex(); + return IMPL_->GetResultIndex(); } -detail::OpResultImpl *OpResult::impl() const { - return reinterpret_cast(impl_); +OpResult OpResult::dyn_cast_from(Value value) { + if (classof(value)) { + return static_cast(value.impl()); + } else { + return nullptr; + } } bool OpResult::operator==(const OpResult &other) const { return impl_ == other.impl_; } +// OpResult::OpResult(const detail::OpResultImpl *impl) : Value(impl) {} + uint32_t OpResult::GetValidInlineIndex(uint32_t index) { uint32_t max_inline_index = pir::detail::OpResultImpl::GetMaxInlineResultIndex(); diff --git a/paddle/pir/core/op_result.h b/paddle/pir/core/op_result.h index 1a5f14a9a17fe..50d638c4276fd 100644 --- a/paddle/pir/core/op_result.h +++ b/paddle/pir/core/op_result.h @@ -28,22 +28,22 @@ class OpResultImpl; /// class IR_API OpResult : public Value { public: - using Value::Value; - - static bool classof(Value value); - + OpResult() = default; Operation *owner() const; - uint32_t GetResultIndex() const; - bool operator==(const OpResult &other) const; + // OpResult(const detail::OpResultImpl *impl); // NOLINT - friend Operation; - - detail::OpResultImpl *impl() const; + // This func will remove in next pr. + OpResult(const detail::ValueImpl *impl) : Value(impl) {} // NOLINT private: + friend Operation; static uint32_t GetValidInlineIndex(uint32_t index); + // Access classof annd dyn_cast_from. + friend Value; + static bool classof(Value value); + static OpResult dyn_cast_from(Value value); }; } // namespace pir diff --git a/paddle/pir/core/op_result_impl.h b/paddle/pir/core/op_result_impl.h index 99601a27911af..0da46507acb58 100644 --- a/paddle/pir/core/op_result_impl.h +++ b/paddle/pir/core/op_result_impl.h @@ -63,7 +63,7 @@ class OpInlineResultImpl : public OpResultImpl { } } - static bool classof(const OpResultImpl &value) { + static bool classof(const ValueImpl &value) { return value.kind() < OUTLINE_OP_RESULT_INDEX; } @@ -80,7 +80,7 @@ class OpOutlineResultImpl : public OpResultImpl { : OpResultImpl(type, OUTLINE_OP_RESULT_INDEX), outline_index_(outline_index) {} - static bool classof(const OpResultImpl &value) { + static bool classof(const ValueImpl &value) { return value.kind() == OUTLINE_OP_RESULT_INDEX; } diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index fdb850bc1f415..30737a12a8df6 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -131,7 +131,7 @@ void Operation::Destroy() { // 2. Deconstruct Result. for (size_t idx = 0; idx < num_results_; ++idx) { - detail::OpResultImpl *impl = result(idx).impl(); + detail::ValueImpl *impl = result(idx).impl(); if (detail::OpOutlineResultImpl::classof(*impl)) { static_cast(impl)->~OpOutlineResultImpl(); } else { @@ -275,7 +275,7 @@ const Region &Operation::region(unsigned index) const { return regions_[index]; } -void Operation::SetParent(Block *parent, const Block::iterator &position) { +void Operation::SetParent(Block *parent, const Block::Iterator &position) { parent_ = parent; position_ = position; } diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index 28c0b42671c96..10ee80c7fa867 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -29,10 +29,6 @@ class Program; class OpOperand; class OpResult; -namespace detial { -class BlockOperandImpl; -} // namespace detial - class IR_API alignas(8) Operation final { public: /// @@ -142,9 +138,9 @@ class IR_API alignas(8) Operation final { const_cast(this)->GetParentProgram()); } - operator Block::iterator() { return position_; } + operator Block::Iterator() { return position_; } - operator Block::const_iterator() const { return position_; } + operator Block::ConstIterator() const { return position_; } /// Replace all uses of results of this operation with the provided 'values'. void ReplaceAllUsesWith(const std::vector &values); @@ -179,7 +175,7 @@ class IR_API alignas(8) Operation final { // Allow access to 'SetParent'. friend class Block; - void SetParent(Block *parent, const Block::iterator &position); + void SetParent(Block *parent, const Block::Iterator &position); template struct CastUtil< @@ -200,7 +196,7 @@ class IR_API alignas(8) Operation final { detail::BlockOperandImpl *block_operands_{nullptr}; Region *regions_{nullptr}; Block *parent_{nullptr}; - Block::iterator position_; + Block::Iterator position_; }; } // namespace pir diff --git a/paddle/pir/core/value.h b/paddle/pir/core/value.h index d92678ed3a8ed..ee908f2355b08 100644 --- a/paddle/pir/core/value.h +++ b/paddle/pir/core/value.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/pir/core/cast_utils.h" #include "paddle/pir/core/op_operand.h" #include "paddle/pir/core/type.h" #include "paddle/pir/core/use_iterator.h" @@ -46,14 +45,14 @@ class IR_API Value { explicit operator bool() const; - template + template bool isa() const { - return pir::isa(*this); + return U::classof(*this); } template U dyn_cast() const { - return pir::dyn_cast(*this); + return U::dyn_cast_from(*this); } Type type() const; @@ -87,9 +86,7 @@ class IR_API Value { Value new_value, const std::function &should_replace) const; void ReplaceAllUsesWith(Value new_value) const; - - detail::ValueImpl *impl() { return impl_; } - const detail::ValueImpl *impl() const { return impl_; } + detail::ValueImpl *impl() const { return impl_; } protected: detail::ValueImpl *impl_{nullptr}; diff --git a/paddle/pir/core/value_impl.cc b/paddle/pir/core/value_impl.cc index f98c1ac75ea3a..c17e44ffa3aa5 100644 --- a/paddle/pir/core/value_impl.cc +++ b/paddle/pir/core/value_impl.cc @@ -18,11 +18,11 @@ namespace pir { namespace detail { void ValueImpl::set_first_use(OpOperandImpl *first_use) { uint32_t offset = kind(); - first_use_offseted_by_index_ = reinterpret_cast( + first_use_offseted_by_kind_ = reinterpret_cast( reinterpret_cast(first_use) + offset); VLOG(4) << "The index of this value is " << offset << ". Offset and set first use: " << first_use << " -> " - << first_use_offseted_by_index_ << "."; + << first_use_offseted_by_kind_ << "."; } std::string ValueImpl::PrintUdChain() { @@ -40,16 +40,17 @@ std::string ValueImpl::PrintUdChain() { result << "nullptr"; return result.str(); } -ValueImpl::ValueImpl(Type type, uint32_t index) { - if (index > OUTLINE_OP_RESULT_INDEX) { - throw("The value of index must not exceed 6"); +ValueImpl::ValueImpl(Type type, uint32_t kind) { + if (kind > BLOCK_ARGUMENT_INDEX) { + LOG(FATAL) << "The kind of value_impl(" << kind + << "), is bigger than BLOCK_ARGUMENT_INDEX(7)"; } type_ = type; - first_use_offseted_by_index_ = reinterpret_cast( - reinterpret_cast(nullptr) + index); - VLOG(4) << "Construct a ValueImpl whose's index is " << index + first_use_offseted_by_kind_ = reinterpret_cast( + reinterpret_cast(nullptr) + kind); + VLOG(4) << "Construct a ValueImpl whose's kind is " << kind << ". The offset first_use address is: " - << first_use_offseted_by_index_; + << first_use_offseted_by_kind_; } } // namespace detail diff --git a/paddle/pir/core/value_impl.h b/paddle/pir/core/value_impl.h index f560aa4362d4d..ccd5e835abdeb 100644 --- a/paddle/pir/core/value_impl.h +++ b/paddle/pir/core/value_impl.h @@ -18,7 +18,9 @@ #include "paddle/pir/core/value.h" namespace pir { -static const uint32_t OUTLINE_OP_RESULT_INDEX = 6; +constexpr const uint32_t OUTLINE_OP_RESULT_INDEX = 6; +constexpr const uint32_t BLOCK_ARGUMENT_INDEX = OUTLINE_OP_RESULT_INDEX + 1; + class Operation; namespace detail { @@ -41,12 +43,12 @@ class alignas(8) ValueImpl { OpOperandImpl *first_use() const { return reinterpret_cast( - reinterpret_cast(first_use_offseted_by_index_) & (~0x07)); + reinterpret_cast(first_use_offseted_by_kind_) & (~0x07)); } void set_first_use(OpOperandImpl *first_use); - OpOperandImpl **first_use_addr() { return &first_use_offseted_by_index_; } + OpOperandImpl **first_use_addr() { return &first_use_offseted_by_kind_; } bool use_empty() const { return first_use() == nullptr; } @@ -57,17 +59,22 @@ class alignas(8) ValueImpl { std::string PrintUdChain(); /// - /// \brief Interface functions of "first_use_offseted_by_index_" attribute. + /// \brief Interface functions of "first_use_offseted_by_kind_" attribute. /// uint32_t kind() const { - return reinterpret_cast(first_use_offseted_by_index_) & 0x07; + return reinterpret_cast(first_use_offseted_by_kind_) & 0x07; + } + + template + bool isa() { + return T::classof(*this); } protected: /// /// \brief Only can be constructed by derived classes such as OpResultImpl. /// - ValueImpl(Type type, uint32_t index); + ValueImpl(Type type, uint32_t kind); /// /// \brief Attribute1: Type of value. @@ -83,7 +90,7 @@ class alignas(8) ValueImpl { /// output(OpInlineResultImpl); (2) index = 6: represent the position >=6 /// outline output(OpOutlineResultImpl); (3) index = 7 is reserved. /// - OpOperandImpl *first_use_offseted_by_index_ = nullptr; + OpOperandImpl *first_use_offseted_by_kind_ = nullptr; }; } // namespace detail diff --git a/test/cpp/pir/core/CMakeLists.txt b/test/cpp/pir/core/CMakeLists.txt index b3d815e59bb6c..51fffd20dc9e5 100644 --- a/test/cpp/pir/core/CMakeLists.txt +++ b/test/cpp/pir/core/CMakeLists.txt @@ -131,3 +131,12 @@ cc_test_old( test_dialect gtest pir) + +cc_test_old( + block_argument_test + SRCS + block_argument_test.cc + DEPS + test_dialect + gtest + pir) diff --git a/test/cpp/pir/core/block_argument_test.cc b/test/cpp/pir/core/block_argument_test.cc new file mode 100644 index 0000000000000..65810319160e0 --- /dev/null +++ b/test/cpp/pir/core/block_argument_test.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/program.h" +#include "test/cpp/pir/tools/test_dialect.h" +#include "test/cpp/pir/tools/test_op.h" + +TEST(block_argument_test, base) { + pir::IrContext ctx; + ctx.GetOrRegisterDialect(); + + pir::Program program(&ctx); + pir::Block* block = program.block(); + pir::Builder builder(&ctx, block); + + std::vector types(3, builder.float32_type()); + block->AddArguments(types); + + EXPECT_FALSE(block->args_empty()); + EXPECT_EQ(block->args_size(), types.size()); + + uint32_t index = 0; + for (auto iter = block->args_begin(); iter != block->args_end(); ++iter) { + EXPECT_EQ(iter->arg_index(), index++); + } + + pir::Value value = block->argument(0); + pir::BlockArgument argument = value.dyn_cast(); + EXPECT_TRUE(argument); + EXPECT_EQ(argument.owner(), block); + EXPECT_EQ(block->argument_type(0), types[0]); + pir::OpResult op_result = value.dyn_cast(); + EXPECT_FALSE(op_result); + + auto op = builder.Build(builder.double_attr(1.0), + builder.float64_type()); + value = op.result(0); + argument = value.dyn_cast(); + EXPECT_FALSE(argument); + op_result = value.dyn_cast(); + EXPECT_TRUE(op_result); +} diff --git a/test/cpp/pir/core/op_info_test.cc b/test/cpp/pir/core/op_info_test.cc index d02566237876a..fec5b71396095 100644 --- a/test/cpp/pir/core/op_info_test.cc +++ b/test/cpp/pir/core/op_info_test.cc @@ -34,7 +34,7 @@ TEST(ir_op_info_test, op_op_info_test) { pir::Operation* op = block->back(); - EXPECT_EQ(block->end(), ++pir::Block::iterator(*op)); + EXPECT_EQ(block->end(), ++pir::Block::Iterator(*op)); auto& info_map = context->registered_op_info_map(); EXPECT_FALSE(info_map.empty());