Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] add while op. #57475

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,27 @@ void IfOp::Print(pir::IrPrinter &printer) {
os << "\n }";
}
void IfOp::Verify() {}

void WhileOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const std::vector<pir::Type> &output_types) {
// auto insert_point = builder.insert_point();
argument.AddInputs(inputs);
argument.AddOutputs(output_types);
argument.AddRegion(nullptr);
argument.AddRegion(nullptr);
}
pir::Block *WhileOp::cond_block() {
pir::Region &cond_region = (*this)->region(0);
if (cond_region.empty()) cond_region.emplace_back();
return cond_region.front();
}
pir::Block *WhileOp::body_block() {
pir::Region &body_region = (*this)->region(1);
if (body_region.empty()) body_region.emplace_back();
return body_region.front();
}
} // namespace dialect
} // namespace paddle

Expand All @@ -1091,3 +1112,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp)
17 changes: 17 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,22 @@ class IfOp : public pir::Op<IfOp> {
void Verify();
};

class WhileOp : public pir::Op<WhileOp> {
public:
using Op::Op;
static const char *name() { return "pd.while"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;

static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const std::vector<pir::Type> &output_types);
void Verify() {}
pir::Block *cond_block();
pir::Block *body_block();
};

} // namespace dialect
} // namespace paddle

Expand All @@ -203,3 +219,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp)
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ void OperatorDialect::initialize() {
paddle::dialect::FusedGemmEpilogueOp,
paddle::dialect::FusedGemmEpilogueGradOp,
paddle::dialect::SplitGradOp,
paddle::dialect::IfOp>();
paddle::dialect::IfOp,
paddle::dialect::WhileOp>();

RegisterInterfaces<ParameterConvertInterface>();
}
Expand Down
10 changes: 6 additions & 4 deletions paddle/pir/core/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ Operation *Builder::Build(const std::vector<Value> &inputs,
}

Operation *Builder::Insert(Operation *op) {
if (block_) {
block_->insert(insert_point_, op);
if (insert_point_.first) {
insert_point_.first->insert(insert_point_.second, op);
} else {
LOG(WARNING) << "Builder's Block is nullptr, insert failed.";
}
return op;
}

BoolType Builder::bool_type() { return BoolType::get(context_); }
UInt8Type Builder::uint8_type() { return UInt8Type::get(context_); }
Int8Type Builder::int8_type() { return Int8Type::get(context_); }
Int16Type Builder::int16_type() { return Int16Type::get(context_); }
Int32Type Builder::int32_type() { return Int32Type::get(context_); }
VectorType Builder::vec_type(const std::vector<Type> &value) {
return VectorType::get(context_, value);
}
Expand All @@ -50,8 +54,6 @@ Float32Type Builder::float32_type() { return Float32Type::get(context_); }

Float64Type Builder::float64_type() { return Float64Type::get(context_); }
IndexType Builder::index_type() { return IndexType::get(context_); }
Int16Type Builder::int16_type() { return Int16Type::get(context_); }
BoolType Builder::bool_type() { return BoolType::get(context_); }
Complex64Type Builder::complex64_type() { return Complex64Type::get(context_); }
Complex128Type Builder::complex128_type() {
return Complex128Type::get(context_);
Expand Down
60 changes: 35 additions & 25 deletions paddle/pir/core/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ namespace pir {
class Type;
class UInt8Type;
class Int8Type;
class Int16Type;
class Int32Type;
class VectorType;
class BFloat16Type;
class Float32Type;
class Float64Type;
class Int16Type;
class IndexType;
class BoolType;
class Complex64Type;
Expand All @@ -42,29 +43,38 @@ class Int64Attribute;
class ArrayAttribute;
class PointerAttribute;

using InsertPoint = std::pair<Block *, Block::Iterator>;
///
/// \brief Unified interface of the Attribute class. Derivation of all Attribute
/// classes only derives interfaces, not members.
///
class Builder {
public:
Builder(IrContext *context, Block *block, Block::Iterator insert_point)
: context_(context) {
SetInsertionPoint(block, insert_point);
}
: context_(context), insert_point_(block, insert_point) {}

Builder(IrContext *context, Block *block)
: Builder(context, block, block->end()) {}

explicit Builder(IrContext *context)
: Builder(context, nullptr, Block::Iterator{}) {}

Builder(IrContext *context, const InsertPoint &insert_point)
: context_(context), insert_point_(insert_point) {}

void SetInsertionPoint(const InsertPoint &insert_point) {
insert_point_ = insert_point;
}

/// Set the insert point to the start of the specified block.
void SetInsertionPointToStart(Block *block) {
SetInsertionPoint(block, block->begin());
}

/// Set the insertion point to the specified location.
void SetInsertionPoint(Block *block, Block::Iterator insert_point) {
// TODO(liuyuanle): check that insertPoint is in this rather than some other
// block.
this->block_ = block;
this->insert_point_ = insert_point;
insert_point_.first = block;
insert_point_.second = insert_point;
}

/// Set the insertion point to the specified operation, which will cause
Expand All @@ -79,19 +89,16 @@ class Builder {
SetInsertionPoint(op->GetParent(), std::next(Block::Iterator{*op}));
}

/// Set the insertion point to the start of the specified block.
void SetInsertionPointToStart(Block *block) {
SetInsertionPoint(block, block->begin());
}

/// Set the insertion point to the end of the specified block.
void SetInsertionPointToEnd(Block *block) {
SetInsertionPoint(block, block->end());
}

IrContext *ir_context() const { return context_; }

Block *block() const { return block_; }
Block *block() const { return insert_point_.first; }

const InsertPoint &insert_point() const { return insert_point_; }

/// Creates an operation given the fields represented as an OperationState.
IR_API Operation *Build(OperationArgument &&argument);
Expand All @@ -104,22 +111,18 @@ class Builder {

/// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpTy Build(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::Build(*this, argument, std::forward<Args>(args)...);
Operation *op = Build(std::move(argument));
return OpTy(op);
}
OpTy Build(Args &&...args);

IR_API BoolType bool_type();
IR_API UInt8Type uint8_type();
IR_API Int8Type int8_type();
IR_API Int16Type int16_type();
IR_API Int32Type int32_type();
IR_API VectorType vec_type(const std::vector<Type> &);
IR_API BFloat16Type bfloat16_type();
IR_API IndexType index_type();
IR_API Float32Type float32_type();
IR_API Float64Type float64_type();
IR_API Int16Type int16_type();
IR_API BoolType bool_type();
IR_API Complex64Type complex64_type();
IR_API Complex128Type complex128_type();

Expand All @@ -136,9 +139,16 @@ class Builder {
Operation *Insert(Operation *op);

IrContext *context_;
Block *block_;
// The insertion point within the list that this builder is inserting before.
Block::Iterator insert_point_;

InsertPoint insert_point_;
};

template <typename OpTy, typename... Args>
OpTy Builder::Build(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::Build(*this, argument, std::forward<Args>(args)...);
Operation *op = Build(std::move(argument));
return OpTy(op);
}

} // namespace pir
2 changes: 1 addition & 1 deletion paddle/pir/dialect/control_flow/ir/cf_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
#include "paddle/pir/dialect/control_flow/ir/cf_ops.h"

namespace pir {
void ControlFlowDialect::initialize() { RegisterOps<YieldOp>(); }
void ControlFlowDialect::initialize() { RegisterOps<YieldOp, CondYieldOp>(); }
} // namespace pir
IR_DEFINE_EXPLICIT_TYPE_ID(pir::ControlFlowDialect)
1 change: 1 addition & 0 deletions paddle/pir/dialect/control_flow/ir/cf_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ void YieldOp::Build(Builder &builder,
} // namespace pir

IR_DEFINE_EXPLICIT_TYPE_ID(pir::YieldOp)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::CondYieldOp)
27 changes: 26 additions & 1 deletion paddle/pir/dialect/control_flow/ir/cf_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#pragma once

#include <functional>
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/op_base.h"

Expand All @@ -30,6 +30,31 @@ class IR_API YieldOp : public Op<YieldOp> {
const std::vector<Value> &Value);
void Verify() {}
};

class IR_API CondYieldOp : public Op<CondYieldOp> {
public:
using Op::Op;
static const char *name() { return "cf.cond_yield"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;

template <class ValueContainer>
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Value cond,
const ValueContainer &inputs);
void Verify() {}
};

template <class ValueContainer>
void CondYieldOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Value cond,
const ValueContainer &inputs) {
argument.AddInput(cond);
argument.AddInputs(inputs);
}
} // namespace pir

IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::YieldOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::CondYieldOp);
11 changes: 10 additions & 1 deletion test/cpp/pir/control_flow_dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
cc_test_old(
test_if_op
if_op_test
SRCS
if_op_test.cc
DEPS
pir
pd_op_dialect
gtest)

cc_test_old(
while_op_test
SRCS
while_op_test.cc
DEPS
pir
pd_op_dialect
gtest)
75 changes: 75 additions & 0 deletions test/cpp/pir/control_flow_dialect/while_op_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <iostream>

#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/dialect/control_flow/ir/cf_dialect.h"
#include "paddle/pir/dialect/control_flow/ir/cf_ops.h"

using namespace paddle::dialect; // NOLINT
TEST(while_op_test, base) {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<pir::ControlFlowDialect>();
ctx->GetOrRegisterDialect<OperatorDialect>();

pir::Program program(ctx);
pir::Block* block = program.block();
pir::Builder builder(ctx, block);

auto i =
builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32)
.out();

auto ten =
builder.Build<FullOp>(std::vector<int64_t>{1}, 10, phi::DataType::INT32)
.out();

auto while_op = builder.Build<WhileOp>(
std::vector<pir::Value>{i, ten},
std::vector<pir::Type>{builder.int32_type(), builder.int32_type()});

// while(i < ten)
pir::Block* cond_block = while_op.cond_block();
auto cond_i_argument = cond_block->AddArgument(i.type());
auto cond_ten_argument = cond_block->AddArgument(ten.type());
builder.SetInsertionPointToStart(cond_block);
auto cond_value =
builder.Build<LessThanOp>(cond_i_argument, cond_ten_argument).out();
builder.Build<pir::CondYieldOp>(
cond_value, std::vector<pir::Value>{cond_i_argument, cond_ten_argument});

// { i = i + 1}
pir::Block* body_block = while_op.body_block();
auto body_i_argument = body_block->AddArgument(i.type());
auto body_ten_argument = body_block->AddArgument(ten.type());
builder.SetInsertionPointToStart(body_block);
auto one =
builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32)
.out();
auto new_i = builder.Build<AddOp>(body_i_argument, one).out();
builder.Build<pir::YieldOp>(
std::vector<pir::Value>{new_i, body_ten_argument});

builder.SetInsertionPointAfter(while_op);
std::stringstream ss;
program.Print(ss);

LOG(INFO) << ss.str();
}