diff --git a/paddle/ir/CMakeLists.txt b/paddle/ir/CMakeLists.txt index 03db95e07332a7..c5524ee38754b1 100644 --- a/paddle/ir/CMakeLists.txt +++ b/paddle/ir/CMakeLists.txt @@ -4,3 +4,4 @@ endif() add_subdirectory(core) add_subdirectory(pass) +add_subdirectory(pattern_rewrite) diff --git a/paddle/ir/pattern_rewrite/CMakeLists.txt b/paddle/ir/pattern_rewrite/CMakeLists.txt new file mode 100644 index 00000000000000..9d5dba05eebc58 --- /dev/null +++ b/paddle/ir/pattern_rewrite/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB PATTERN_SRCS "*.cc") + +cc_library( + pattern_rewrite + SRCS ${PATTERN_SRCS} + DEPS new_ir) diff --git a/paddle/ir/pattern_rewrite/pattern_match.cc b/paddle/ir/pattern_rewrite/pattern_match.cc new file mode 100644 index 00000000000000..728e7eb708d0dc --- /dev/null +++ b/paddle/ir/pattern_rewrite/pattern_match.cc @@ -0,0 +1,144 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/pattern_rewrite/pattern_match.h" +#include +#include +#include "paddle/ir/core/operation.h" + +namespace ir { + +//===----------------------------------------------------------------------===// +// Pattern +//===----------------------------------------------------------------------===// + +// Pattern::Pattern(const void* root_val, +// RootKind root_kind, +// const std::vector& generated_names, +// PatternBenefit benefit, +// ir::IrContext* context) +// : benefit_(benefit), context_(context), generated_names_(generated_names) +// {} + +Pattern::Pattern(const std::string& root_name, + PatternBenefit benefit, + IrContext* context, + const std::vector& generated_names) + : op_name_(root_name), + root_kind_(RootKind::OperationName), + benefit_(benefit), + context_(context), + generated_names_(generated_names) {} + +Pattern::Pattern(MatchAnyOpTypeTag tag, + PatternBenefit benefit, + ir::IrContext* context, + const std::vector& generated_names) + : root_kind_(RootKind::Any), + benefit_(benefit), + context_(context), + generated_names_(generated_names) {} + +Pattern::Pattern(MatchInterfaceOpTypeTag tag, + ir::TypeId interface_id, + PatternBenefit benefit, + ir::IrContext* context, + const std::vector& generated_names) + : interface_id_(interface_id), + root_kind_(RootKind::InterfaceId), + benefit_(benefit), + context_(context), + generated_names_(generated_names) {} + +Pattern::Pattern(MatchTraitOpTypeTag tag, + ir::TypeId trait_id, + PatternBenefit benefit, + ir::IrContext* context, + const std::vector& generated_names) + : trait_id_(trait_id), + root_kind_(RootKind::TraitId), + benefit_(benefit), + context_(context), + generated_names_(generated_names) {} + +RewritePattern::~RewritePattern() = default; + +//===----------------------------------------------------------------------===// +// RewriterBase +//===----------------------------------------------------------------------===// + +RewriterBase::~RewriterBase() = default; + +// TODO(wilber): value support replace method. +// void RewriterBase::ReplaceOpWithIf(Operation* op, +// ValueRange new_values, +// bool* all_uses_replaced, +// std::function functor) { +// // assert(op->num_results() == new_values.size() && "incorrect number of +// values to replace operation"); NotifyRootReplaced(op, new_values); bool +// replace_all_uses = true; for (uint32_t i = 0; i < op->num_results(); ++i) { +// // op->GetResultByIndex(0) +// } +// } +// void RewriterBase::ReplaceOpWithIf(Operation* op, +// ValueRange new_values, +// std::function functor) { +// ReplaceOpWithIf(op, new_values, nullptr, functor); +// } + +// TODO(wilber): support erase. +// void ReplaceOp(Operation* op, ValueRange new_values) { +// NotifyRootReplaced(op, new_values); +// assert(op->num_results() == new_values.size() && "incorrect # of +// replacement values"); op->ReplaceAllUsesWith(new_values); +// NotifyOperationRemoved(op); +// op->erase(); +// } +void RewriterBase::EraseOp(Operation* op) { + // assert(op->use_empty() && "expected 'op' to have no uses"); + // NotifyOperationRemoved(op); + // op->erase(); +} + +void RewriterBase::ReplaceAllUsesWith(Value from, Value to) { + // from. + // for (mlir::OpOperand& operand : llvm::make_early_inc_range(from.getUses())) + // { + // mlir::Operation* op = operand.getOwner(); + // UpdateRootInPlace(op, [&]() { operand.set(to); }); + // } +} + +// TODO(wilber): iterator maybe should support modify inplace. +void RewriterBase::ReplaceUseIf(Value from, + Value to, + std::function functor) { + // for (auto it = from.begin(); it != from.end(); ++it) { + // // TODO: need a lvalue. + // if (functor(it.get())) { + // UpdateRootInplace(it.owner(), [&](){it.get().set(to)}); + // } + // } +} + +void RewriterBase::ReplaceOpWithResultsOfAnotherOp(Operation* op, + Operation* new_op) { + assert(op->num_results() == new_op->num_results() && + "replacement op doesn't match results of original op"); + // TODO(wilber): Op support results method. + // if (op->num_results() == 1) return ReplaceOp(op, + // new_op->GetResultByIndex(0)); return ReplaceOp(op, new_op->GetResults()); +} + +} // namespace ir diff --git a/paddle/ir/pattern_rewrite/pattern_match.h b/paddle/ir/pattern_rewrite/pattern_match.h new file mode 100644 index 00000000000000..0017afea612363 --- /dev/null +++ b/paddle/ir/pattern_rewrite/pattern_match.h @@ -0,0 +1,356 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/operation.h" +#include "paddle/ir/core/type_id.h" +#include "paddle/ir/core/type_name.h" +#include "paddle/ir/core/value.h" +namespace ir { + +/// The design is mainly from MLIR, very thanks to the greate project. + +/// This class reprensents the benefit of a pattern. The most common +/// unit to use is the `numver of operations` in the pattern. +class PatternBenefit { + public: + PatternBenefit(unsigned val) : val_(val) {} // NOLINT + + unsigned benefit() { return val_; } + + bool operator==(const PatternBenefit& rhs) const { return val_ == rhs.val_; } + bool operator!=(const PatternBenefit& rhs) const { return !(*this == rhs); } + bool operator<(const PatternBenefit& rhs) const { return val_ < rhs.val_; } + bool operator>(const PatternBenefit& rhs) const { return rhs < *this; } + bool operator<=(const PatternBenefit& rhs) const { return !(*this > rhs); } + bool operator>=(const PatternBenefit& rhs) const { return !(*this <= rhs); } + + private: + unsigned int val_{0}; +}; + +/// This class contains all of the data related to a Pattern, but not contains +/// any methods for the matching. This class is used to interface with the +/// metadata of a pattern, such as benefit or root operation. +class Pattern { + enum class RootKind { Any, OperationName, InterfaceId, TraitId }; + + public: + PatternBenefit benefit() const { return benefit_; } + + IrContext* context() const { return context_; } + + std::string debug_name() const { return debug_name_; } + + void SetDebugName(const std::string& name) { debug_name_ = name; } + + const std::vector& debug_labels() const { return debug_labels_; } + + void AddDebugLabels(const std::vector& labels) { + debug_labels_.insert(debug_labels_.end(), labels.begin(), labels.end()); + } + + void AddDebugLabels(const std::string& label) { + debug_labels_.push_back(label); + } + + protected: + struct MatchAnyOpTypeTag {}; + struct MatchInterfaceOpTypeTag {}; + struct MatchTraitOpTypeTag {}; + + Pattern(const std::string& root_name, + PatternBenefit benefit, + ir::IrContext* context, + const std::vector& generated_names = {}); + + Pattern(MatchAnyOpTypeTag tag, + PatternBenefit benefit, + ir::IrContext* context, + const std::vector& generated_names = {}); + + Pattern(MatchInterfaceOpTypeTag tag, + ir::TypeId interface_id, + PatternBenefit benefit, + ir::IrContext* context, + const std::vector& generated_names = {}); + + Pattern(MatchTraitOpTypeTag tag, + ir::TypeId trait_id, + PatternBenefit benefit, + ir::IrContext* context, + const std::vector& generated_names = {}); + + private: + // TODO(wilber): How to uniform variables and constructor. + // Pattern(const void* root_val, + // RootKind root_kind, + // const std::vector& generated_names, + // PatternBenefit benefit, + // ir::IrContext* context); + std::string op_name_; + ir::TypeId interface_id_; + ir::TypeId trait_id_; + RootKind root_kind_; + + const PatternBenefit benefit_; + ir::IrContext* context_; + std::vector generated_names_; + + std::string debug_name_; + std::vector debug_labels_; +}; + +class PatternRewriter; + +class RewritePattern : public Pattern { + public: + virtual ~RewritePattern(); + + virtual void Rewrite(ir::Operation* op, + PatternRewriter& rewriter) const { // NOLINT + throw( + "need to implement either MatchAndRewrite or one of the rewrite " + "functions."); + } + + virtual bool Match(ir::Operation* op) const { + throw("need to implement either MatchAndRewrite or Match."); + return false; + } + + virtual bool MatchAndRewrite(ir::Operation* op, + PatternRewriter& rewriter) const { // NOLINT + if (Match(op)) { + Rewrite(op, rewriter); + return true; + } + return false; + } + + virtual void Initialize() {} + + template + static std::unique_ptr Create(Args&&... args) { + std::unique_ptr pattern = + std::make_unique(std::forward(args)...); + pattern->Initialize(); + + if (pattern->debug_name().empty()) + pattern->SetDebugName(get_type_name()); + return pattern; + } + + protected: + using Pattern::Pattern; +}; + +namespace detail { +/// A wrapper around PatternWrite that allows for matching and rewriting +/// against an instance of a derived operation class or Interface. +template +struct OpOrInterfaceRewritePatternBase : public RewritePattern { + using RewritePattern::RewritePattern; + + void Rewrite(Operation* op, + PatternRewriter& rewriter) const final { // NOLINT + Rewrite(op->dyn_cast(), rewriter); + } + + bool Match(Operation* op) const final { + return Match(op->dyn_cast()); + } + bool MatchAndRewrite(Operation* op, + PatternRewriter& rewriter) const final { // NOLINT + return MatchAndRewrite(op->dyn_cast(), rewriter); + } + + virtual void Rewrite(SourceOp op, + PatternRewriter& rewriter) const { // NOLINT + throw("must override Rewrite or MatchAndRewrite"); + } + virtual bool Match(SourceOp op) const { + throw("must override Match or MatchAndRewrite"); + } + virtual bool MatchAndRewrite(SourceOp op, + PatternRewriter& rewriter) const { // NOLINT + if (Match(op)) { + Rewrite(op, rewriter); + return true; + } + return false; + } +}; +} // namespace detail + +/// OpRewritePattern is a wrapper around RewritePattern that allows for +/// matching and rewriting against an instance of a derived operation +/// class as opposed to a raw Operation. +template +struct OpRewritePattern + : public detail::OpOrInterfaceRewritePatternBase { + OpRewritePattern(ir::IrContext* context, + PatternBenefit benefit = 1, + const std::vector& generated_names = {}) + : detail::OpOrInterfaceRewritePatternBase( + "NeedToFix", // TODO(wilber): Need to fix. SourceOp maybe should + // have a getOperationName static method. + benefit, + context, + generated_names) {} +}; + +// TODO(wilber): Support OpInterfaceRewritePattern and OpTraitRewritePattern. +// ... + +/// This class provides a series of interfaces for modifying IR and tracking IR +/// changes. This class provides a unified API for IR modification. +/// +class RewriterBase { // maybe should inherit OpBuilder. + public: + // TODO(wilber): Supplementary methods of block and region. + + // TODO(wilber): Support ValueRange. + // virtual void ReplaceOpWithIf(Operation* op, + // ValueRange new_values, + // bool* all_uses_replaced, + // std::function functor); + // void ReplaceOpWithIf(Operation* op, + // ValueRange new_values, + // std::function functor); + // virtual void ReplaceOp(Operation* op, ValueRange new_values); + + // virtual void ReplaceOpWithNewOp() + + virtual void EraseOp(Operation* op); + + virtual void StartRootUpdate(Operation* op) {} + virtual void FinalizeRootUpdate(Operation* op) {} + virtual void CancleRootUpdate(Operation* op) {} + + template + void UpdateRootInplace(Operation* root, CallableT&& callable) { + StartRootUpdate(root); + callable(); + FinalizeRootUpdate(root); + } + + void ReplaceAllUsesWith(Value from, Value to); + + void ReplaceUseIf(Value from, + Value to, + std::function functor); + + protected: + explicit RewriterBase(IrContext* ctx) : ctx_(ctx) {} + + virtual ~RewriterBase(); + + // virtual void NotifyRootReplaced(Operation* op, ValueRange replacement) {} + + virtual void NotifyOperationRemoved(Operation* op) {} + + // virtual bool NotifyMatchFailure() + + private: + void operator=(const RewriterBase&) = delete; + RewriterBase(const RewriterBase&) = delete; + + void ReplaceOpWithResultsOfAnotherOp(Operation* op, Operation* new_op); + + private: + IrContext* ctx_; +}; + +class PatternRewriter : public RewriterBase { + public: + using RewriterBase::RewriterBase; +}; + +/// A pattern collection, easy to add patterns. +class RewritePatternSet { + using NativePatternListT = std::vector>; + + public: + explicit RewritePatternSet(IrContext* context) : context_(context) {} + + RewritePatternSet(IrContext* context, std::unique_ptr pattern) + : context_(context) { + native_patterns_.emplace_back(std::move(pattern)); + } + + IrContext* context() const { return context_; } + + NativePatternListT& native_patterns() { return native_patterns_; } + + void Clear() { native_patterns_.clear(); } + + // 'add' methods for adding patterns to the set. + template > + RewritePatternSet& Add(ConstructorArg&& arg, ConstructorArgs&&... args) { + std::initializer_list{ + (AddImpl({}, + std::forward(arg), + std::forward(args)...), + 0)...}; + return *this; + } + + template > + RewritePatternSet& AddWithLabel(const std::vector& debug_labels, + ConstructorArg&& arg, + ConstructorArgs&&... args) { + std::initializer_list{ + (AddImpl(debug_labels, + std::forward(arg), + std::forward(args)...), + 0)...}; + return *this; + } + + RewritePatternSet& Add(std::unique_ptr pattern) { + native_patterns_.emplace_back(std::move(pattern)); + return *this; + } + + private: + template + std::enable_if_t::value> AddImpl( + const std::vector& debug_labels, Args&&... args) { + std::unique_ptr pattern = + RewritePattern::Create(std::forward(args)...); + pattern->AddDebugLabels(debug_labels); + native_patterns_.emplace_back(std::move(pattern)); + } + + private: + IrContext* const context_; + NativePatternListT native_patterns_; +}; +} // namespace ir diff --git a/test/cpp/ir/CMakeLists.txt b/test/cpp/ir/CMakeLists.txt index 03db95e07332a7..c5524ee38754b1 100644 --- a/test/cpp/ir/CMakeLists.txt +++ b/test/cpp/ir/CMakeLists.txt @@ -4,3 +4,4 @@ endif() add_subdirectory(core) add_subdirectory(pass) +add_subdirectory(pattern_rewrite) diff --git a/test/cpp/ir/pattern_rewrite/CMakeLists.txt b/test/cpp/ir/pattern_rewrite/CMakeLists.txt new file mode 100644 index 00000000000000..67cd3c8c0fd809 --- /dev/null +++ b/test/cpp/ir/pattern_rewrite/CMakeLists.txt @@ -0,0 +1,10 @@ +cc_test_old( + pattern_rewrite_test + SRCS + pattern_rewrite_test.cc + DEPS + new_pass + pattern_rewrite + pd_dialect + phi + gtest) diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc new file mode 100644 index 00000000000000..de8f25809acc4d --- /dev/null +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/builtin_dialect.h" +#include "paddle/ir/core/dialect.h" +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/pattern_rewrite/pattern_match.h" + +TEST(PatternBenefit, PatternBenefit) { + ir::PatternBenefit benefit1(1); + EXPECT_EQ(benefit1.benefit(), 1U); + ir::PatternBenefit benefit2(2); + EXPECT_EQ(benefit2.benefit(), 2U); + + EXPECT_TRUE(benefit2 > benefit1); + EXPECT_TRUE(benefit2 >= benefit1); + EXPECT_TRUE(benefit1 < benefit2); + EXPECT_TRUE(benefit1 <= benefit2); + EXPECT_TRUE(benefit1 != benefit2); + ir::PatternBenefit benefit3(2); + EXPECT_TRUE(benefit2 == benefit3); +} + +// Define op1. +class Operation1 : public ir::Op { + public: + using Op::Op; + static const char *name() { return "test.Operation1"; } + static constexpr uint32_t attributes_num = 2; + static const char *attributes_name[attributes_num]; + static void Verify(const std::vector &inputs, + const std::vector &outputs, + const ir::AttributeMap &attributes) { + if (attributes.count("op2_attr1") == 0 || + (!attributes.at("op2_attr1").isa())) { + throw("Type of attribute: parameter_name is not right."); + } + if (attributes.count("op2_attr2") == 0 || + (!attributes.at("op2_attr2").isa())) { + throw("Type of attribute: parameter_name is not right."); + } + } + static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } +}; +const char *Operation1::attributes_name[attributes_num] = {"op2_attr1", + "op2_attr2"}; + +// Define a dialect, op1 and op2 will be registered by this dialect. +class TestDialect : public ir::Dialect { + public: + explicit TestDialect(ir::IrContext *context) + : ir::Dialect(name(), context, ir::TypeId::get()) { + initialize(); + } + static const char *name() { return "test"; } + + private: + void initialize() { RegisterOps(); } +}; + +// TODO(wilber): Add logical when ir support erase, replace or update. +class TestPatternRewrite : public ir::OpRewritePattern { + public: + using ir::OpRewritePattern::OpRewritePattern; + + void Rewrite(Operation1 op, ir::PatternRewriter &rewriter) const override {} + bool Match(Operation1 op) const override { return false; } +}; +class TestPatternRewrite2 : public ir::OpRewritePattern { + public: + using ir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite( + Operation1 op, + ir::PatternRewriter &rewriter) const override { // NOLINT + return false; + } +}; + +TEST(RewritePattern, OpRewritePattern) { + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + auto *test_dialect = ctx->GetOrRegisterDialect(); + test_dialect->RegisterOp(); + + ir::RewritePatternSet ps(ctx); + ps.Add(ctx, 1); + EXPECT_EQ(ps.native_patterns().size(), 1U); + EXPECT_TRUE(ps.native_patterns().back()->debug_labels().empty()); + EXPECT_EQ(ps.native_patterns().back()->benefit(), 1U); + ps.AddWithLabel({"TestPatternRewrite2"}, ctx, 2); + EXPECT_EQ(ps.native_patterns().size(), 2U); + EXPECT_EQ(ps.native_patterns().back()->debug_labels()[0], + "TestPatternRewrite2"); + EXPECT_EQ(ps.native_patterns().back()->benefit(), 2U); + + ps.Clear(); + ps.Add(ctx, 2); + EXPECT_EQ(ps.native_patterns().size(), 2U); + EXPECT_EQ(ps.native_patterns()[0]->benefit(), 2U); + EXPECT_EQ(ps.native_patterns()[1]->benefit(), 2U); +}