From 1c72398ccec43b4333c275123516cd9e96e933cb Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 2 Feb 2019 16:16:53 -0800 Subject: [PATCH 1/5] [EXPR] Expression-template based pattern matching. --- src/arithmetic/pattern_match.h | 664 ++++++++++++++++++++++++++++++++ src/pass/inject_copy_intrin.cc | 40 +- tests/cpp/pattern_match_test.cc | 112 ++++++ 3 files changed, 788 insertions(+), 28 deletions(-) create mode 100644 src/arithmetic/pattern_match.h create mode 100644 tests/cpp/pattern_match_test.cc diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h new file mode 100644 index 000000000000..e8c1260ba8bd --- /dev/null +++ b/src/arithmetic/pattern_match.h @@ -0,0 +1,664 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/arithmetic/pattern_match.h + * + * \brief Internal tool for expression-template based pattern matching. + * + * It helps to simplify pattern matching and rewrites. + * All the patterns are generated via expression template during compile time, + * so the result code should be as efficient as manually write pattern match code + * using as cast and checks. + * + * The code below gives shows how to use the pattern matcher. + * + * \code + * + * // max(x + z, y + z) => max(x, y) + z + * arith::PVar x, y, z; + * + * // The following code tries to match the declared pattern. + * // Match will fill the result of match into PVar if successful. + * // Note that z occurs twice in the pattern, + * // an equality check to ensure each occurance of z is equivalent + * // to each other. + * if (max(x + z, y + z).Match(expr)) { + * // Eval evaluates a pattern with the current matched value. + * return (max(x, y) + z).Eval(); + * } + * \endcode + */ +#ifndef TVM_ARITHMETIC_PATTERN_MATCH_H_ +#define TVM_ARITHMETIC_PATTERN_MATCH_H_ + +#include +#include + +namespace tvm { +namespace arith { +/*! + * \brief Base cass of all the patterns. + * + * There are two major member functions supported by each pattern. + * - Match: checks if value matches the pattern. + * - Eval: construct a new value based on matched values in PVar. + * + * We use curiously recurring template pattern. + * \tparam SubType The type if the child class. + */ +template +class Pattern { + public: + /*! + * \brief Check if value matches the current pattern. + * + * This call also populates the PVars with matched value. + * The values in PVars are valid until the next call to Match. + * + * \return whether value matches the pattern. + */ + template + bool Match(const NodeType& value) const { + self().InitMatch_(); + return self().Match_(value); + } + /*! \return subtype instance of current class. */ + const SubType& self() const { + return *static_cast(this); + } +}; + +/*! + * \brief Default deep equality checker + * \tparam T the comparison point. + */ +template +class PEqualChecker { + public: + bool operator()(const T& lhs, const T& rhs) const { + return lhs == rhs; + } +}; + +template<> +class PEqualChecker { + public: + bool operator()(const Expr& lhs, const Expr& rhs) const { + if (lhs.same_as(rhs)) return true; + return ir::Equal(lhs, rhs); + } +}; + +/*! + * \brief Pattern variable container. + * + * PVar is used as a "hole" in the pattern that can be matched. + * + * \tparam T the type of the hole. + */ +template +class PVar : public Pattern > { + public: + void InitMatch_() const { + filled_ = false; + } + + bool Match_(const T& value) const { + if (!filled_) { + value_ = value; + filled_ = true; + return true; + } else { + return PEqualChecker()(value_, value); + } + } + + T Eval() const { + CHECK(filled_); + return value_; + } + + private: + /*! \brief The matched value */ + mutable T value_; + /*! \brief whether the variable has been filled */ + mutable bool filled_{false}; +}; + +/*! + * \brief Constant Pattern variable container. + * + * \tparam T the type of the hole. + */ +template +class PConst : public Pattern > { + public: + PConst(T value) // NOLINT(*) + : value_(value) {} + + void InitMatch_() const {} + + bool Match_(const T& value) const { + return PEqualChecker()(value_, value); + } + + T Eval() const { + return value_; + } + private: + const T value_; +}; + +/*! + * \brief Pattern binary expression. + * \tparam NodeType The AST node type. + * \tparam TA The pattern type of the first operand. + * \tparam TB The pattern type of the second operand. + */ +template +class PBinaryExpr : + public Pattern > { + public: + PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {} + + void InitMatch_() const { + a_.InitMatch_(); + b_.InitMatch_(); + } + + bool Match_(const NodeRef& node) const { + if (const NodeType* ptr = node.as()) { + if (!a_.Match_(ptr->a)) return false; + if (!b_.Match_(ptr->b)) return false; + return true; + } else { + return false; + } + } + + Expr Eval() const { + return NodeType::make(a_.Eval(), b_.Eval()); + } + + private: + const TA& a_; + const TB& b_; +}; + + +#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \ + template \ + inline PBinaryExpr \ + FuncName(const Pattern& a, const Pattern& b) { \ + return PBinaryExpr(a.self(), b.self()); \ + } + +// arithmetic expressions +TVM_PATTERN_BINARY_OP(operator+, ir::Add); +TVM_PATTERN_BINARY_OP(operator-, ir::Sub); +TVM_PATTERN_BINARY_OP(operator*, ir::Mul); +TVM_PATTERN_BINARY_OP(operator/, ir::Div); +TVM_PATTERN_BINARY_OP(operator%, ir::Mod); +TVM_PATTERN_BINARY_OP(min, ir::Min); +TVM_PATTERN_BINARY_OP(max, ir::Max); + +// logical expressions +TVM_PATTERN_BINARY_OP(operator>, ir::GT); +TVM_PATTERN_BINARY_OP(operator>=, ir::GE); +TVM_PATTERN_BINARY_OP(operator<, ir::LT); +TVM_PATTERN_BINARY_OP(operator<=, ir::LE); +TVM_PATTERN_BINARY_OP(operator==, ir::EQ); +TVM_PATTERN_BINARY_OP(operator!=, ir::NE); +TVM_PATTERN_BINARY_OP(operator&&, ir::And); +TVM_PATTERN_BINARY_OP(operator||, ir::Or); + +/*! + * \brief Pattern not expression. + * \tparam TCond The pattern type of the condition. + * \tparam TA The pattern type of the true operand. + * \tparam TB The pattern type of the false operand. + */ +template +class PNotExpr : public Pattern > { + public: + explicit PNotExpr(const TA& value) + : value_(value) {} + + void InitMatch_() const { + value_.InitMatch_(); + } + + bool Match_(const NodeRef& node) const { + if (const ir::Not* ptr = node.as()) { + if (!value_.Match_(ptr->a)) return false; + return true; + } else { + return false; + } + } + + Expr Eval() const { + return ir::Not::make(value_.Eval()); + } + + private: + const TA& value_; +}; + +template +inline PNotExpr operator!(const TA& value) { + return PNotExpr(value.self()); +} + +// select +/*! + * \brief Pattern select expression. + * \tparam TCond The pattern type of the condition. + * \tparam TA The pattern type of the true operand. + * \tparam TB The pattern type of the false operand. + */ +template +class PSelectExpr : + public Pattern > { + public: + PSelectExpr(const TCond& condition, + const TA& true_value, + const TB& false_value) + : condition_(condition), + true_value_(true_value), + false_value_(false_value) {} + + void InitMatch_() const { + condition_.InitMatch_(); + true_value_.InitMatch_(); + false_value_.InitMatch_(); + } + + bool Match_(const NodeRef& node) const { + if (const ir::Select* ptr = node.as()) { + if (!condition_.Match_(ptr->condition)) return false; + if (!true_value_.Match_(ptr->true_value)) return false; + if (!false_value_.Match_(ptr->false_value)) return false; + return true; + } else { + return false; + } + } + + Expr Eval() const { + return ir::Select::make( + condition_.Eval(), true_value_.Eval(), false_value_.Eval()); + } + + private: + const TCond& condition_; + const TA& true_value_; + const TB& false_value_; +}; + +/*! + * \brief Construct a select pattern. + * + * \param condition The condition expression. + * \param true_value The value when condition is true. + * \param true_value The value when condition is false. + * + * \return The result pattern. + * + * \tparam TCond The pattern type of the condition. + * \tparam TA The pattern type of the true operand. + * \tparam TB The pattern type of the false operand. + */ +template +inline PSelectExpr +select(const Pattern& condition, + const Pattern& true_value, + const Pattern& false_value) { + return PSelectExpr( + condition.self(), true_value.self(), false_value.self()); +} + +/*! + * \brief Pattern cast expression. + * \tparam DType The Pattern type of dtype. + * \tparam TA The pattern type of the first operand. + */ +template +class PCastExpr : + public Pattern > { + public: + PCastExpr(const DType& dtype, const TA& value) + : dtype_(dtype), value_(value) { + } + + void InitMatch_() const { + dtype_.InitMatch_(); + value_.InitMatch_(); + } + + bool Match_(const NodeRef& node) const { + if (const ir::Cast* ptr = node.as()) { + if (!dtype_.Match_(ptr->type)) return false; + if (!value_.Match_(ptr->value)) return false; + return true; + } else { + return false; + } + } + + Expr Eval() const { + return ir::Cast::make(dtype_.Eval(), value_.Eval()); + } + + private: + const DType& dtype_; + const TA& value_; +}; + +/*! + * \brief Construct a cast pattern. + * + * \param dtype The target data type, can be PVar or PConst. + * \param value The input type. + * + * \return The result pattern. + * + * \tparam DType The pattern type of type. + * \tparam TA The pattern type of the true operand. + */ +template +inline PCastExpr +cast(const Pattern& dtype, const Pattern& value) { + return PCastExpr(dtype.self(), value.self()); +} + +/*! + * \brief Pattern ramp expression. + * \tparam TBase The pattern type of the base. + * \tparam TStride The pattern type of the stride. + * \tparam TLanes The pattern type of the lanes. + */ +template +class PRampExpr : + public Pattern > { + public: + PRampExpr(const TBase& base, + const TStride& stride, + const TLanes& lanes) + : base_(base), stride_(stride), lanes_(lanes) { + } + + void InitMatch_() const { + base_.InitMatch_(); + stride_.InitMatch_(); + lanes_.InitMatch_(); + } + + bool Match_(const NodeRef& node) const { + if (const ir::Ramp* ptr = node.as()) { + if (!base_.Match_(ptr->base)) return false; + if (!stride_.Match_(ptr->stride)) return false; + if (!lanes_.Match_(ptr->lanes)) return false; + return true; + } else { + return false; + } + } + + Expr Eval() const { + return ir::Ramp::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); + } + + private: + const TBase& base_; + const TStride& stride_; + const TLanes& lanes_; +}; + +/*! + * \brief Construct a ramp pattern. + * + * \param base The base pattern. + * \param stride The stride pattern. + * \param lanes The lanes pattern. + * + * \return The result pattern. + * + * \tparam TBase The pattern type of the base. + * \tparam TStride The pattern type of the stride. + * \tparam TLanes The pattern type of the lanes. + */ +template +inline PRampExpr +ramp(const Pattern& base, + const Pattern& stride, + const Pattern& lanes) { + return PRampExpr( + base.self(), stride.self(), lanes.self()); +} + +/*! + * \brief Pattern broadcast expression. + * \tparam TA The pattern type of the value. + * \tparam TLanes The pattern type of the lanes. + */ +template +class PBroadcastExpr : + public Pattern > { + public: + PBroadcastExpr(const TA& value, + const TLanes& lanes) + : value_(value), lanes_(lanes) { + } + + void InitMatch_() const { + value_.InitMatch_(); + lanes_.InitMatch_(); + } + + bool Match_(const NodeRef& node) const { + if (const ir::Broadcast* ptr = node.as()) { + if (!value_.Match_(ptr->value)) return false; + if (!lanes_.Match_(ptr->lanes)) return false; + return true; + } else { + return false; + } + } + + Expr Eval() const { + return ir::Broadcast::make(value_.Eval(), lanes_.Eval()); + } + + private: + const TA& value_; + const TLanes& lanes_; +}; + +/*! + * \brief Construct a broadcast pattern. + * + * \param value The value pattern. + * \param lanes The lanes pattern. + * + * \return The result pattern. + * + * \tparam TA The pattern type of the value. + * \tparam TLanes The pattern type of the lanes. + */ +template +inline PBroadcastExpr +broadcast(const Pattern& value, const Pattern& lanes) { + return PBroadcastExpr(value.self(), lanes.self()); +} + +// internal namespace +namespace detail { +// implementation details for CallExpr +template +struct tuple_for_each_dispatcher { + template + static void run(F& f, const TTuple& tuple) { // NOLINT(*) + f(I, std::get(tuple)); + tuple_for_each_dispatcher< + (I + 1) == std::tuple_size::value, (I + 1), F> + ::run(f, tuple); + } +}; + +template +struct tuple_for_each_dispatcher { + template + static void run(F& f, const TTuple& tuple) {} // NOLINT(*) +}; + +template +inline void tuple_for_each(F& f, const TTuple& tuple) { // NOLINT(*) + tuple_for_each_dispatcher::value == 0, 0, F> + ::run(f, tuple); +} + +struct PCallExprInitMatchFunctor { + template + void operator()(size_t i, const T& pattern) const { + pattern.InitMatch_(); + } +}; + +struct PCallExprMatchFunctor { + const ir::Call* call_; + bool matched_{true}; + + explicit PCallExprMatchFunctor(const ir::Call* call) + : call_(call) {} + + template + void operator()(size_t i, const T& pattern) { + matched_ = matched_ && pattern.Match_(call_->args[i]); + } +}; + +struct PCallExprEvalArgsFunctor { + Array args_; + + template + void operator()(size_t i, const T& pattern) { + args_.push_back(pattern.Eval()); + } +}; +} // namespace detail + +/*! + * \brief Pattern CallExpr expression. + * \tparam Op The operator functor class. + * \tparam TArgs The arguments. + * \note Op functor contains the name of the function and + * the implementation of Eval. + */ +template +class PCallExpr : + public Pattern > { + public: + explicit PCallExpr(const TArgs&... args) + : args_(args...) { + } + + void InitMatch_() const { + detail::PCallExprInitMatchFunctor finit; + detail::tuple_for_each(finit, args_); + } + + bool Match_(const NodeRef& node) const { + if (const ir::Call* ptr = node.as()) { + if (ptr->args.size() != sizeof...(TArgs)) return false; + if (ptr->name != Op::kName) return false; + detail::PCallExprMatchFunctor fmatch(ptr); + detail::tuple_for_each(fmatch, args_); + return fmatch.matched_; + } else { + return false; + } + } + + Expr Eval() const { + detail::PCallExprEvalArgsFunctor feval_args; + detail::tuple_for_each(feval_args, args_); + return Op::Eval(feval_args.args_); + } + + private: + const std::tuple args_; +}; + +// arithemetic intrinsics +#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static Expr Eval(Array args) { \ + return ir::Call::make(args[0].type(), kName, args, \ + ir::Call::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr \ + FuncName(const Pattern& a, const Pattern& b) { \ + return PCallExpr(a.self(), b.self()); \ + } + +TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left"); +TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, "shift_right"); +TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, "bitwise_and"); +TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or"); +TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); + +// unary intrinsics +#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static Expr Eval(Array args) { \ + return ir::Call::make(args[0].type(), kName, args, \ + ir::Call::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr \ + FuncName(const Pattern& a) { \ + return PCallExpr(a.self()); \ + } + +TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); + +// if_then_else +struct PIfThenElseOp { + static Expr Eval(Array args) { + return ir::Call::make( + args[1].type(), kName, args, + ir::Call::PureIntrinsic); + } + static constexpr const char* kName = "tvm_if_then_else"; +}; + +/*! + * \brief Construct a if_then_else pattern. + * + * \param cond The condition expression. + * \param true_value The value when condition is true. + * \param true_value The value when condition is false. + * + * \return The result pattern. + * + * \tparam TCond The pattern type of the condition. + * \tparam TA The pattern type of the true operand. + * \tparam TB The pattern type of the false operand. + */ +template +inline PCallExpr +if_then_else(const Pattern& cond, + const Pattern& true_value, + const Pattern& false_value) { + return PCallExpr( + cond.self(), true_value.self(), false_value.self()); +} + +} // namespace arith +} // namespace tvm +#endif // TVM_ARITHMETIC_PATTERN_MATCH_H_ diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index 65942fbee92c..7ca1d133bd2d 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -7,6 +7,7 @@ #include #include #include +#include "../arithmetic/pattern_match.h" namespace tvm { namespace ir { @@ -35,27 +36,8 @@ class CopyIntrinInjector : public IRMutator { } private: - bool MatchCondition(Expr expr, - Expr* cond, - Expr* true_value, - Expr* false_value) { - if (const auto* op = expr.as