diff --git a/paddle/ir/core/builtin_attribute.cc b/paddle/ir/core/builtin_attribute.cc index 06c0f347530cf..d6c2b3f829daf 100644 --- a/paddle/ir/core/builtin_attribute.cc +++ b/paddle/ir/core/builtin_attribute.cc @@ -35,6 +35,8 @@ std::vector ArrayAttribute::data() const { void* PointerAttribute::data() const { return storage()->GetAsKey(); } +Type TypeAttribute::data() const { return storage()->GetAsKey(); } + } // namespace ir IR_DEFINE_EXPLICIT_TYPE_ID(ir::StrAttribute) @@ -45,3 +47,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Attribute) IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Attribute) IR_DEFINE_EXPLICIT_TYPE_ID(ir::ArrayAttribute) IR_DEFINE_EXPLICIT_TYPE_ID(ir::PointerAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(ir::TypeAttribute) diff --git a/paddle/ir/core/builtin_attribute.h b/paddle/ir/core/builtin_attribute.h index 472a5ae8c156e..8d8efbc4d79f7 100644 --- a/paddle/ir/core/builtin_attribute.h +++ b/paddle/ir/core/builtin_attribute.h @@ -103,6 +103,15 @@ class IR_API PointerAttribute : public Attribute { void* data() const; }; +class IR_API TypeAttribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TypeAttribute, TypeAttributeStorage); + + Type data() const; +}; + } // namespace ir IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::StrAttribute) @@ -113,3 +122,4 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Attribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Attribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ArrayAttribute) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::PointerAttribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::TypeAttribute) diff --git a/paddle/ir/core/builtin_attribute_storage.h b/paddle/ir/core/builtin_attribute_storage.h index dce78a0563519..891a0691186f5 100644 --- a/paddle/ir/core/builtin_attribute_storage.h +++ b/paddle/ir/core/builtin_attribute_storage.h @@ -20,6 +20,7 @@ #include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute_base.h" +#include "paddle/ir/core/type.h" #include "paddle/ir/core/utils.h" namespace ir { @@ -131,4 +132,25 @@ struct ArrayAttributeStorage : public AttributeStorage { size_t length_ = 0; }; +struct TypeAttributeStorage : public AttributeStorage { + using ParamKey = Type; + + explicit TypeAttributeStorage(const ParamKey &key) : value_(key) {} + + static TypeAttributeStorage *Construct(ParamKey key) { + return new TypeAttributeStorage(key); + } + + static std::size_t HashValue(const ParamKey &key) { + return std::hash()(key); + } + + bool operator==(const ParamKey &key) const { return value_ == key; } + + ParamKey GetAsKey() const { return value_; } + + private: + Type value_; +}; + } // namespace ir diff --git a/paddle/ir/core/builtin_dialect.cc b/paddle/ir/core/builtin_dialect.cc index 2dc4438564b03..a5e9605c2835e 100644 --- a/paddle/ir/core/builtin_dialect.cc +++ b/paddle/ir/core/builtin_dialect.cc @@ -46,7 +46,8 @@ void BuiltinDialect::initialize() { PointerAttribute, Int32Attribute, Int64Attribute, - ArrayAttribute>(); + ArrayAttribute, + TypeAttribute>(); RegisterOpsPrintAttribute(v); }, [this]() { this->os << ","; }); os << "]"; + } else if (auto type = attr.dyn_cast()) { + os << type.data(); } else { auto& dialect = attr.dialect(); dialect.PrintAttribute(attr, os); diff --git a/paddle/ir/transforms/dce.cc b/paddle/ir/transforms/dce.cc index 31d8a1951fbdd..94613fc017a8b 100644 --- a/paddle/ir/transforms/dce.cc +++ b/paddle/ir/transforms/dce.cc @@ -22,13 +22,13 @@ namespace { // TODO(wilber): After support SideEffectTrait, Only NoSideEffectTrait op can be // removed by dce pass. // Now just a naive implementation. -class DCEPass : public ir::Pass { +class DcePass : public ir::Pass { public: - DCEPass() : ir::Pass("DCEPass", 0) {} + DcePass() : ir::Pass("DcePass", 0) {} void Run(ir::Operation *op) override { auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "DCEPass should run on module op."); + IR_ENFORCE(module_op, "DcePass should run on module op."); auto *block = module_op.block(); std::vector erased_op; for (auto it = block->begin(); it != block->end(); ++it) { @@ -39,6 +39,7 @@ class DCEPass : public ir::Pass { for (uint32_t i = 0; i < (*it)->num_results(); ++i) { use_empty &= (*it)->result(i).use_empty(); } + // TODO(wilber): Support Terminator trait. if (use_empty && (*it)->name() != "pd.fetch") { erased_op.push_back(**it); } @@ -56,6 +57,6 @@ class DCEPass : public ir::Pass { namespace ir { -std::unique_ptr CreateDCEPass() { return std::make_unique(); } +std::unique_ptr CreateDcePass() { return std::make_unique(); } } // namespace ir diff --git a/paddle/ir/transforms/dce.h b/paddle/ir/transforms/dce.h index 061fc04ceb9e2..6e51b1b5b1dbd 100644 --- a/paddle/ir/transforms/dce.h +++ b/paddle/ir/transforms/dce.h @@ -20,6 +20,6 @@ namespace ir { class Pass; -IR_API std::unique_ptr CreateDCEPass(); +IR_API std::unique_ptr CreateDcePass(); } // namespace ir diff --git a/test/cpp/ir/core/ir_attribute_test.cc b/test/cpp/ir/core/ir_attribute_test.cc index 5c53e58a8b90e..291b64a7233cb 100644 --- a/test/cpp/ir/core/ir_attribute_test.cc +++ b/test/cpp/ir/core/ir_attribute_test.cc @@ -19,6 +19,7 @@ #include "paddle/ir/core/attribute_base.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_dialect.h" +#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/ir_context.h" @@ -63,4 +64,10 @@ TEST(attribute_test, built_in_attribute) { string_attr_1.dyn_cast(); EXPECT_EQ(string_attr_cast_1.isa(), true); EXPECT_EQ(string_attr_cast_1.size() == 8, 1); + + ir::Int32Type i32_type = ir::Int32Type::get(ctx); + ir::Attribute type_attr = ir::TypeAttribute::get(ctx, i32_type); + EXPECT_TRUE(type_attr.isa()); + EXPECT_EQ(type_attr.dyn_cast().data().type_id(), + i32_type.type_id()); } diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index 8a8a73b093459..9b4e33348817c 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -429,7 +429,7 @@ TEST(pattern_rewrite, Patterns) { ir::PassManager pm(ctx); pm.AddPass(std::make_unique()); - pm.AddPass(ir::CreateDCEPass()); + pm.AddPass(ir::CreateDcePass()); program.Print(std::cout); std::cout << std::endl; pm.Run(&program);