From 68e64ee8db4dff9a786bf5a3bec8156a81d67644 Mon Sep 17 00:00:00 2001 From: kfeng123 <446100240@qq.com> Date: Tue, 18 Jul 2023 22:37:27 +0800 Subject: [PATCH] fix bug of dataflow pattern print --- include/tvm/relay/dataflow_pattern.h | 35 +++ src/relay/ir/dataflow_pattern.cc | 304 +++++++++++++++++++++------ 2 files changed, 275 insertions(+), 64 deletions(-) diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 46abee5d444f..5493e96b0d31 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -537,6 +537,41 @@ DFPattern IsTuple(const Array& fields); /*! \brief Syntatic Sugar for creating a TupleGetItemPattern*/ DFPattern IsTupleGetItem(const DFPattern tuple, int index = -1); +/*! \brief A printer class to print pattern. */ +class DFPatternPrinter : public ReprPrinter { + public: + std::stringstream string_stream{}; + + std::unordered_map, ObjectPtrHash, + ObjectPtrEqual> memo_{}; + std::vector recursed_patterns{}; + + DFPatternPrinter(std::ostream& stream) // NOLINT(*) + : ReprPrinter(stream) {} + TVM_DLL void Print(const ObjectRef& node); + using FType = NodeFunctor; + TVM_DLL static FType& vtable(); +}; + +inline std::ostream& operator<<(std::ostream& os, + const DFPattern& n) { // NOLINT(*) + std::stringstream string_stream{}, tmp_stream{}; + DFPatternPrinter printer{tmp_stream}; + printer.Print(n); + string_stream << "Main pattern is:" << std::endl; + string_stream << printer.string_stream.str(); + string_stream << std::endl; + string_stream << "Auxiliary patterns are:"; + for (const DFPattern& pat : printer.recursed_patterns) { + string_stream << std::endl; + string_stream << printer.memo_[pat].second; + } + os << string_stream.str(); + return os; +} + +String PrettyPrint(const DFPattern& pattern); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_DATAFLOW_PATTERN_H_ diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index 1f5dba6aca80..4e9f78f63402 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -27,6 +27,39 @@ namespace tvm { namespace relay { +DFPatternPrinter::FType& DFPatternPrinter::vtable() { + static FType inst; + return inst; +} + +String PrettyPrint(const DFPattern& pattern) { + std::stringstream string_stream{}; + string_stream << pattern; + return string_stream.str(); +} + +void DFPatternPrinter::Print(const ObjectRef& node) { + ICHECK(node.as()); + DFPattern pat = Downcast(node); + static const FType& f = vtable(); + string_stream.str(""); + if (!node.defined()) { + string_stream << "(nullptr)"; + } else if (memo_.find(pat) != memo_.end()) { + string_stream << "(invoke pattern id " << memo_[pat].first << ")"; + recursed_patterns.push_back(pat); + } else { + if (f.can_dispatch(node)) { + memo_.insert({pat, {memo_.size(), ""}}); + f(node, this); + memo_[pat].second = string_stream.str(); + } else { + // default value, output type key and addr. + string_stream << node->GetTypeKey() << "(" << node.get() << ")"; + } + } +} + ExprPattern::ExprPattern(Expr expr) { ObjectPtr n = make_object(); n->expr = std::move(expr); @@ -39,10 +72,11 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ExprPattern").set_body_typed([](Expr return ExprPattern(e); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->Print(node->expr); +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + ExprPattern pattern = Downcast(ref); + p->string_stream.str(""); + p->string_stream << pattern->expr; }); VarPattern::VarPattern(String name_hint) { @@ -57,10 +91,11 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.VarPattern").set_body_typed([](Strin return VarPattern(name_hint); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "VarPattern(" << node->name_hint() << ")"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + VarPattern pattern = Downcast(ref); + p->string_stream.str(""); + p->string_stream << "VarPattern(" << pattern->name_hint() << ")"; }); TVM_REGISTER_NODE_TYPE(ConstantPatternNode); @@ -70,9 +105,10 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ConstantPattern").set_body_typed([]( return c; }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - p->stream << "ConstantPattern()"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + p->string_stream.str(""); + p->string_stream << "ConstantPattern()"; }); CallPattern::CallPattern(DFPattern op, Array args) { @@ -87,10 +123,29 @@ TVM_REGISTER_NODE_TYPE(CallPatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.CallPattern") .set_body_typed([](DFPattern op, Array args) { return CallPattern(op, args); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "CallPatternNode(" << node->op << ", " << node->args << ")"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + CallPattern pattern = Downcast(ref); + + p->Print(pattern->op); + std::string op_pattern_string{p->string_stream.str()}; + + std::vector args_pattern_string{}; + for (const DFPattern& arg : pattern->args) { + p->Print(arg); + args_pattern_string.push_back(p->string_stream.str()); + } + + p->string_stream.str(""); + p->string_stream << "(id " << p->memo_[pattern].first << "): "; + p->string_stream << "CallPatternNode(" << op_pattern_string << ", ["; + for (size_t i = 0; i < args_pattern_string.size(); ++i) { + if (i != 0) { + p->string_stream << ", "; + } + p->string_stream << args_pattern_string[i]; + } + p->string_stream << "])"; }); FunctionPattern::FunctionPattern(Array params, DFPattern body) { @@ -106,10 +161,31 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.FunctionPattern") return FunctionPattern(params, body); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + FunctionPattern pattern = Downcast(ref); + + std::vector params_pattern_string{}; + for (const DFPattern& param : pattern->params) { + p->Print(param); + params_pattern_string.push_back(p->string_stream.str()); + } + + p->Print(pattern->body); + std::string body_pattern_string{p->string_stream.str()}; + + p->string_stream.str(""); + p->string_stream << "(id " << p->memo_[pattern].first << "): "; + + p->string_stream << "FunctionPatternNode(["; + for (size_t i = 0; i < params_pattern_string.size(); ++i) { + if (i != 0) { + p->string_stream << ", "; + } + p->string_stream << params_pattern_string[i]; + } + p->string_stream << "]"; + p->string_stream << ", " << body_pattern_string << ")"; }); LetPattern::LetPattern(DFPattern var, DFPattern value, DFPattern body) { @@ -127,10 +203,22 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.LetPattern") return LetPattern(var, value, body); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "LetPatternNode(" << node->var << ", " << node->value << ", " << node->body +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + LetPattern pattern = Downcast(ref); + + p->Print(pattern->var); + std::string var_pattern_string{p->string_stream.str()}; + + p->Print(pattern->value); + std::string value_pattern_string{p->string_stream.str()}; + + p->Print(pattern->body); + std::string body_pattern_string{p->string_stream.str()}; + + p->string_stream.str(""); + p->string_stream << "(id " << p->memo_[pattern].first << "): "; + p->string_stream << "LetPatternNode(" << var_pattern_string << ", " << value_pattern_string << ", " << body_pattern_string << ")"; }); @@ -149,11 +237,23 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.IfPattern") return IfPattern(cond, true_branch, false_branch); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IfPattern(" << node->cond << ", " << node->true_branch << ", " - << node->false_branch << ")"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + IfPattern pattern = Downcast(ref); + + p->Print(pattern->cond); + std::string cond_pattern_string{p->string_stream.str()}; + + p->Print(pattern->true_branch); + std::string true_branch_pattern_string{p->string_stream.str()}; + + p->Print(pattern->false_branch); + std::string false_branch_pattern_string{p->string_stream.str()}; + + p->string_stream.str(""); + p->string_stream << "(id " << p->memo_[pattern].first << "): "; + p->string_stream << "IfPattern(" << cond_pattern_string << ", " << true_branch_pattern_string << ", " + << false_branch_pattern_string << ")"; }); TuplePattern::TuplePattern(tvm::Array fields) { @@ -167,10 +267,28 @@ TVM_REGISTER_NODE_TYPE(TuplePatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TuplePattern") .set_body_typed([](tvm::Array fields) { return TuplePattern(fields); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TuplePattern(" << node->fields << ")"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + TuplePattern pattern = Downcast(ref); + + std::vector fields_pattern_string{}; + for (const DFPattern& field : pattern->fields) { + p->Print(field); + fields_pattern_string.push_back(p->string_stream.str()); + } + + p->string_stream.str(""); + p->string_stream << "(id " << p->memo_[pattern].first << "): "; + p->string_stream << "TuplePattern("; + p->string_stream << "["; + for (size_t i = 0; i < fields_pattern_string.size(); ++i) { + if (i != 0) { + p->string_stream << ", "; + } + p->string_stream << fields_pattern_string[i]; + } + p->string_stream << "]"; + p->string_stream << ")"; }); TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { @@ -185,10 +303,18 @@ TVM_REGISTER_NODE_TYPE(TupleGetItemPatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TupleGetItemPattern") .set_body_typed([](DFPattern tuple, int index) { return TupleGetItemPattern(tuple, index); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleGetItemPatternNode(" << node->tuple << ", " << node->index << ")"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + TupleGetItemPattern pattern = Downcast(ref); + + p->Print(pattern->tuple); + std::string tuple_pattern_string{p->string_stream.str()}; + + + p->string_stream.str(""); + p->string_stream << "(id " << p->memo_[pattern].first << "): "; + p->string_stream << "TupleGetItemPatternNode("; + p->string_stream << tuple_pattern_string << ", " << pattern->index << ")"; }); AltPattern::AltPattern(DFPattern left, DFPattern right) { @@ -203,10 +329,20 @@ TVM_REGISTER_NODE_TYPE(AltPatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AltPattern") .set_body_typed([](DFPattern left, DFPattern right) { return AltPattern(left, right); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "AltPattern(" << node->left << " | " << node->right << ")"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + AltPattern pattern = Downcast(ref); + + p->Print(pattern->left); + std::string left_pattern_string{p->string_stream.str()}; + + p->Print(pattern->right); + std::string right_pattern_string{p->string_stream.str()}; + + p->string_stream.str(""); + p->string_stream << "(id " << p->memo_[pattern].first << "): "; + p->string_stream << "AltPattern(" << left_pattern_string << " | " + << right_pattern_string << ")"; }); TVM_REGISTER_NODE_TYPE(WildcardPatternNode); @@ -216,9 +352,10 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern").set_body_typed([]( return w; }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - p->stream << "*"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + p->string_stream.str(""); + p->string_stream << "*"; }); TypePattern::TypePattern(DFPattern pattern, Type type) { @@ -233,10 +370,17 @@ TVM_REGISTER_NODE_TYPE(TypePatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TypePattern") .set_body_typed([](DFPattern pattern, Type type) { return TypePattern(pattern, type); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + TypePattern pattern = Downcast(ref); + + p->Print(pattern->pattern); + std::string pattern_pattern_string{p->string_stream.str()}; + + p->string_stream.str(""); + p->string_stream << "(id " << p->memo_[pattern].first << "): "; + p->string_stream << "TypePattern(" << pattern_pattern_string + << " has type " << pattern->type << ")"; }); ShapePattern::ShapePattern(DFPattern pattern, Array shape) { @@ -253,10 +397,15 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ShapePattern") return ShapePattern(pattern, shape); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + ShapePattern pattern = Downcast(ref); + + p->Print(pattern->pattern); + std::string pattern_pattern_string{p->string_stream.str()}; + + p->string_stream.str(""); + p->string_stream << "(id " << p->memo_[pattern].first << "): "; }); DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { @@ -273,10 +422,17 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DataTypePattern") return DataTypePattern(pattern, dtype); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypePattern(" << node->pattern << " has dtype " << node->dtype << ")"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + DataTypePattern pattern = Downcast(ref); + + p->Print(pattern->pattern); + std::string pattern_pattern_string{p->string_stream.str()}; + + p->string_stream.str(""); + p->string_stream << "(id " << p->memo_[pattern].first << "): "; + p->string_stream << "DataTypePattern(" << pattern_pattern_string + << " has dtype " << pattern->dtype << ")"; }); AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) { @@ -291,10 +447,17 @@ TVM_REGISTER_NODE_TYPE(AttrPatternNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AttrPattern") .set_body_typed([](DFPattern pattern, DictAttrs attrs) { return AttrPattern(pattern, attrs); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + AttrPattern pattern = Downcast(ref); + + p->Print(pattern->pattern); + std::string pattern_pattern_string{p->string_stream.str()}; + + p->string_stream.str(""); + p->string_stream << "(id " << p->memo_[pattern].first << "): "; + p->string_stream << "AttrPattern(" << pattern_pattern_string + << " has attributes " << pattern->attrs << ")"; }); DominatorPattern::DominatorPattern(DFPattern parent, DFPattern path, DFPattern child) { @@ -313,11 +476,24 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DominatorPattern") return DominatorPattern(parent, path, child); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "DominatorPattern(" << node->parent << ", " << node->path << ", " << node->child - << ")"; +TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, DFPatternPrinter* p) { + DominatorPattern pattern = Downcast(ref); + + p->Print(pattern->parent); + std::string parent_pattern_string{p->string_stream.str()}; + + p->Print(pattern->path); + std::string path_pattern_string{p->string_stream.str()}; + + p->Print(pattern->child); + std::string child_pattern_string{p->string_stream.str()}; + + p->string_stream.str(""); + p->string_stream << "(id " << p->memo_[pattern].first << "): "; + p->string_stream << "DominatorPattern(" << parent_pattern_string << ", " + << path_pattern_string << ", " << child_pattern_string + << ")"; }); // Syntatic Sugar