Skip to content

Commit

Permalink
[PIR] support verify in trait or interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang committed Sep 20, 2023
1 parent 3e8da40 commit debe468
Show file tree
Hide file tree
Showing 15 changed files with 295 additions and 204 deletions.
14 changes: 4 additions & 10 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ Operation *BuildOpFrom(
std::back_inserter(to_create_argument.inputs),
[&value_map](const pir::OpOperand &operand) {
// Operand -> OpResult
return OpResult::dyn_cast_from(value_map[operand.source()]);
return value_map[operand.source()];
});
auto *cloned_op = Operation::Create(std::move(to_create_argument));

Expand Down Expand Up @@ -780,11 +780,8 @@ SplitedResult ForwardBackwardSplit(
pir::StrAttribute::get(
ctx, std::string("output_") + std::to_string(counter))},
};
pir::Operation *operation =
pir::Operation::Create({OpResult::dyn_cast_from(forward_value_map[v])},
attribute_map,
{},
op_info);
pir::Operation *operation = pir::Operation::Create(
{forward_value_map[v]}, attribute_map, {}, op_info);
forward_program->block()->push_back(operation);
counter += 1;
};
Expand All @@ -803,10 +800,7 @@ SplitedResult ForwardBackwardSplit(
ctx, std::string("output_") + std::to_string(counter))},
};
pir::Operation *operation = pir::Operation::Create(
{OpResult::dyn_cast_from(backward_value_map.at(v))},
attribute_map,
{},
op_info);
{backward_value_map.at(v)}, attribute_map, {}, op_info);
backward_program->block()->push_back(operation);
counter += 1;
};
Expand Down
8 changes: 7 additions & 1 deletion paddle/pir/core/op_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ class OpInterfaceBase : public OpBase {
return ConcreteInterface(nullptr, nullptr);
}
};
template <class T>
void VerifyTraitOrInterface(Operation *) {}
template <class T>
decltype(T::Verify(nullptr)) VerifyTraitOrInterface(Operation *op) {
return T::Verify(op);
}

