diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 89edf0ddf29b3..66a52f99c6b44 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -1081,6 +1081,27 @@ void IfOp::Print(pir::IrPrinter &printer) { os << "\n }"; } void IfOp::Verify() {} + +void WhileOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + const std::vector &inputs, + const std::vector &output_types) { + // auto insert_point = builder.insert_point(); + argument.AddInputs(inputs); + argument.AddOutputs(output_types); + argument.AddRegion(nullptr); + argument.AddRegion(nullptr); +} +pir::Block *WhileOp::cond_block() { + pir::Region &cond_region = (*this)->region(0); + if (cond_region.empty()) cond_region.emplace_back(); + return cond_region.front(); +} +pir::Block *WhileOp::body_block() { + pir::Region &body_region = (*this)->region(1); + if (body_region.empty()) body_region.emplace_back(); + return body_region.front(); +} } // namespace dialect } // namespace paddle @@ -1091,3 +1112,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index b9f8474755ef7..93f24e80cb524 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -193,6 +193,22 @@ class IfOp : public pir::Op { void Verify(); }; +class WhileOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd.while"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + const std::vector &inputs, + const std::vector &output_types); + void Verify() {} + pir::Block *cond_block(); + pir::Block *body_block(); +}; + } // namespace dialect } // namespace paddle @@ -203,3 +219,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 6e7e49e3bcee9..a250d43ecae84 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -56,7 +56,8 @@ void OperatorDialect::initialize() { paddle::dialect::FusedGemmEpilogueOp, paddle::dialect::FusedGemmEpilogueGradOp, paddle::dialect::SplitGradOp, - paddle::dialect::IfOp>(); + paddle::dialect::IfOp, + paddle::dialect::WhileOp>(); RegisterInterfaces(); } diff --git a/paddle/pir/core/builder.cc b/paddle/pir/core/builder.cc index 7b04de3ee3759..6a1608c84ab85 100644 --- a/paddle/pir/core/builder.cc +++ b/paddle/pir/core/builder.cc @@ -33,15 +33,19 @@ Operation *Builder::Build(const std::vector &inputs, } Operation *Builder::Insert(Operation *op) { - if (block_) { - block_->insert(insert_point_, op); + if (insert_point_.first) { + insert_point_.first->insert(insert_point_.second, op); } else { LOG(WARNING) << "Builder's Block is nullptr, insert failed."; } return op; } + +BoolType Builder::bool_type() { return BoolType::get(context_); } UInt8Type Builder::uint8_type() { return UInt8Type::get(context_); } Int8Type Builder::int8_type() { return Int8Type::get(context_); } +Int16Type Builder::int16_type() { return Int16Type::get(context_); } +Int32Type Builder::int32_type() { return Int32Type::get(context_); } VectorType Builder::vec_type(const std::vector &value) { return VectorType::get(context_, value); } @@ -50,8 +54,6 @@ Float32Type Builder::float32_type() { return Float32Type::get(context_); } Float64Type Builder::float64_type() { return Float64Type::get(context_); } IndexType Builder::index_type() { return IndexType::get(context_); } -Int16Type Builder::int16_type() { return Int16Type::get(context_); } -BoolType Builder::bool_type() { return BoolType::get(context_); } Complex64Type Builder::complex64_type() { return Complex64Type::get(context_); } Complex128Type Builder::complex128_type() { return Complex128Type::get(context_); diff --git a/paddle/pir/core/builder.h b/paddle/pir/core/builder.h index dab0d2b4ffafe..72c8494cf8906 100644 --- a/paddle/pir/core/builder.h +++ b/paddle/pir/core/builder.h @@ -24,11 +24,12 @@ namespace pir { class Type; class UInt8Type; class Int8Type; +class Int16Type; +class Int32Type; class VectorType; class BFloat16Type; class Float32Type; class Float64Type; -class Int16Type; class IndexType; class BoolType; class Complex64Type; @@ -42,6 +43,7 @@ class Int64Attribute; class ArrayAttribute; class PointerAttribute; +using InsertPoint = std::pair; /// /// \brief Unified interface of the Attribute class. Derivation of all Attribute /// classes only derives interfaces, not members. @@ -49,9 +51,7 @@ class PointerAttribute; class Builder { public: Builder(IrContext *context, Block *block, Block::Iterator insert_point) - : context_(context) { - SetInsertionPoint(block, insert_point); - } + : context_(context), insert_point_(block, insert_point) {} Builder(IrContext *context, Block *block) : Builder(context, block, block->end()) {} @@ -59,12 +59,22 @@ class Builder { explicit Builder(IrContext *context) : Builder(context, nullptr, Block::Iterator{}) {} + Builder(IrContext *context, const InsertPoint &insert_point) + : context_(context), insert_point_(insert_point) {} + + void SetInsertionPoint(const InsertPoint &insert_point) { + insert_point_ = insert_point; + } + + /// Set the insert point to the start of the specified block. + void SetInsertionPointToStart(Block *block) { + SetInsertionPoint(block, block->begin()); + } + /// Set the insertion point to the specified location. void SetInsertionPoint(Block *block, Block::Iterator insert_point) { - // TODO(liuyuanle): check that insertPoint is in this rather than some other - // block. - this->block_ = block; - this->insert_point_ = insert_point; + insert_point_.first = block; + insert_point_.second = insert_point; } /// Set the insertion point to the specified operation, which will cause @@ -79,11 +89,6 @@ class Builder { SetInsertionPoint(op->GetParent(), std::next(Block::Iterator{*op})); } - /// Set the insertion point to the start of the specified block. - void SetInsertionPointToStart(Block *block) { - SetInsertionPoint(block, block->begin()); - } - /// Set the insertion point to the end of the specified block. void SetInsertionPointToEnd(Block *block) { SetInsertionPoint(block, block->end()); @@ -91,7 +96,9 @@ class Builder { IrContext *ir_context() const { return context_; } - Block *block() const { return block_; } + Block *block() const { return insert_point_.first; } + + const InsertPoint &insert_point() const { return insert_point_; } /// Creates an operation given the fields represented as an OperationState. IR_API Operation *Build(OperationArgument &&argument); @@ -104,22 +111,18 @@ class Builder { /// Create an operation of specific op type at the current insertion point. template - OpTy Build(Args &&...args) { - OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); - OpTy::Build(*this, argument, std::forward(args)...); - Operation *op = Build(std::move(argument)); - return OpTy(op); - } + OpTy Build(Args &&...args); + IR_API BoolType bool_type(); IR_API UInt8Type uint8_type(); IR_API Int8Type int8_type(); + IR_API Int16Type int16_type(); + IR_API Int32Type int32_type(); IR_API VectorType vec_type(const std::vector &); IR_API BFloat16Type bfloat16_type(); IR_API IndexType index_type(); IR_API Float32Type float32_type(); IR_API Float64Type float64_type(); - IR_API Int16Type int16_type(); - IR_API BoolType bool_type(); IR_API Complex64Type complex64_type(); IR_API Complex128Type complex128_type(); @@ -136,9 +139,16 @@ class Builder { Operation *Insert(Operation *op); IrContext *context_; - Block *block_; - // The insertion point within the list that this builder is inserting before. - Block::Iterator insert_point_; + + InsertPoint insert_point_; }; +template +OpTy Builder::Build(Args &&...args) { + OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); + OpTy::Build(*this, argument, std::forward(args)...); + Operation *op = Build(std::move(argument)); + return OpTy(op); +} + } // namespace pir diff --git a/paddle/pir/dialect/control_flow/ir/cf_dialect.cc b/paddle/pir/dialect/control_flow/ir/cf_dialect.cc index ed36c0c81cca6..7166af2ece636 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_dialect.cc +++ b/paddle/pir/dialect/control_flow/ir/cf_dialect.cc @@ -15,6 +15,6 @@ #include "paddle/pir/dialect/control_flow/ir/cf_ops.h" namespace pir { -void ControlFlowDialect::initialize() { RegisterOps(); } +void ControlFlowDialect::initialize() { RegisterOps(); } } // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::ControlFlowDialect) diff --git a/paddle/pir/dialect/control_flow/ir/cf_ops.cc b/paddle/pir/dialect/control_flow/ir/cf_ops.cc index 7981a6ab96396..69dce41e62bad 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_ops.cc +++ b/paddle/pir/dialect/control_flow/ir/cf_ops.cc @@ -24,3 +24,4 @@ void YieldOp::Build(Builder &builder, } // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::YieldOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::CondYieldOp) diff --git a/paddle/pir/dialect/control_flow/ir/cf_ops.h b/paddle/pir/dialect/control_flow/ir/cf_ops.h index 3689920e1bce6..898f954e09d5f 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_ops.h +++ b/paddle/pir/dialect/control_flow/ir/cf_ops.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once - +#include #include "paddle/pir/core/builder.h" #include "paddle/pir/core/op_base.h" @@ -30,6 +30,31 @@ class IR_API YieldOp : public Op { const std::vector &Value); void Verify() {} }; + +class IR_API CondYieldOp : public Op { + public: + using Op::Op; + static const char *name() { return "cf.cond_yield"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + + template + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value cond, + const ValueContainer &inputs); + void Verify() {} +}; + +template +void CondYieldOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value cond, + const ValueContainer &inputs) { + argument.AddInput(cond); + argument.AddInputs(inputs); +} } // namespace pir IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::YieldOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::CondYieldOp); diff --git a/test/cpp/pir/control_flow_dialect/CMakeLists.txt b/test/cpp/pir/control_flow_dialect/CMakeLists.txt index fa6b0a5ae7fca..64af30a54d0ee 100644 --- a/test/cpp/pir/control_flow_dialect/CMakeLists.txt +++ b/test/cpp/pir/control_flow_dialect/CMakeLists.txt @@ -1,8 +1,17 @@ cc_test_old( - test_if_op + if_op_test SRCS if_op_test.cc DEPS pir pd_op_dialect gtest) + +cc_test_old( + while_op_test + SRCS + while_op_test.cc + DEPS + pir + pd_op_dialect + gtest) diff --git a/test/cpp/pir/control_flow_dialect/while_op_test.cc b/test/cpp/pir/control_flow_dialect/while_op_test.cc new file mode 100644 index 0000000000000..6c558cc982926 --- /dev/null +++ b/test/cpp/pir/control_flow_dialect/while_op_test.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" +#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" + +using namespace paddle::dialect; // NOLINT +TEST(while_op_test, base) { + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + pir::Block* block = program.block(); + pir::Builder builder(ctx, block); + + auto i = + builder.Build(std::vector{1}, 1, phi::DataType::INT32) + .out(); + + auto ten = + builder.Build(std::vector{1}, 10, phi::DataType::INT32) + .out(); + + auto while_op = builder.Build( + std::vector{i, ten}, + std::vector{builder.int32_type(), builder.int32_type()}); + + // while(i < ten) + pir::Block* cond_block = while_op.cond_block(); + auto cond_i_argument = cond_block->AddArgument(i.type()); + auto cond_ten_argument = cond_block->AddArgument(ten.type()); + builder.SetInsertionPointToStart(cond_block); + auto cond_value = + builder.Build(cond_i_argument, cond_ten_argument).out(); + builder.Build( + cond_value, std::vector{cond_i_argument, cond_ten_argument}); + + // { i = i + 1} + pir::Block* body_block = while_op.body_block(); + auto body_i_argument = body_block->AddArgument(i.type()); + auto body_ten_argument = body_block->AddArgument(ten.type()); + builder.SetInsertionPointToStart(body_block); + auto one = + builder.Build(std::vector{1}, 1, phi::DataType::INT32) + .out(); + auto new_i = builder.Build(body_i_argument, one).out(); + builder.Build( + std::vector{new_i, body_ten_argument}); + + builder.SetInsertionPointAfter(while_op); + std::stringstream ss; + program.Print(ss); + + LOG(INFO) << ss.str(); +}