Skip to content

Commit

Permalink
[PIR] support while op (PaddlePaddle#57475)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored and jiahy0825 committed Oct 16, 2023
1 parent d03ebc7 commit 3eaa065
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 33 deletions.
22 changes: 22 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,27 @@ void IfOp::Print(pir::IrPrinter &printer) {
os << "\n }";
}
void IfOp::Verify() {}

void WhileOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const std::vector<pir::Type> &output_types) {
// auto insert_point = builder.insert_point();
argument.AddInputs(inputs);
argument.AddOutputs(output_types);
argument.AddRegion(nullptr);
argument.AddRegion(nullptr);
}
pir::Block *WhileOp::cond_block() {
pir::Region &cond_region = (*this)->region(0);
if (cond_region.empty()) cond_region.emplace_back();
return cond_region.front();
}
pir::Block *WhileOp::body_block() {
pir::Region &body_region = (*this)->region(1);
if (body_region.empty()) body_region.emplace_back();
return body_region.front();
}
} // namespace dialect
} // namespace paddle

Expand All @@ -1091,3 +1112,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp)
17 changes: 17 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,22 @@ class IfOp : public pir::Op<IfOp> {
void Verify();
};

class WhileOp : public pir::Op<WhileOp> {
public:
using Op::Op;
static const char *name() { return "pd.while"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;

static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const std::vector<pir::Type> &output_types);
void Verify() {}
pir::Block *cond_block();
pir::Block *body_block();
};

} // namespace dialect
} // namespace paddle

Expand All @@ -203,3 +219,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp)
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ void OperatorDialect::initialize() {
paddle::dialect::FusedGemmEpilogueOp,
paddle::dialect::FusedGemmEpilogueGradOp,
paddle::dialect::SplitGradOp,
paddle::dialect::IfOp>();
paddle::dialect::IfOp,
paddle::dialect::WhileOp>();

RegisterInterfaces<ParameterConvertInterface>();
}
Expand Down
10 changes: 6 additions & 4 deletions paddle/pir/core/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ Operation *Builder::Build(const std::vector<Value> &inputs,
}

Operation *Builder::Insert(Operation *op) {
if (block_) {
block_->insert(insert_point_, op);
if (insert_point_.first) {
insert_point_.first->insert(insert_point_.second, op);
} else {
LOG(WARNING) << "Builder's Block is nullptr, insert failed.";
}
return op;
}

BoolType Builder::bool_type() { return BoolType::get(context_); }
UInt8Type Builder::uint8_type() { return UInt8Type::get(context_); }
Int8Type Builder::int8_type() { return Int8Type::get(context_); }
Int16Type Builder::int16_type() { return Int16Type::get(context_); }
Int32Type Builder::int32_type() { return Int32Type::get(context_); }
VectorType Builder::vec_type(const std::vector<Type> &value) {
return VectorType::get(context_, value);
}
Expand All @@ -50,8 +54,6 @@ Float32Type Builder::float32_type() { return Float32Type::get(context_); }

Float64Type Builder::float64_type() { return Float64Type::get(context_); }
IndexType Builder::index_type() { return IndexType::get(context_); }
Int16Type Builder::int16_type() { return Int16Type::get(context_); }
BoolType Builder::bool_type() { return BoolType::get(context_); }
Complex64Type Builder::complex64_type() { return Complex64Type::get(context_); }
Complex128Type Builder::complex128_type() {
return Complex128Type::get(context_);
Expand Down
60 changes: 35 additions & 25 deletions paddle/pir/core/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ namespace pir {
class Type;
class UInt8Type;
class Int8Type;
class Int16Type;
class Int32Type;
class VectorType;
class BFloat16Type;
class Float32Type;
class Float64Type;
class Int16Type;
class IndexType;
class BoolType;
class Complex64Type;
Expand All @@ -42,29 +43,38 @@ class Int64Attribute;
class ArrayAttribute;
class PointerAttribute;

using InsertPoint = std::pair<Block *, Block::Iterator>;
///
/// \brief Unified interface of the Attribute class. Derivation of all Attribute
/// classes only derives interfaces, not members.
///
class Builder {
public:
Builder(IrContext *context, Block *block, Block::Iterator insert_point)
: context_(context) {
SetInsertionPoint(block, insert_point);
}
: context_(context), insert_point_(block, insert_point) {}

Builder(IrContext *context, Block *block)
: Builder(context, block, block->end()) {}

explicit Builder(IrContext *context)
: Builder(context, nullptr, Block::Iterator{}) {}

Builder(IrContext *context, const InsertPoint &insert_point)
: context_(context), insert_point_(insert_point) {}

void SetInsertionPoint(const InsertPoint &insert_point) {
insert_point_ = insert_point;
}

/// Set the insert point to the start of the specified block.
void SetInsertionPointToStart(Block *block) {
SetInsertionPoint(block, block->begin());
}

/// Set the insertion point to the specified location.
void SetInsertionPoint(Block *block, Block::Iterator insert_point) {
// TODO(liuyuanle): check that insertPoint is in this rather than some other
// block.
this->block_ = block;
this->insert_point_ = insert_point;
insert_point_.first = block;
insert_point_.second = insert_point;
}

/// Set the insertion point to the specified operation, which will cause
Expand All @@ -79,19 +89,16 @@ class Builder {
SetInsertionPoint(op->GetParent(), std::next(Block::Iterator{*op}));
}

/// Set the insertion point to the start of the specified block.
void SetInsertionPointToStart(Block *block) {
SetInsertionPoint(block, block->begin());
}

/// Set the insertion point to the end of the specified block.
void SetInsertionPointToEnd(Block *block) {
SetInsertionPoint(block, block->end());
}

IrContext *ir_context() const { return context_; }

Block *block() const { return block_; }
Block *block() const { return insert_point_.first; }

const InsertPoint &insert_point() const { return insert_point_; }

/// Creates an operation given the fields represented as an OperationState.
IR_API Operation *Build(OperationArgument &&argument);
Expand All @@ -104,22 +111,18 @@ class Builder {

/// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpTy Build(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::Build(*this, argument, std::forward<Args>(args)...);
Operation *op = Build(std::move(argument));
return OpTy(op);
}
OpTy Build(Args &&...args);

IR_API BoolType bool_type();
IR_API UInt8Type uint8_type();
IR_API Int8Type int8_type();
IR_API Int16Type int16_type();
IR_API Int32Type int32_type();
IR_API VectorType vec_type(const std::vector<Type> &);
IR_API BFloat16Type bfloat16_type();
IR_API IndexType index_type();
IR_API Float32Type float32_type();
IR_API Float64Type float64_type();
IR_API Int16Type int16_type();
IR_API BoolType bool_type();
IR_API Complex64Type complex64_type();
IR_API Complex128Type complex128_type();

Expand All @@ -136,9 +139,16 @@ class Builder {
Operation *Insert(Operation *op);

IrContext *context_;
Block *block_;
// The insertion point within the list that this builder is inserting before.
Block::Iterator insert_point_;

InsertPoint insert_point_;
};

template <typename OpTy, typename... Args>
OpTy Builder::Build(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::Build(*this, argument, std::forward<Args>(args)...);
Operation *op = Build(std::move(argument));
return OpTy(op);
}

} // namespace pir
2 changes: 1 addition & 1 deletion paddle/pir/dialect/control_flow/ir/cf_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
#include "paddle/pir/dialect/control_flow/ir/cf_ops.h"

namespace pir {
void ControlFlowDialect::initialize() { RegisterOps<YieldOp>(); }
void ControlFlowDialect::initialize() { RegisterOps<YieldOp, CondYieldOp>(); }
} // namespace pir
IR_DEFINE_EXPLICIT_TYPE_ID(pir::ControlFlowDialect)
1 change: 1 addition & 0 deletions paddle/pir/dialect/control_flow/ir/cf_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ void YieldOp::Build(Builder &builder,
} // namespace pir

IR_DEFINE_EXPLICIT_TYPE_ID(pir::YieldOp)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::CondYieldOp)
27 changes: 26 additions & 1 deletion paddle/pir/dialect/control_flow/ir/cf_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#pragma once

#include <functional>
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/op_base.h"

Expand All @@ -30,6 +30,31 @@ class IR_API YieldOp : public Op<YieldOp> {
const std::vector<Value> &Value);
void Verify() {}
};

class IR_API CondYieldOp : public Op<CondYieldOp> {
public:
using Op::Op;
static const char *name() { return "cf.cond_yield"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;

template <class ValueContainer>
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Value cond,
const ValueContainer &inputs);
void Verify() {}
};

template <class ValueContainer>
void CondYieldOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Value cond,
const ValueContainer &inputs) {
argument.AddInput(cond);
argument.AddInputs(inputs);
}
} // namespace pir

IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::YieldOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::CondYieldOp);
11 changes: 10 additions & 1 deletion test/cpp/pir/control_flow_dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
cc_test_old(
test_if_op
if_op_test
SRCS
if_op_test.cc
DEPS
pir
pd_op_dialect
gtest)

cc_test_old(
while_op_test
SRCS
while_op_test.cc
DEPS
pir
pd_op_dialect
gtest)
75 changes: 75 additions & 0 deletions test/cpp/pir/control_flow_dialect/while_op_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>

#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/pir/dialect/control_flow/ir/cf_ops.h"

using namespace paddle::dialect; // NOLINT
TEST(while_op_test, base) {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<pir::ControlFlowDialect>();
ctx->GetOrRegisterDialect<OperatorDialect>();

pir::Program program(ctx);
pir::Block* block = program.block();
pir::Builder builder(ctx, block);

auto i =
builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32)
.out();

auto ten =
builder.Build<FullOp>(std::vector<int64_t>{1}, 10, phi::DataType::INT32)
.out();

auto while_op = builder.Build<WhileOp>(
std::vector<pir::Value>{i, ten},
std::vector<pir::Type>{builder.int32_type(), builder.int32_type()});

// while(i < ten)
pir::Block* cond_block = while_op.cond_block();
auto cond_i_argument = cond_block->AddArgument(i.type());
auto cond_ten_argument = cond_block->AddArgument(ten.type());
builder.SetInsertionPointToStart(cond_block);
auto cond_value =
builder.Build<LessThanOp>(cond_i_argument, cond_ten_argument).out();
builder.Build<pir::CondYieldOp>(
cond_value, std::vector<pir::Value>{cond_i_argument, cond_ten_argument});

// { i = i + 1}
pir::Block* body_block = while_op.body_block();
auto body_i_argument = body_block->AddArgument(i.type());
auto body_ten_argument = body_block->AddArgument(ten.type());
builder.SetInsertionPointToStart(body_block);
auto one =
builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32)
.out();
auto new_i = builder.Build<AddOp>(body_i_argument, one).out();
builder.Build<pir::YieldOp>(
std::vector<pir::Value>{new_i, body_ten_argument});

builder.SetInsertionPointAfter(while_op);
std::stringstream ss;
program.Print(ss);

LOG(INFO) << ss.str();
}

0 comments on commit 3eaa065

Please sign in to comment.