template <typename ConcreteOp, class... TraitOrInterface>
class Op : public OpBase {
Expand Down Expand Up @@ -139,12 +145,12 @@ class Op : public OpBase {
class EmptyOp : public Op<EmptyOp, TraitOrInterface...> {};
return sizeof(ConcreteOp) == sizeof(EmptyOp);
}

// Implementation of `VerifyInvariantsFn` OperationName hook.
static void VerifyInvariants(Operation *op) {
static_assert(HasNoDataMembers(),
"Op class shouldn't define new data members");
op->dyn_cast<ConcreteOp>().Verify();
(VerifyTraitOrInterface<TraitOrInterface>(op), ...);
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/op_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ class IR_API OpResult : public Value {
Operation *owner() const;
uint32_t index() const;
bool operator==(const OpResult &other) const;
static OpResult dyn_cast_from(Value value);

private:
friend Operation;
OpResult(detail::OpResultImpl *impl); // NOLINT
// Access classof annd dyn_cast_from.
friend Value;
static bool classof(Value value);
static OpResult dyn_cast_from(Value value);
};

} // namespace pir
6 changes: 6 additions & 0 deletions paddle/pir/core/operation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ struct OperationArgument {
/// Add an array of named attributes.
template <class InputIt>
void AddAttributes(InputIt first, InputIt last);

template <class AttrContainer>
void AddAttributes(const AttrContainer& attr_container) {
AddAttributes(std::begin(attr_container), std::end(attr_container));
}

/// Get the context held by this operation state.
IrContext* getContext() const { return info.ir_context(); }

Expand Down
9 changes: 8 additions & 1 deletion test/cpp/pir/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@ cc_test_old(
pd_op_dialect)
cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS pir gtest)
cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS pir gtest)
cc_test_old(ir_op_test SRCS ir_op_test.cc DEPS pir gtest)
cc_test_old(
ir_op_test
SRCS
ir_op_test.cc
DEPS
pir
gtest
test_dialect)
cc_test_old(ir_region_test SRCS ir_region_test.cc DEPS pir gtest)
cc_test_old(ir_builder_test SRCS ir_builder_test.cc DEPS pir gtest)
cc_test_old(
Expand Down
203 changes: 26 additions & 177 deletions test/cpp/pir/core/ir_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,49 +27,8 @@
#include "paddle/pir/core/op_base.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/core/region.h"

/// \brief Define built-in Trait, derived from OpTraitBase.
class ReadOnlyTrait : public pir::OpTraitBase<ReadOnlyTrait> {
public:
explicit ReadOnlyTrait(pir::Operation *op)
: pir::OpTraitBase<ReadOnlyTrait>(op) {}
};
IR_DECLARE_EXPLICIT_TYPE_ID(ReadOnlyTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(ReadOnlyTrait)

/// \brief Define built-in Interface, derived from OpInterfaceBase. Concepts and
/// Models need to be defined within the class. Concept defines abstract
/// interface functions, and Model is a template class that defines the specific
/// implementation of interface functions based on template parameters.
class InferShapeInterface : public pir::OpInterfaceBase<InferShapeInterface> {
public:
struct Concept {
explicit Concept(void (*infer_shape)(pir::Operation *))
: infer_shape_(infer_shape) {}
void (*infer_shape_)(pir::Operation *);
};

template <class ConcreteOp>
struct Model : public Concept {
static void InferShape(pir::Operation *op) {
ConcreteOp concret_op = ConcreteOp(op);
if (concret_op == nullptr) throw("concret_op is nullptr");
concret_op.InferShape();
}

Model() : Concept(InferShape) {}
};

InferShapeInterface(pir::Operation *op, Concept *impl)
: pir::OpInterfaceBase<InferShapeInterface>(op), impl_(impl) {}

void InferShape() { impl_->infer_shape_(operation()); }

private:
Concept *impl_;
};
IR_DECLARE_EXPLICIT_TYPE_ID(InferShapeInterface)
IR_DEFINE_EXPLICIT_TYPE_ID(InferShapeInterface)
#include "test/cpp/pir/tools/test_dialect.h"
#include "test/cpp/pir/tools/test_op.h"

pir::AttributeMap CreateAttributeMap(
const std::vector<std::string> &attribute_names,
Expand All @@ -84,139 +43,15 @@ pir::AttributeMap CreateAttributeMap(
return attr_map;
}

// Define op1.
class Operation1 : public pir::Op<Operation1> {
public:
using Op::Op;
static const char *name() { return "test.operation1"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num]; // NOLINT
void Verify() {
auto &attributes = this->attributes();
if (attributes.count("op1_attr1") == 0 ||
!attributes.at("op1_attr1").isa<pir::StrAttribute>()) {
throw("Type of attribute: parameter_name is not right.");
}
if (attributes.count("op1_attr2") == 0 ||
!attributes.at("op1_attr2").isa<pir::StrAttribute>()) {
throw("Type of attribute: parameter_name is not right.");
}
}
static void Build(const pir::Builder &builder,
pir::OperationArgument &argument) { // NOLINT
std::vector<pir::Type> output_types = {
pir::Float32Type::get(builder.ir_context())};
std::unordered_map<std::string, pir::Attribute> attributes =
CreateAttributeMap({"op1_attr1", "op1_attr2"},
{"op1_attr1", "op1_attr2"});
argument.AddOutputs(output_types.begin(), output_types.end());
argument.AddAttributes(attributes.begin(), attributes.end());
}
};
const char *Operation1::attributes_name[attributes_num] = { // NOLINT
"op1_attr1",
"op1_attr2"};

IR_DECLARE_EXPLICIT_TYPE_ID(Operation1)
IR_DEFINE_EXPLICIT_TYPE_ID(Operation1)

// Define op2.
class Operation2
: public pir::Op<Operation2, ReadOnlyTrait, InferShapeInterface> {
public:
using Op::Op;
static const char *name() { return "test.operation2"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num]; // NOLINT
void Verify() {
auto &attributes = this->attributes();
if (attributes.count("op2_attr1") == 0 ||
(!attributes.at("op2_attr1").isa<pir::StrAttribute>())) {
throw("Type of attribute: parameter_name is not right.");
}
if (attributes.count("op2_attr2") == 0 ||
(!attributes.at("op2_attr2").isa<pir::StrAttribute>())) {
throw("Type of attribute: parameter_name is not right.");
}
}
static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; }
};
const char *Operation2::attributes_name[attributes_num] = { // NOLINT
"op2_attr1",
"op2_attr2"};
IR_DECLARE_EXPLICIT_TYPE_ID(Operation2)
IR_DEFINE_EXPLICIT_TYPE_ID(Operation2)

// Define a dialect, op1 and op2 will be registered by this dialect.
class TestDialect : public pir::Dialect {
public:
explicit TestDialect(pir::IrContext *context)
: pir::Dialect(name(), context, pir::TypeId::get<TestDialect>()) {
initialize();
}
static const char *name() { return "test"; }

void PrintOperation(pir::Operation *op,
pir::IrPrinter &printer) const override {
printer.PrintOpResult(op);
printer.os << " =";

printer.os << " \"" << op->name() << "\"";
printer.PrintOpOperands(op);
}

private:
void initialize() { RegisterOps<Operation1, Operation2>(); }
};
IR_DECLARE_EXPLICIT_TYPE_ID(TestDialect)
IR_DEFINE_EXPLICIT_TYPE_ID(TestDialect)

TEST(op_test, op_test) {
// (1) Register Dialect, Operation1, Operation2 into IrContext.
pir::IrContext *ctx = pir::IrContext::Instance();
pir::Dialect *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
EXPECT_EQ(test_dialect != nullptr, true);

// (2) Get registered operations.
std::string op1_name = Operation1::name();
pir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
EXPECT_TRUE(op1_info);
std::string op2_name = Operation2::name();
pir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
EXPECT_TRUE(op2_info);
EXPECT_EQ(op1_info.HasTrait<ReadOnlyTrait>(), false);
EXPECT_EQ(op1_info.HasInterface<InferShapeInterface>(), false);
EXPECT_EQ(op2_info.HasTrait<ReadOnlyTrait>(), true);
EXPECT_EQ(op2_info.HasInterface<InferShapeInterface>(), true);

// (3) Test uses for op.
std::vector<pir::Value> op_inputs = {};
std::vector<pir::Type> op_output_types = {pir::Float32Type::get(ctx)};
pir::Operation *op2 =
pir::Operation::Create(op_inputs,
CreateAttributeMap({"op2_attr1", "op2_attr2"},
{"op2_attr1", "op2_attr2"}),
op_output_types,
op2_info);

ReadOnlyTrait trait = op2->dyn_cast<ReadOnlyTrait>();
EXPECT_EQ(trait.operation(), op2);
InferShapeInterface interface = op2->dyn_cast<InferShapeInterface>();
interface.InferShape();
Operation2 Op2 = op2->dyn_cast<Operation2>();
EXPECT_EQ(Op2.operation(), op2);
op2->Destroy();
}

TEST(op_test, region_test) {
// (1) Register Dialect, Operation1, Operation2 into IrContext.
pir::IrContext *ctx = pir::IrContext::Instance();
pir::Dialect *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
pir::Dialect *test_dialect = ctx->GetOrRegisterDialect<test::TestDialect>();
EXPECT_EQ(test_dialect != nullptr, true);

// (2) Get registered operations.
pir::OpInfo op1_info = ctx->GetRegisteredOpInfo(Operation1::name());
pir::OpInfo op2_info = ctx->GetRegisteredOpInfo(Operation2::name());
pir::OpInfo op1_info = ctx->GetRegisteredOpInfo(test::Operation1::name());
pir::OpInfo op2_info = ctx->GetRegisteredOpInfo(test::Operation2::name());

pir::Operation *op1 =
pir::Operation::Create({},
Expand All @@ -225,15 +60,9 @@ TEST(op_test, region_test) {
{pir::Float32Type::get(ctx)},
op1_info);
pir::Operation *op1_2 =
pir::Operation::Create({},
CreateAttributeMap({"op1_attr1", "op1_attr2"},
{"op1_attr1", "op1_attr2"}),
{pir::Float32Type::get(ctx)},
op1_info);
pir::Operation::Create({}, {}, {pir::Float32Type::get(ctx)}, op1_info);

pir::OperationArgument argument(op2_info);
argument.attributes = CreateAttributeMap({"op2_attr1", "op2_attr2"},
{"op2_attr1", "op2_attr2"});
argument.output_types = {pir::Float32Type::get(ctx)};
argument.num_regions = 1;

Expand Down Expand Up @@ -279,3 +108,23 @@ TEST(op_test, module_op_death) {
program.module_op()->set_attribute("program",
pir::PointerAttribute::get(ctx, &program));
}

TEST(op_test, trait_and_interface) {
pir::IrContext ctx;
ctx.GetOrRegisterDialect<test::TestDialect>();
pir::Program program(&ctx);
auto block = program.block();
pir::Builder builder(&ctx, block);
auto op1 = builder.Build<test::Operation1>();
auto op2 = builder.Build<test::Operation2>();

EXPECT_EQ(op1->HasTrait<test::ReadOnlyTrait>(), false);
EXPECT_EQ(op1->HasInterface<test::InferShapeInterface>(), false);
EXPECT_EQ(op2->HasTrait<test::ReadOnlyTrait>(), true);
EXPECT_EQ(op2->HasInterface<test::InferShapeInterface>(), true);

pir::OperationArgument argument(&ctx, "test.region");
argument.num_regions = 2u;
auto p3 = builder.Build(argument);
EXPECT_THROW(builder.Build(argument), pir::IrNotMetException);
}
2 changes: 1 addition & 1 deletion test/cpp/pir/tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cc_library(
test_dialect
SRCS test_dialect.cc test_op.cc
SRCS test_dialect.cc test_op.cc test_trait.cc test_interface.cc
DEPS pir)
19 changes: 18 additions & 1 deletion test/cpp/pir/tools/test_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "test/cpp/pir/tools/test_dialect.h"
#include "paddle/pir/core/ir_printer.h"
#include "test/cpp/pir/tools/test_op.h"
namespace test {
void TestDialect::initialize() { RegisterOps<RegionOp, BranchOp>(); }

TestDialect::TestDialect(pir::IrContext *context)
: pir::Dialect(name(), context, pir::TypeId::get<TestDialect>()) {
initialize();
}
void TestDialect::initialize() {
RegisterOps<RegionOp, BranchOp, Operation1, Operation2>();
}

void TestDialect::PrintOperation(pir::Operation *op,
pir::IrPrinter &printer) const {
printer.PrintOpResult(op);
printer.os << " =";

printer.os << " \"" << op->name() << "\"";
printer.PrintOpOperands(op);
}
} // namespace test
IR_DEFINE_EXPLICIT_TYPE_ID(test::TestDialect)
7 changes: 3 additions & 4 deletions test/cpp/pir/tools/test_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
namespace test {
class TestDialect : public pir::Dialect {
public:
explicit TestDialect(pir::IrContext *context)
: pir::Dialect(name(), context, pir::TypeId::get<TestDialect>()) {
initialize();
}
explicit TestDialect(pir::IrContext *context);
static const char *name() { return "test"; }
void PrintOperation(pir::Operation *op,
pir::IrPrinter &printer) const override;

private:
void initialize();
Expand Down
Loading

0 comments on commit debe468

Please sign in to comment.