Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Support TypeAttribute. #54984

Merged
merged 10 commits into from
Jul 2, 2023
3 changes: 3 additions & 0 deletions paddle/ir/core/builtin_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ std::vector<Attribute> ArrayAttribute::data() const {

void* PointerAttribute::data() const { return storage()->GetAsKey(); }

Type TypeAttribute::data() const { return storage()->GetAsKey(); }

} // namespace ir

IR_DEFINE_EXPLICIT_TYPE_ID(ir::StrAttribute)
Expand All @@ -45,3 +47,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Attribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Attribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ArrayAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::PointerAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::TypeAttribute)
10 changes: 10 additions & 0 deletions paddle/ir/core/builtin_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ class IR_API PointerAttribute : public Attribute {
void* data() const;
};

class IR_API TypeAttribute : public Attribute {
public:
using Attribute::Attribute;

DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TypeAttribute, TypeAttributeStorage);

Type data() const;
};

} // namespace ir

IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::StrAttribute)
Expand All @@ -113,3 +122,4 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Attribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Attribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ArrayAttribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::PointerAttribute)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::TypeAttribute)
22 changes: 22 additions & 0 deletions paddle/ir/core/builtin_attribute_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/utils.h"

namespace ir {
Expand Down Expand Up @@ -131,4 +132,25 @@ struct ArrayAttributeStorage : public AttributeStorage {
size_t length_ = 0;
};

struct TypeAttributeStorage : public AttributeStorage {
using ParamKey = Type;

explicit TypeAttributeStorage(const ParamKey &key) : value_(key) {}

static TypeAttributeStorage *Construct(ParamKey key) {
return new TypeAttributeStorage(key);
}

static std::size_t HashValue(const ParamKey &key) {
return std::hash<Type>()(key);
}

bool operator==(const ParamKey &key) const { return value_ == key; }

ParamKey GetAsKey() const { return value_; }

private:
Type value_;
};

} // namespace ir
3 changes: 2 additions & 1 deletion paddle/ir/core/builtin_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ void BuiltinDialect::initialize() {
PointerAttribute,
Int32Attribute,
Int64Attribute,
ArrayAttribute>();
ArrayAttribute,
TypeAttribute>();

RegisterOps<ModuleOp,
GetParameterOp,
Expand Down
2 changes: 2 additions & 0 deletions paddle/ir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ void BasicIrPrinter::PrintAttribute(const Attribute& attr) {
[this](Attribute v) { this->PrintAttribute(v); },
[this]() { this->os << ","; });
os << "]";
} else if (auto type = attr.dyn_cast<TypeAttribute>()) {
os << type.GetValue();
} else {
auto& dialect = attr.dialect();
dialect.PrintAttribute(attr, os);
Expand Down
9 changes: 5 additions & 4 deletions paddle/ir/transforms/dce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ namespace {
// TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be
// removed by dce pass.
// Now just a naive implementation.
class DCEPass : public ir::Pass {
class DcePass : public ir::Pass {
public:
DCEPass() : ir::Pass("DCEPass", 0) {}
DcePass() : ir::Pass("DcePass", 0) {}

void Run(ir::Operation *op) override {
auto module_op = op->dyn_cast<ir::ModuleOp>();
IR_ENFORCE(module_op, "DCEPass should run on module op.");
IR_ENFORCE(module_op, "DcePass should run on module op.");
auto *block = module_op.block();
std::vector<ir::Operation> erased_op;
for (auto it = block->begin(); it != block->end(); ++it) {
Expand All @@ -39,6 +39,7 @@ class DCEPass : public ir::Pass {
for (uint32_t i = 0; i < (*it)->num_results(); ++i) {
use_empty &= (*it)->result(i).use_empty();
}
// TODO(wilber): Support Terminator trait.
if (use_empty && (*it)->name() != "pd.fetch") {
erased_op.push_back(**it);
}
Expand All @@ -56,6 +57,6 @@ class DCEPass : public ir::Pass {

namespace ir {

std::unique_ptr<Pass> CreateDCEPass() { return std::make_unique<DCEPass>(); }
std::unique_ptr<Pass> CreateDcePass() { return std::make_unique<DcePass>(); }

} // namespace ir
2 changes: 1 addition & 1 deletion paddle/ir/transforms/dce.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
namespace ir {
class Pass;

IR_API std::unique_ptr<Pass> CreateDCEPass();
IR_API std::unique_ptr<Pass> CreateDcePass();

} // namespace ir
7 changes: 7 additions & 0 deletions test/cpp/ir/core/ir_attribute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"

Expand Down Expand Up @@ -63,4 +64,10 @@ TEST(attribute_test, built_in_attribute) {
string_attr_1.dyn_cast<ir::StrAttribute>();
EXPECT_EQ(string_attr_cast_1.isa<ir::StrAttribute>(), true);
EXPECT_EQ(string_attr_cast_1.size() == 8, 1);

ir::Int32Type i32_type = ir::Int32Type::get(ctx);
ir::Attribute type_attr = ir::TypeAttribute::get(ctx, i32_type);
EXPECT_TRUE(type_attr.isa<ir::TypeAttribute>());
EXPECT_EQ(type_attr.dyn_cast<ir::TypeAttribute>().GetValue().type_id(),
i32_type.type_id());
}
2 changes: 1 addition & 1 deletion test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ TEST(PatternRewrite, GreedyPatternRewriteDriver) {

ir::PassManager pm(ctx);
pm.AddPass(std::make_unique<TestPass>());
pm.AddPass(ir::CreateDCEPass());
pm.AddPass(ir::CreateDcePass());
std::stringstream o1, o2;
program.Print(o1);
LOG(INFO) << o1.str();
Expand Down