From ae7e0a17633a76db570e43238669baad0c5d8741 Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Sat, 9 May 2020 23:35:58 +0800 Subject: [PATCH] [TIR][Printer] text format printer considering future parsing use (#5483) --- include/tvm/ir/module.h | 4 +- src/printer/meta_data.h | 9 + src/printer/relay_text_printer.cc | 1318 ++++++++--------- src/printer/text_printer.cc | 104 ++ src/printer/text_printer.h | 404 +++++ src/printer/tir_text_printer.cc | 597 ++++++++ .../unittest/test_arith_deduce_bound.py | 20 +- tests/python/unittest/test_te_schedule.py | 4 +- tests/python/unittest/test_tir_nodes.py | 44 +- 9 files changed, 1765 insertions(+), 739 deletions(-) create mode 100644 src/printer/text_printer.cc create mode 100644 src/printer/text_printer.h create mode 100644 src/printer/tir_text_printer.cc diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index d113860ddbce..ae78383878e3 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -363,7 +363,7 @@ TVM_DLL String PrettyPrint(const ObjectRef& node); * \return The text representation. */ TVM_DLL String AsText(const ObjectRef& node, - bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr); + bool show_meta_data = true, + runtime::TypedPackedFunc annotate = nullptr); } // namespace tvm #endif // TVM_IR_MODULE_H_ diff --git a/src/printer/meta_data.h b/src/printer/meta_data.h index d3906926363c..8bf58ecc0aad 100644 --- a/src/printer/meta_data.h +++ b/src/printer/meta_data.h @@ -108,6 +108,15 @@ class TextMetaDataContext { return meta_repr_[node]; } + /*! + * \brief Test whether a node has been put in meta + * \param node The query node + * \return whether the node has been put in meta + */ + bool InMeta(const ObjectRef& node) { + return meta_repr_.find(node) != meta_repr_.end(); + } + /*! * \brief Print a key value pair */ diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 2e675c8ed8f4..9e6abeed3155 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -18,7 +18,7 @@ */ /*! - * \file text_format_printer.cc + * \file relay_text_printer.cc * \brief Printer to print out the IR text format * that can be parsed by a parser. * @@ -41,785 +41,735 @@ #include "meta_data.h" #include "../relay/analysis/dependency_graph.h" #include "../ir/attr_functor.h" +#include "text_printer.h" namespace tvm { namespace relay { -class RelayTextPrinter : - public ExprFunctor, - public PatternFunctor, - public TypeFunctor, - public AttrFunctor { - public: - explicit RelayTextPrinter(bool show_meta_data, - runtime::TypedPackedFunc annotate) - : show_meta_data_(show_meta_data), - annotate_(annotate) {} - - /*! - * \brief Print additional info about expr in comment. - * \param expr The expression. - */ - Doc PrintOptionalInfo(const Expr& expr) { - Doc doc; - // default annotations - if (annotate_ == nullptr) { - if ((expr.as() || expr.as()) && expr->checked_type_.defined()) { - doc << " /* ty=" << Print(expr->checked_type()) << " */"; - } - } else { - std::string annotated_expr = annotate_(expr); - if (annotated_expr != "") { - doc << annotated_expr; - } +/*! + * \brief Print additional info about expr in comment. + * \param expr The expression. + */ +Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { + Doc doc; + // default annotations + if (annotate_ == nullptr) { + if ((expr.as() || expr.as()) && expr->checked_type_.defined()) { + doc << " /* ty=" << Print(expr->checked_type()) << " */"; + } + } else { + std::string annotated_expr = annotate_(expr); + if (annotated_expr != "") { + doc << annotated_expr; } - - return doc; } - // indent a new body - Doc PrintBody(const ObjectRef& node, int indent = 2) { - Doc doc; - Doc body; - doc << "{"; - doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine(); - doc << "}"; - return doc; - } + return doc; +} - // create a new scope by creating a new printer object. This allows temp var - // numbers to be reused and prevents hoisted vars from escaping too far - Doc PrintScope(const ObjectRef& node) { - // print in a new scope - doc_stack_.push_back(Doc()); - // must print first so doc_stack_.back() reference doesn't become stale - Doc doc = Print(node, false, true); - doc = doc_stack_.back() << doc; - doc_stack_.pop_back(); - return doc; +// indent a new body +Doc RelayTextPrinter::PrintBody(const ObjectRef& node, int indent) { + Doc doc; + Doc body; + doc << "{"; + doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine(); + doc << "}"; + return doc; +} + +// create a new scope by creating a new printer object. This allows temp var +// numbers to be reused and prevents hoisted vars from escaping too far +Doc RelayTextPrinter::PrintScope(const ObjectRef& node) { + // print in a new scope + doc_stack_.push_back(Doc()); + // must print first so doc_stack_.back() reference doesn't become stale + Doc doc = Print(node, false, true); + doc = doc_stack_.back() << doc; + doc_stack_.pop_back(); + return doc; +} + +Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) { + if (node->IsInstance() && + !node->IsInstance()) { + // Temporarily skip non-relay functions. + // TODO(tvm-team) enhance the code to work for all functions + } else if (node.as()) { + Expr expr = Downcast(node); + dg_ = DependencyGraph::Create(&arena_, expr); } - Doc PrintFinal(const ObjectRef& node) { - if (node->IsInstance() && - !node->IsInstance()) { - // Temporarily skip non-relay functions. - // TODO(tvm-team) enhance the code to work for all functions - } else if (node.as()) { - Expr expr = Downcast(node); - dg_ = DependencyGraph::Create(&arena_, expr); - } + Doc doc; + doc << PrintScope(node); + return doc; +} - Doc doc; - doc << PrintScope(node); - if (!meta_.empty()) { - doc << Doc::NewLine(); - if (show_meta_data_) { - // append meta data in the end. - doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection(); - } else { - doc << "// meta data omitted. you can use show_meta_data=True to include meta data"; - } - } - return doc; +Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) { + bool is_non_relay_func = + node->IsInstance() && + !node->IsInstance(); + if (node.as() && !is_non_relay_func) { + return PrintExpr(Downcast(node), meta, try_inline); + } else if (node.as()) { + return PrintType(Downcast(node), meta); + } else if (node.as()) { + return PrintPattern(Downcast(node), meta); + } else if (node.as()) { + return PrintMod(Downcast(node)); + } else { + // default module. + std::ostringstream os; + os << node; + return Doc::RawText(os.str()); } +} - std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); - std::vector PrintFuncAttrs(const Attrs& attrs); - - Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false) { - bool is_non_relay_func = - node->IsInstance() && - !node->IsInstance(); - if (node.as() && !is_non_relay_func) { - return PrintExpr(Downcast(node), meta, try_inline); - } else if (node.as()) { - return PrintType(Downcast(node), meta); - } else if (node.as()) { - return PrintPattern(Downcast(node), meta); - } else if (node.as()) { - return PrintMod(Downcast(node)); - } else { - // default module. +Doc RelayTextPrinter::TempVar(int n) { + Doc doc; + return doc << "%" << n; +} + +Doc RelayTextPrinter::AllocTemp() { + return TempVar(temp_var_counter_++); +} + +/*! + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ +Doc RelayTextPrinter::GetUniqueName(const std::string& prefix) { + std::string unique_prefix = prefix; + auto it = name_alloc_map_.find(prefix); + if (it != name_alloc_map_.end()) { + while (true) { std::ostringstream os; - os << node; - return Doc::RawText(os.str()); + os << prefix << (++it->second); + std::string name = os.str(); + if (name_alloc_map_.count(name) == 0) { + unique_prefix = name; + break; + } } } + name_alloc_map_[unique_prefix] = 0; + return Doc::Text(unique_prefix); +} - Doc TempVar(int n) { - Doc doc; - return doc << "%" << n; - } - - Doc AllocTemp() { - return TempVar(temp_var_counter_++); - } - - /*! - * \brief get a unique name with the corresponding prefix - * \param prefix The prefix of the name - * \return The returned name. - */ - Doc GetUniqueName(const std::string& prefix) { - std::string unique_prefix = prefix; - auto it = name_alloc_map_.find(prefix); - if (it != name_alloc_map_.end()) { - while (true) { - std::ostringstream os; - os << prefix << (++it->second); - std::string name = os.str(); - if (name_alloc_map_.count(name) == 0) { - unique_prefix = name; - break; - } - } - } - name_alloc_map_[unique_prefix] = 0; - return Doc::Text(unique_prefix); - } - - Doc Print(Kind k) { - switch (k) { - case kType: - return Doc::Text("Type"); - case kShapeVar: - return Doc::Text("Shape"); - case kBaseType: - return Doc::Text("BaseType"); - case kConstraint: - return Doc::Text("Constraint"); - case kAdtHandle: - return Doc::Text("AdtHandle"); - case kTypeData: - return Doc::Text("TypeData"); - default: - LOG(ERROR) << "Unknown Kind"; - throw; - } +Doc RelayTextPrinter::Print(Kind k) { + switch (k) { + case kType: + return Doc::Text("Type"); + case kShapeVar: + return Doc::Text("Shape"); + case kBaseType: + return Doc::Text("BaseType"); + case kConstraint: + return Doc::Text("Constraint"); + case kAdtHandle: + return Doc::Text("AdtHandle"); + case kTypeData: + return Doc::Text("TypeData"); + default: + LOG(ERROR) << "Unknown Kind"; + throw; } - /*! - * \brief Allocate name to a type variable. - * \param var The input type variable. - * \return The corresponding name. - */ - Doc AllocTypeVar(const TypeVar& var) { - if (memo_type_.count(var)) { - Doc val = memo_type_[var]; - val << "-malformed-ir"; - return val; - } - std::string name = var->name_hint; - if (name.length() == 0 || !std::isalpha(name[0])) { - name = "t" + name; - } - Doc val = GetUniqueName(name); - memo_type_[var] = val; - if (var->kind != kType) { - val << ": " << Print(var->kind); - } +} +/*! + * \brief Allocate name to a type variable. + * \param var The input type variable. + * \return The corresponding name. + */ +Doc RelayTextPrinter::AllocTypeVar(const TypeVar& var) { + if (memo_type_.count(var)) { + Doc val = memo_type_[var]; + val << "-malformed-ir"; return val; } + std::string name = var->name_hint; + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "t" + name; + } + Doc val = GetUniqueName(name); + memo_type_[var] = val; + if (var->kind != kType) { + val << ": " << Print(var->kind); + } + return val; +} - /*! - * \brief Allocate name to a variable. - * \param var The input variable. - * \return The corresponding name. - */ - Doc AllocVar(const Var& var) { - // still print if ir is malformed, but show the error. - if (memo_.count(var)) { - Doc val = memo_[var]; - val << "-malformed-ir"; - return val; - } - std::string name = var->name_hint(); - // always make sure first name is alpha - if (name.length() == 0 || !std::isalpha(name[0])) { - name = "v" + name; - } - Doc val = GetUniqueName("%" + name); - memo_[var] = val; - if (var->type_annotation.defined()) { - val << ": " << Print(var->type_annotation); - } +/*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ +Doc RelayTextPrinter::AllocVar(const Var& var) { + // still print if ir is malformed, but show the error. + if (memo_.count(var)) { + Doc val = memo_[var]; + val << "-malformed-ir"; return val; } - - bool IsUnique(const Expr& expr) { - auto it = dg_.expr_node.find(expr); - if (it == dg_.expr_node.end()) { - return true; - } else { - return !(it->second->parents.head && it->second->parents.head->next); - } + std::string name = var->name_hint(); + // always make sure first name is alpha + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "v" + name; } + Doc val = GetUniqueName("%" + name); + memo_[var] = val; + if (var->type_annotation.defined()) { + val << ": " << Print(var->type_annotation); + } + return val; +} - bool AlwaysInline(const Expr& expr) { - return expr.as() || expr.as() || expr.as() || - expr.as() || expr.as(); +bool RelayTextPrinter::IsUnique(const Expr& expr) { + auto it = dg_.expr_node.find(expr); + if (it == dg_.expr_node.end()) { + return true; + } else { + return !(it->second->parents.head && it->second->parents.head->next); } +} - //------------------------------------ - // Overload of Expr printing functions - //------------------------------------ - Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) { - // Exploit memoization to print GNF. - // The first time we visit an expression, we need to allocate a temp var - // for it. Every subsequent time we can just use its assigned variable. - // This works since hashing uses pointer equality. +bool RelayTextPrinter::AlwaysInline(const Expr& expr) { + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as(); +} - // determine whether to inline - bool inline_expr = AlwaysInline(expr); - if (try_inline) { - inline_expr |= IsUnique(expr); - } +//------------------------------------ +// Overload of Expr printing functions +//------------------------------------ +Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) { + // Exploit memoization to print GNF. + // The first time we visit an expression, we need to allocate a temp var + // for it. Every subsequent time we can just use its assigned variable. + // This works since hashing uses pointer equality. + + // determine whether to inline + bool inline_expr = AlwaysInline(expr); + if (try_inline) { + inline_expr |= IsUnique(expr); + } + + auto it = memo_.find(expr); + if (it != memo_.end()) return it->second; + + Doc printed_expr; + if (meta) { + printed_expr = meta_->GetMetaNode(GetRef(expr.get())); + } else if (!inline_expr && expr.as()) { + // wrap GNFed let in brackets + Doc body; + printed_expr << "("; + printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine(); + printed_expr << ")"; + } else { + printed_expr = VisitExpr(expr); + } - auto it = memo_.find(expr); - if (it != memo_.end()) return it->second; - - Doc printed_expr; - if (meta) { - printed_expr = meta_.GetMetaNode(GetRef(expr.get())); - } else if (!inline_expr && expr.as()) { - // wrap GNFed let in brackets - Doc body; - printed_expr << "("; - printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine(); - printed_expr << ")"; - } else { - printed_expr = VisitExpr(expr); - } + printed_expr << PrintOptionalInfo(expr); - printed_expr << PrintOptionalInfo(expr); - - // add expr to doc - if (expr.as()) { - // This is our first time visiting the var and we hit the VarNode case - // in the visitor. Thus the variable is free. - doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine(); - // Memoization is done in AllocVar. - return memo_[expr]; - } else if (inline_expr) { - memo_[expr] = printed_expr; - return printed_expr; - } else { - Doc temp_var = AllocTemp(); - memo_[expr] = temp_var; - doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine(); - return temp_var; - } + // add expr to doc + if (expr.as()) { + // This is our first time visiting the var and we hit the VarNode case + // in the visitor. Thus the variable is free. + doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine(); + // Memoization is done in AllocVar. + return memo_[expr]; + } else if (inline_expr) { + memo_[expr] = printed_expr; + return printed_expr; + } else { + Doc temp_var = AllocTemp(); + memo_[expr] = temp_var; + doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine(); + return temp_var; } +} + +// Should only be triggered when op is a free variable being visited for the +// first time. +Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { + return AllocVar(GetRef(op)); +} - // Should only be triggered when op is a free variable being visited for the - // first time. - Doc VisitExpr_(const VarNode* op) final { - return AllocVar(GetRef(op)); +/*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param value The value to be printed. + */ +template +Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) { + std::ostringstream os; + if (dtype == DataType::Int(32)) { + os << value; + } else if (dtype == DataType::Float(32)) { + os << value << 'f'; + } else if (dtype == DataType::Float(64)) { + os << value; + } else if (dtype == DataType::Bool()) { + return Doc::PyBoolLiteral(value != 0); + } else { + os << value; } + return Doc::Text(os.str()); +} - /*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param value The value to be printed. - */ - template - static Doc ScalarLiteral(DataType dtype, const T& value) { +Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) { + // Print out simple scalars directly. + if (op->is_scalar()) { std::ostringstream os; + DataType dtype = DataType(op->data->dtype); + CHECK_EQ(op->data->ctx.device_type, kDLCPU); if (dtype == DataType::Int(32)) { - os << value; + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } else if (dtype == DataType::Int(64)) { + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Float(32)) { - os << value << 'f'; + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Float(64)) { - os << value; + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Bool()) { - return Doc::PyBoolLiteral(value != 0); - } else { - os << value; + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } - return Doc::Text(os.str()); } + // default fall-back, record it as meta node. + Doc doc; + return doc << Print(GetRef(op), true); +} - Doc VisitExpr_(const ConstantNode* op) final { - // Print out simple scalars directly. - if (op->is_scalar()) { - std::ostringstream os; - DataType dtype = DataType(op->data->dtype); - CHECK_EQ(op->data->ctx.device_type, kDLCPU); - if (dtype == DataType::Int(32)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Int(64)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Float(32)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Float(64)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Bool()) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } - } - // default fall-back, record it as meta node. - Doc doc; - return doc << Print(GetRef(op), true); +Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) { + std::vector fields; + for (Expr field : op->fields) { + fields.push_back(Print(field)); } + Doc doc; + doc << "(" << Doc::Concat(fields); + // conform to python tuple format (1,) + if (op->fields.size() == 1) { + doc << ","; + } + return doc << ")"; +} - Doc VisitExpr_(const TupleNode* op) final { - std::vector fields; - for (Expr field : op->fields) { - fields.push_back(Print(field)); - } - Doc doc; - doc << "(" << Doc::Concat(fields); - // conform to python tuple format (1,) - if (op->fields.size() == 1) { - doc << ","; +Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) { + Doc doc; + return doc << Print(op->tuple) << "." << op->index; +} + +Doc RelayTextPrinter::VisitExpr_(const IfNode* op) { + Doc doc; + doc << "if (" << Print(op->cond) << ") "; + doc << PrintBody(op->true_branch); + doc << " else "; + doc << PrintBody(op->false_branch); + return doc; +} + +Doc RelayTextPrinter::VisitExpr_(const LetNode* op) { + Doc doc; + doc + << "let " + << AllocVar(op->var) + << " = " + << Print(op->value, false, true) + << ";" + << Doc::NewLine(); + // we use a scope here so GNF hoisting doesn't escape too far + // and nested, unique lets are not hoisted + doc << PrintScope(op->body); + return doc; +} + +Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { + Doc doc; + doc << prefix; + if (fn->type_params.size() > 0) { + doc << "["; + std::vector type_params; + for (const TypeVar& tv : fn->type_params) { + type_params.push_back(Doc::Text(tv->name_hint)); } - return doc << ")"; + doc << Doc::Concat(type_params); + doc << "]"; } - - Doc VisitExpr_(const TupleGetItemNode* op) final { - Doc doc; - return doc << Print(op->tuple) << "." << op->index; + doc << "("; + std::vector params; + for (Var param : fn->params) { + params.push_back(AllocVar(param)); } - - Doc VisitExpr_(const IfNode* op) final { - Doc doc; - doc << "if (" << Print(op->cond) << ") "; - doc << PrintBody(op->true_branch); - doc << " else "; - doc << PrintBody(op->false_branch); - return doc; + for (const Doc& d : PrintFuncAttrs(fn->attrs)) { + params.push_back(d); } - - Doc VisitExpr_(const LetNode* op) final { - Doc doc; - doc - << "let " - << AllocVar(op->var) - << " = " - << Print(op->value, false, true) - << ";" - << Doc::NewLine(); - // we use a scope here so GNF hoisting doesn't escape too far - // and nested, unique lets are not hoisted - doc << PrintScope(op->body); - return doc; + doc << Doc::Concat(params) << ") "; + if (fn->ret_type.defined()) { + doc << "-> " << Print(fn->ret_type) << " "; } + doc << PrintBody(fn->body); + return doc; +} - Doc PrintFunc(const Doc& prefix, const relay::Function& fn) { +Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const BaseFunc& base_func) { + if (auto* n = base_func.as()) { + return PrintFunc(prefix, GetRef(n)); + } else if (auto* n = base_func.as()) { + std::ostringstream os; + os << GetRef(n); + return Doc::RawText(os.str()); + } else { + // def @xyz = meta['ExternalFunc'][id] Doc doc; - doc << prefix; - if (fn->type_params.size() > 0) { - doc << "["; - std::vector type_params; - for (const TypeVar& tv : fn->type_params) { - type_params.push_back(Doc::Text(tv->name_hint)); - } - doc << Doc::Concat(type_params); - doc << "]"; - } - doc << "("; - std::vector params; - for (Var param : fn->params) { - params.push_back(AllocVar(param)); - } - for (const Doc& d : PrintFuncAttrs(fn->attrs)) { - params.push_back(d); - } - doc << Doc::Concat(params) << ") "; - if (fn->ret_type.defined()) { - doc << "-> " << Print(fn->ret_type) << " "; - } - doc << PrintBody(fn->body); + doc << prefix << " = " << meta_->GetMetaNode(base_func); return doc; } +} - Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) { - if (auto* n = base_func.as()) { - return PrintFunc(prefix, GetRef(n)); - } else if (auto* n = base_func.as()) { - std::ostringstream os; - os << GetRef(n); - return Doc::RawText(os.str()); - } else { - // def @xyz = meta['ExternalFunc'][id] - Doc doc; - doc << prefix << " = " << meta_.GetMetaNode(base_func); - return doc; +Doc RelayTextPrinter::PrintMod(const IRModule& mod) { + Doc doc; + int counter = 0; + // type definitions + for (const auto& kv : mod->type_definitions) { + if (counter++ != 0) { + doc << Doc::NewLine(); } + doc << Print(kv.second); + doc << Doc::NewLine(); } - - Doc PrintMod(const IRModule& mod) { - Doc doc; - int counter = 0; - // type definitions - for (const auto& kv : mod->type_definitions) { - if (counter++ != 0) { - doc << Doc::NewLine(); - } - doc << Print(kv.second); - doc << Doc::NewLine(); + // functions + for (const auto& kv : mod->functions) { + if (kv.second.as()) { + dg_ = DependencyGraph::Create(&arena_, kv.second); } - // functions - for (const auto& kv : mod->functions) { - if (kv.second.as()) { - dg_ = DependencyGraph::Create(&arena_, kv.second); - } - if (counter++ != 0) { - doc << Doc::NewLine(); - } - std::ostringstream os; - os << "def @" << kv.first->name_hint; - doc << PrintFunc(Doc::Text(os.str()), kv.second); + if (counter++ != 0) { doc << Doc::NewLine(); } - return doc; - } - - Doc VisitExpr_(const FunctionNode* op) final { - return PrintFunc(Doc::Text("fn "), GetRef(op)); + std::ostringstream os; + os << "def @" << kv.first->name_hint; + doc << PrintFunc(Doc::Text(os.str()), kv.second); + doc << Doc::NewLine(); } + return doc; +} - Doc VisitExpr_(const GlobalVarNode* op) final { - return Doc::Text('@' + op->name_hint); - } +Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) { + return PrintFunc(Doc::Text("fn "), GetRef(op)); +} - Doc VisitExpr_(const OpNode* op) final { - return Doc::Text(op->name); - } +Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { + return Doc::Text('@' + op->name_hint); +} - Doc VisitExpr_(const CallNode* op) final { - Doc doc; - // visit args first so they are lifted before the op - // this places op closer to its call site - std::vector args; - for (const Expr& arg : op->args) { - args.push_back(Print(arg)); - } - for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { - args.push_back(d); - } - const auto* cons_node = op->op.as(); - if (cons_node) { - doc << cons_node->name_hint; - } else { - doc << Print(op->op); - } +Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { + return Doc::Text(op->name); +} - if (cons_node && cons_node->inputs.size() == 0) { - // don't print as a call if it's a 0-arity cons - return doc; - } else { - return doc << "(" << Doc::Concat(args) << ")"; - } +Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { + Doc doc; + // visit args first so they are lifted before the op + // this places op closer to its call site + std::vector args; + for (const Expr& arg : op->args) { + args.push_back(Print(arg)); + } + for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { + args.push_back(d); + } + const auto* cons_node = op->op.as(); + if (cons_node) { + doc << cons_node->name_hint; + } else { + doc << Print(op->op); } - Doc VisitExpr_(const RefCreateNode* op) final { - Doc doc; - return doc << "ref(" << Print(op->value) << ")"; + if (cons_node && cons_node->inputs.size() == 0) { + // don't print as a call if it's a 0-arity cons + return doc; + } else { + return doc << "(" << Doc::Concat(args) << ")"; } +} - Doc VisitExpr_(const RefReadNode* op) final { - Doc doc; - return doc << Print(op->ref) << "^"; - } +Doc RelayTextPrinter::VisitExpr_(const RefCreateNode* op) { + Doc doc; + return doc << "ref(" << Print(op->value) << ")"; +} - Doc VisitExpr_(const RefWriteNode* op) final { - Doc doc; - return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")"; - } +Doc RelayTextPrinter::VisitExpr_(const RefReadNode* op) { + Doc doc; + return doc << Print(op->ref) << "^"; +} - Doc VisitExpr_(const MatchNode* op) final { - // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs. - Doc doc; - Doc body; - doc << "match"; - if (!op->complete) { - doc << "?"; - } - doc << " (" << Print(op->data) << ") {"; - std::vector clause_docs; - for (const auto& clause : op->clauses) { - Doc clause_doc; - clause_doc << PrintPattern(clause->lhs, false) << " => "; - Doc rhs_doc = PrintScope(clause->rhs); - if (clause->rhs.as()) { - // only add braces if there are multiple lines on the rhs - rhs_doc = Doc::Brace("{", rhs_doc, "}"); - } - clause_doc << rhs_doc << ","; - clause_docs.push_back(clause_doc); - } - doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine())) - << Doc::NewLine() << "}"; - return doc; - } +Doc RelayTextPrinter::VisitExpr_(const RefWriteNode* op) { + Doc doc; + return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")"; +} - Doc PrintPattern(const Pattern& pattern, bool meta) { - auto it = memo_pattern_.find(pattern); - if (it != memo_pattern_.end()) return it->second; - Doc printed_pattern; - if (meta) { - printed_pattern = meta_.GetMetaNode(GetRef(pattern.get())); - } else { - printed_pattern = VisitPattern(pattern); - } - memo_pattern_[pattern] = printed_pattern; - return printed_pattern; - } +Doc RelayTextPrinter::VisitExpr_(const MatchNode* op) { + // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs. + Doc doc; + Doc body; + doc << "match"; + if (!op->complete) { + doc << "?"; + } + doc << " (" << Print(op->data) << ") {"; + std::vector clause_docs; + for (const auto& clause : op->clauses) { + Doc clause_doc; + clause_doc << PrintPattern(clause->lhs, false) << " => "; + Doc rhs_doc = PrintScope(clause->rhs); + if (clause->rhs.as()) { + // only add braces if there are multiple lines on the rhs + rhs_doc = Doc::Brace("{", rhs_doc, "}"); + } + clause_doc << rhs_doc << ","; + clause_docs.push_back(clause_doc); + } + doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine())) + << Doc::NewLine() << "}"; + return doc; +} - Doc VisitPattern_(const PatternConstructorNode* p) final { - Doc doc; - doc << p->constructor->name_hint; - if (!p->patterns.empty()) { - doc << "("; - std::vector pats; - for (const auto& pat : p->patterns) { - pats.push_back(Print(pat)); - } - doc << Doc::Concat(pats) << ")"; - } - return doc; +Doc RelayTextPrinter::PrintPattern(const Pattern& pattern, bool meta) { + auto it = memo_pattern_.find(pattern); + if (it != memo_pattern_.end()) return it->second; + Doc printed_pattern; + if (meta) { + printed_pattern = meta_->GetMetaNode(GetRef(pattern.get())); + } else { + printed_pattern = VisitPattern(pattern); } + memo_pattern_[pattern] = printed_pattern; + return printed_pattern; +} - Doc VisitPattern_(const PatternTupleNode* pt) final { - Doc doc; +Doc RelayTextPrinter::VisitPattern_(const PatternConstructorNode* p) { + Doc doc; + doc << p->constructor->name_hint; + if (!p->patterns.empty()) { doc << "("; std::vector pats; - for (const auto& pat : pt->patterns) { + for (const auto& pat : p->patterns) { pats.push_back(Print(pat)); } doc << Doc::Concat(pats) << ")"; - return doc; } + return doc; +} - Doc VisitPattern_(const PatternWildcardNode* pw) final { - return Doc::Text("_"); +Doc RelayTextPrinter::VisitPattern_(const PatternTupleNode* pt) { + Doc doc; + doc << "("; + std::vector pats; + for (const auto& pat : pt->patterns) { + pats.push_back(Print(pat)); } + doc << Doc::Concat(pats) << ")"; + return doc; +} - Doc VisitPattern_(const PatternVarNode* pv) final { - return AllocVar(pv->var); - } +Doc RelayTextPrinter::VisitPattern_(const PatternWildcardNode* pw) { + return Doc::Text("_"); +} - Doc VisitExpr_(const ConstructorNode* n) final { - Doc doc; - doc << n->name_hint; - if (in_adt_def_ && n->inputs.size() != 0) { - doc << "("; - std::vector inputs; - for (Type input : n->inputs) { - inputs.push_back(Print(input)); - } - doc << Doc::Concat(inputs) << ")"; - } - return doc; - } +Doc RelayTextPrinter::VisitPattern_(const PatternVarNode* pv) { + return AllocVar(pv->var); +} - //------------------------------------ - // Overload of Type printing functions - //------------------------------------ - Doc PrintType(const Type& type, bool meta) { - auto it = memo_type_.find(type); - if (it != memo_type_.end()) return it->second; - Doc printed_type; - if (meta) { - printed_type = meta_.GetMetaNode(GetRef(type.get())); - } else { - printed_type = VisitType(type); +Doc RelayTextPrinter::VisitExpr_(const ConstructorNode* n) { + Doc doc; + doc << n->name_hint; + if (in_adt_def_ && n->inputs.size() != 0) { + doc << "("; + std::vector inputs; + for (Type input : n->inputs) { + inputs.push_back(Print(input)); } - memo_type_[type] = printed_type; - return printed_type; + doc << Doc::Concat(inputs) << ")"; } + return doc; +} - Doc VisitTypeDefault_(const Object* node) final { - // by default always print as meta data - return Print(GetRef(node), true); +//------------------------------------ +// Overload of Type printing functions +//------------------------------------ +Doc RelayTextPrinter::PrintType(const Type& type, bool meta) { + auto it = memo_type_.find(type); + if (it != memo_type_.end()) return it->second; + Doc printed_type; + if (meta) { + printed_type = meta_->GetMetaNode(GetRef(type.get())); + } else { + printed_type = VisitType(type); } + memo_type_[type] = printed_type; + return printed_type; +} - Doc VisitType_(const TypeVarNode* node) final { - return Doc::Text(node->name_hint); - } +Doc RelayTextPrinter::VisitTypeDefault_(const Object* node) { + // by default always print as meta data + return Print(GetRef(node), true); +} - Doc VisitType_(const GlobalTypeVarNode* node) final { - return Doc::Text(node->name_hint); - } +Doc RelayTextPrinter::VisitType_(const TypeVarNode* node) { + return Doc::Text(node->name_hint); +} - Doc VisitType_(const TypeCallNode* node) final { - Doc doc = PrintType(node->func, false); - std::vector args; - for (const Type& t : node->args) { - args.push_back(PrintType(t, false)); - } - doc << "["; - doc << Doc::Concat(args); - doc << "]"; - return doc; - } +Doc RelayTextPrinter::VisitType_(const GlobalTypeVarNode* node) { + return Doc::Text(node->name_hint); +} - Doc PrintDType(DataType dtype) { - return Doc::Text(runtime::DLDataType2String(dtype)); +Doc RelayTextPrinter::VisitType_(const TypeCallNode* node) { + Doc doc = PrintType(node->func, false); + std::vector args; + for (const Type& t : node->args) { + args.push_back(PrintType(t, false)); } + doc << "["; + doc << Doc::Concat(args); + doc << "]"; + return doc; +} - Doc VisitType_(const TensorTypeNode* node) final { - // scalar type - if (node->shape.size() == 0) { - return PrintDType(node->dtype); - } - Doc doc; - doc << "Tensor[("; - std::vector shapes; - for (ObjectRef shape : node->shape) { - shapes.push_back(PrintAttr(shape)); - } - doc << Doc::Concat(shapes); - return doc << "), " << PrintDType(node->dtype) << "]"; +Doc RelayTextPrinter::PrintDType(DataType dtype) { + return Doc::Text(runtime::DLDataType2String(dtype)); +} + +Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) { + // scalar type + if (node->shape.size() == 0) { + return PrintDType(node->dtype); } + Doc doc; + doc << "Tensor[("; + std::vector shapes; + for (ObjectRef shape : node->shape) { + shapes.push_back(PrintAttr(shape)); + } + doc << Doc::Concat(shapes); + return doc << "), " << PrintDType(node->dtype) << "]"; +} - Doc VisitType_(const TupleTypeNode* node) final { - std::vector fields; - for (Type field : node->fields) { - fields.push_back(Print(field)); - } - Doc doc; - doc << "(" << Doc::Concat(fields); - // conform to python tuple format (1,) - if (node->fields.size() == 1) { - doc << ","; - } - return doc << ")"; +Doc RelayTextPrinter::VisitType_(const TupleTypeNode* node) { + std::vector fields; + for (Type field : node->fields) { + fields.push_back(Print(field)); + } + Doc doc; + doc << "(" << Doc::Concat(fields); + // conform to python tuple format (1,) + if (node->fields.size() == 1) { + doc << ","; } + return doc << ")"; +} - Doc VisitType_(const FuncTypeNode* node) final { - Doc doc; - doc << "fn "; - if (node->type_params.size() != 0) { - doc << "["; - std::vector type_params; - for (Type type_param : node->type_params) { - type_params.push_back(Print(type_param)); - } - doc << Doc::Concat(type_params); - doc << "]"; - } - std::vector arg_types; - for (Type arg_type : node->arg_types) { - arg_types.push_back(Print(arg_type)); +Doc RelayTextPrinter::VisitType_(const FuncTypeNode* node) { + Doc doc; + doc << "fn "; + if (node->type_params.size() != 0) { + doc << "["; + std::vector type_params; + for (Type type_param : node->type_params) { + type_params.push_back(Print(type_param)); } - return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type); + doc << Doc::Concat(type_params); + doc << "]"; } - - Doc VisitType_(const RelayRefTypeNode* node) final { - Doc doc; - return doc << "ref(" << Print(node->value) << ")"; + std::vector arg_types; + for (Type arg_type : node->arg_types) { + arg_types.push_back(Print(arg_type)); } + return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type); +} - Doc VisitType_(const TypeDataNode* node) final { - in_adt_def_ = true; - Doc doc; - doc << "type " << Print(node->header); - - // type vars - if (node->type_vars.size() != 0) { - doc << "["; - std::vector type_vars; - for (Type type_var : node->type_vars) { - type_vars.push_back(Print(type_var)); - } - doc << Doc::Concat(type_vars) << "]"; - } - doc << " "; +Doc RelayTextPrinter::VisitType_(const RelayRefTypeNode* node) { + Doc doc; + return doc << "ref(" << Print(node->value) << ")"; +} - std::vector constructor_docs; - for (Constructor constructor : node->constructors) { - constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true)); - } - Doc separator; - separator << "," << Doc::NewLine(); - Doc adt_body; - adt_body << Doc::Concat(constructor_docs, separator); - // add trailing comma if there are any constructors - if (!constructor_docs.empty()) { - adt_body << ","; +Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) { + in_adt_def_ = true; + Doc doc; + doc << "type " << Print(node->header); + + // type vars + if (node->type_vars.size() != 0) { + doc << "["; + std::vector type_vars; + for (Type type_var : node->type_vars) { + type_vars.push_back(Print(type_var)); } - doc << Doc::Brace("{", adt_body, "}"); - in_adt_def_ = false; - return doc; + doc << Doc::Concat(type_vars) << "]"; } + doc << " "; - //------------------------------------ - // Overload of Attr printing functions - //------------------------------------ + std::vector constructor_docs; + for (Constructor constructor : node->constructors) { + constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true)); + } + Doc separator; + separator << "," << Doc::NewLine(); + Doc adt_body; + adt_body << Doc::Concat(constructor_docs, separator); + // add trailing comma if there are any constructors + if (!constructor_docs.empty()) { + adt_body << ","; + } + doc << Doc::Brace("{", adt_body, "}"); + in_adt_def_ = false; + return doc; +} - Doc PrintAttr(const ObjectRef& value, bool meta = false) { - if (value.defined()) { - Doc printed_attr; - if (value.as()) { - printed_attr << "?"; - } else if (meta) { - printed_attr = meta_.GetMetaNode(Downcast(value)); - } else { - printed_attr = VisitAttr(value); - } - return printed_attr; +//------------------------------------ +// Overload of Attr printing functions +//------------------------------------ + +Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { + if (value.defined()) { + Doc printed_attr; + if (value.as()) { + printed_attr << "?"; + } else if (meta) { + printed_attr = meta_->GetMetaNode(Downcast(value)); } else { - return Doc::Text("None"); + printed_attr = VisitAttr(value); } + return printed_attr; + } else { + return Doc::Text("None"); } +} - Doc VisitAttrDefault_(const Object* op) final { - return PrintAttr(GetRef(op), true); - } +Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) { + return PrintAttr(GetRef(op), true); +} - Doc VisitAttr_(const ArrayNode* op) final { - Doc doc; - doc << "["; - std::vector arr_vals; - for (auto val : op->data) { - arr_vals.push_back(PrintAttr(val)); - } - doc << Doc::Concat(arr_vals); - doc << "]"; - return doc; - } +Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { + Doc doc; + doc << "["; + std::vector arr_vals; + for (auto val : op->data) { + arr_vals.push_back(PrintAttr(val)); + } + doc << Doc::Concat(arr_vals); + doc << "]"; + return doc; +} - Doc VisitAttr_(const tir::IntImmNode* op) final { - return ScalarLiteral(op->dtype, op->value); - } +Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) { + return ScalarLiteral(op->dtype, op->value); +} - Doc VisitAttr_(const tir::FloatImmNode* op) final { - return ScalarLiteral(op->dtype, op->value); - } +Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) { + return ScalarLiteral(op->dtype, op->value); +} - Doc VisitAttr_(const tir::StringImmNode* op) final { - return Doc::StrLiteral(op->value); - } +Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) { + return Doc::StrLiteral(op->value); +} - private: - /*! \brief Whether to print meta data. */ - bool show_meta_data_; - /*! \brief additional comment function */ - runtime::TypedPackedFunc annotate_; - /*! \brief Stack of docs to implement scoped GNFing. */ - std::vector doc_stack_{}; - /*! \brief Map from Expr to Doc */ - std::unordered_map memo_; - /*! \brief Map from Type to Doc */ - std::unordered_map memo_type_; - /*! \brief Map from Type to Doc */ - std::unordered_map memo_pattern_; - /*! \brief name allocation map */ - std::unordered_map name_alloc_map_; - /*! \brief meta data context */ - TextMetaDataContext meta_; - /*! \brief counter of temporary variable */ - size_t temp_var_counter_{0}; - /*! \brief whether the printer is currently in an ADT definition */ - bool in_adt_def_; - /*! \brief arena for dependency graph */ - support::Arena arena_; - /*! \brief dependency graph of the expr */ - DependencyGraph dg_; - class AttrPrinter; - friend class AttrPrinter; -}; /*! * \brief Attribute printer which prints the attributes in the call. @@ -883,7 +833,7 @@ std::vector RelayTextPrinter::PrintCallAttrs( if (op_node && (attrs->type_index() != op_node->attrs_type_index)) { // fallback Doc doc; - doc << meta_.GetMetaNode(attrs); + doc << meta_->GetMetaNode(attrs); docs.push_back(doc); return docs; } else { @@ -905,44 +855,6 @@ std::vector RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) { } return docs; } -} // namespace relay -static const char* kSemVer = "v0.0.4"; - -// TODO(tvm-team): split into files, related: arith/analyzer.h -// -// - text_printer.h (common header) -// - text_printer.cc (prints modules dispatch into relay and tir files) -// - type_text_printer.cc(specific printing logics for types, -// can also consider put under type_text_printer) -// - Implements AsText -// - relay_text_printer.cc (specific printing logics for relay) -// - tir_text_printer.cc (specific printing logics for TIR) -String PrettyPrint(const ObjectRef& node) { - Doc doc; - doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node); - return doc.str(); -} - -String AsText(const ObjectRef& node, - bool show_meta_data, - runtime::TypedPackedFunc annotate) { - Doc doc; - doc << kSemVer << Doc::NewLine(); - runtime::TypedPackedFunc ftyped = nullptr; - if (annotate != nullptr) { - ftyped = runtime::TypedPackedFunc( - [&annotate](const ObjectRef& expr) -> std::string { - return annotate(expr); - }); - } - doc << relay::RelayTextPrinter(show_meta_data, ftyped).PrintFinal(node); - return doc.str(); -} - -TVM_REGISTER_GLOBAL("ir.PrettyPrint") -.set_body_typed(PrettyPrint); - -TVM_REGISTER_GLOBAL("ir.AsText") -.set_body_typed(AsText); +} // namespace relay } // namespace tvm diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc new file mode 100644 index 000000000000..592aabe0d8fd --- /dev/null +++ b/src/printer/text_printer.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file text_printer.cc + * \brief Printer to print out the unified IR text format + * that can be parsed by a parser. + */ + +#include +#include +#include "text_printer.h" + +namespace tvm { + +static const char* kSemVer = "v0.0.4"; + +// TODO(tvm-team): split into files, related: arith/analyzer.h +// +// - text_printer.h (common header) +// - text_printer.cc (prints modules dispatch into relay and tir files) +// - type_text_printer.cc(specific printing logics for types, +// can also consider put under type_text_printer) +// - Implements AsText +// - relay_text_printer.cc (specific printing logics for relay) +// - tir_text_printer.cc (specific printing logics for TIR) + +Doc TextPrinter::PrintMod(const IRModule& mod) { + Doc doc; + int counter = 0; + // type definitions + for (const auto& kv : mod->type_definitions) { + if (counter++ != 0) { + doc << Doc::NewLine(); + } + doc << relay_text_printer_.Print(kv.second); + doc << Doc::NewLine(); + } + // functions + for (const auto& kv : mod->functions) { + if (kv.second.as()) { + relay_text_printer_.dg_ = + relay::DependencyGraph::Create(&relay_text_printer_.arena_, kv.second); + } + if (counter++ != 0) { + doc << Doc::NewLine(); + } + if (kv.second.as()) { + std::ostringstream os; + os << "def @" << kv.first->name_hint; + doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second); + } else if (kv.second.as()) { + doc << tir_text_printer_.PrintPrimFunc(Downcast(kv.second)); + } + doc << Doc::NewLine(); + } + return doc; +} + +String PrettyPrint(const ObjectRef& node) { + Doc doc; + doc << TextPrinter(false, nullptr).PrintFinal(node); + return doc.str(); +} + +String AsText(const ObjectRef& node, + bool show_meta_data, + runtime::TypedPackedFunc annotate) { + Doc doc; + doc << kSemVer << Doc::NewLine(); + runtime::TypedPackedFunc ftyped = nullptr; + if (annotate != nullptr) { + ftyped = runtime::TypedPackedFunc( + [&annotate](const ObjectRef& expr) -> std::string { + return annotate(expr); + }); + } + doc << TextPrinter(show_meta_data, ftyped).PrintFinal(node); + return doc.str(); +} + +TVM_REGISTER_GLOBAL("ir.PrettyPrint") +.set_body_typed(PrettyPrint); + +TVM_REGISTER_GLOBAL("ir.AsText") +.set_body_typed(AsText); + +} // namespace tvm diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h new file mode 100644 index 000000000000..63767afe6eee --- /dev/null +++ b/src/printer/text_printer.h @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file text_printer.h + * \brief Printer to print out the unified IR text format + * that can be parsed by a parser. + */ + +#ifndef TVM_PRINTER_TEXT_PRINTER_H_ +#define TVM_PRINTER_TEXT_PRINTER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../relay/analysis/dependency_graph.h" +#include "../ir/attr_functor.h" + +#include "doc.h" +#include "meta_data.h" +#include "text_printer.h" + +namespace tvm { +class TextPrinter; +} // namespace tvm + +namespace tvm { +namespace relay { + +class RelayTextPrinter : + public ExprFunctor, + public PatternFunctor, + public TypeFunctor, + public AttrFunctor { + public: + explicit RelayTextPrinter(bool show_meta_data, + TextMetaDataContext* meta, + runtime::TypedPackedFunc annotate) + : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {} + + /*! + * \brief Print additional info about expr in comment. + * \param expr The expression. + */ + Doc PrintOptionalInfo(const Expr& expr); + // indent a new body + Doc PrintBody(const ObjectRef& node, int indent = 2); + // create a new scope by creating a new printer object. This allows temp var + // numbers to be reused and prevents hoisted vars from escaping too far + Doc PrintScope(const ObjectRef& node); + Doc PrintFinal(const ObjectRef& node); + std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); + std::vector PrintFuncAttrs(const Attrs& attrs); + + Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); + + Doc TempVar(int n); + Doc AllocTemp(); + /*! + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ + Doc GetUniqueName(const std::string& prefix); + Doc Print(Kind k); + /*! + * \brief Allocate name to a type variable. + * \param var The input type variable. + * \return The corresponding name. + */ + Doc AllocTypeVar(const TypeVar& var); + /*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ + Doc AllocVar(const Var& var); + bool IsUnique(const Expr& expr); + bool AlwaysInline(const Expr& expr); + + Doc PrintFunc(const Doc& prefix, const relay::Function& fn); + Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func); + Doc PrintMod(const IRModule& mod); + + //------------------------------------ + // Overload of Expr printing functions + //------------------------------------ + Doc PrintExpr(const Expr& expr, bool meta, bool try_inline); + // Should only be triggered when op is a free variable being visited for the + // first time. + Doc VisitExpr_(const VarNode* op) final; + /*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param value The value to be printed. + */ + template + static Doc ScalarLiteral(DataType dtype, const T& value); + Doc VisitExpr_(const ConstantNode* op) final; + Doc VisitExpr_(const TupleNode* op) final; + Doc VisitExpr_(const TupleGetItemNode* op) final; + Doc VisitExpr_(const IfNode* op) final; + Doc VisitExpr_(const LetNode* op) final; + Doc VisitExpr_(const FunctionNode* op) final; + Doc VisitExpr_(const GlobalVarNode* op) final; + Doc VisitExpr_(const OpNode* op) final; + Doc VisitExpr_(const CallNode* op) final; + Doc VisitExpr_(const RefCreateNode* op) final; + Doc VisitExpr_(const RefReadNode* op) final; + Doc VisitExpr_(const RefWriteNode* op) final; + Doc VisitExpr_(const MatchNode* op) final; + Doc PrintPattern(const Pattern& pattern, bool meta); + Doc VisitPattern_(const PatternConstructorNode* p) final; + Doc VisitPattern_(const PatternTupleNode* pt) final; + Doc VisitPattern_(const PatternWildcardNode* pw) final; + Doc VisitPattern_(const PatternVarNode* pv) final; + Doc VisitExpr_(const ConstructorNode* n) final; + //------------------------------------ + // Overload of Type printing functions + //------------------------------------ + Doc PrintType(const Type& type, bool meta); + Doc VisitTypeDefault_(const Object* node) final; + Doc VisitType_(const TypeVarNode* node) final; + Doc VisitType_(const GlobalTypeVarNode* node); + Doc VisitType_(const TypeCallNode* node) final; + Doc PrintDType(DataType dtype); + Doc VisitType_(const TensorTypeNode* node) final; + Doc VisitType_(const TupleTypeNode* node) final; + Doc VisitType_(const FuncTypeNode* node) final; + Doc VisitType_(const RelayRefTypeNode* node) final; + Doc VisitType_(const TypeDataNode* node) final; + //------------------------------------ + // Overload of Attr printing functions + //------------------------------------ + Doc PrintAttr(const ObjectRef& value, bool meta = false); + Doc VisitAttrDefault_(const Object* op) final; + Doc VisitAttr_(const ArrayNode* op) final; + Doc VisitAttr_(const tir::IntImmNode* op) final; + Doc VisitAttr_(const tir::FloatImmNode* op) final; + Doc VisitAttr_(const tir::StringImmNode* op) final; + + private: + /*! \brief Whether to print meta data. */ + bool show_meta_data_; + /*! \brief additional comment function */ + runtime::TypedPackedFunc annotate_; + /*! \brief Stack of docs to implement scoped GNFing. */ + std::vector doc_stack_{}; + /*! \brief Map from Expr to Doc */ + std::unordered_map memo_; + /*! \brief Map from Type to Doc */ + std::unordered_map memo_type_; + /*! \brief Map from Type to Doc */ + std::unordered_map memo_pattern_; + /*! \brief name allocation map */ + std::unordered_map name_alloc_map_; + /*! \brief meta data context */ + TextMetaDataContext* meta_; + /*! \brief counter of temporary variable */ + size_t temp_var_counter_{0}; + /*! \brief whether the printer is currently in an ADT definition */ + bool in_adt_def_; + /*! \brief arena for dependency graph */ + support::Arena arena_; + /*! \brief dependency graph of the expr */ + DependencyGraph dg_; + class AttrPrinter; + friend class AttrPrinter; + friend class tvm::TextPrinter; +}; + +} // namespace relay +} // namespace tvm + +namespace tvm { +namespace tir { + +/*! + * \brief Meta node collector + * If we decide to put some node into meta, then all the sub-nodes inside + * it need to be put in meta as well, since when parsing we need to know + * whether two refs are the same + */ +class MetaCollector : public StmtExprVisitor { + public: + explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {} + + void Collect(const ObjectRef& n) { + // these nodes can be print directly(StringLiteral or use identifier to identify) + if (!n.defined() || n.as() || n.as() || n.as() + || n.as() || n.as() || n.as()) { + return; + } + if (n->IsInstance()) { + VisitStmt(Downcast(n)); + } else if (n->IsInstance()) { + VisitExpr(Downcast(n)); + } + } + + void VisitStmt(const Stmt& n) override { + meta_->GetMetaNode(n); + StmtVisitor::VisitStmt(n); + } + + void VisitExpr(const PrimExpr& n) override { + meta_->GetMetaNode(n); + ExprVisitor::VisitExpr(n); + } + + private: + TextMetaDataContext* meta_; +}; + +class TIRTextPrinter : public StmtFunctor, + public ExprFunctor, + public TypeFunctor { + public: + explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta) + : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {} + + /*! \brief Print the node */ + Doc Print(const ObjectRef& node); + + private: + /*! \brief whether show meta data */ + bool show_meta_; + /*! \brief meta data context */ + TextMetaDataContext* meta_; + /*! \brief meta collector */ + MetaCollector meta_collector_; + /*! \brief Map from Var to Doc */ + std::unordered_map memo_var_; + /*! \brief Map from Buffer to Doc */ + std::unordered_map memo_buf_; + /*! \brief name allocation map */ + std::unordered_map name_alloc_map_; + + friend class tvm::TextPrinter; + + Doc VisitExpr_(const IntImmNode* op) override; + Doc VisitExpr_(const FloatImmNode* op) override; + Doc VisitExpr_(const StringImmNode* op) override; + Doc VisitExpr_(const CastNode* op) override; + Doc VisitExpr_(const VarNode* op) override; + Doc VisitExpr_(const AddNode* op) override; + Doc VisitExpr_(const SubNode* op) override; + Doc VisitExpr_(const MulNode* op) override; + Doc VisitExpr_(const DivNode* op) override; + Doc VisitExpr_(const ModNode* op) override; + Doc VisitExpr_(const FloorDivNode* op) override; + Doc VisitExpr_(const FloorModNode* op) override; + Doc VisitExpr_(const MinNode* op) override; + Doc VisitExpr_(const MaxNode* op) override; + Doc VisitExpr_(const EQNode* op) override; + Doc VisitExpr_(const NENode* op) override; + Doc VisitExpr_(const LTNode* op) override; + Doc VisitExpr_(const LENode* op) override; + Doc VisitExpr_(const GTNode* op) override; + Doc VisitExpr_(const GENode* op) override; + Doc VisitExpr_(const AndNode* op) override; + Doc VisitExpr_(const OrNode* op) override; + Doc VisitExpr_(const NotNode* op) override; + Doc VisitExpr_(const SelectNode* op) override; + Doc VisitExpr_(const BufferLoadNode* op) override; + Doc VisitExpr_(const LoadNode* op) override; + Doc VisitExpr_(const RampNode* op) override; + Doc VisitExpr_(const BroadcastNode* op) override; + Doc VisitExpr_(const LetNode* op) override; + Doc VisitExpr_(const CallNode* op) override; + Doc VisitExpr_(const ShuffleNode* op) override; + Doc VisitExpr_(const ReduceNode* op) override; + Doc VisitExprDefault_(const Object* op) override; + + Doc VisitStmt_(const LetStmtNode* op) override; + Doc VisitStmt_(const AttrStmtNode* op) override; + Doc VisitStmt_(const AssertStmtNode* op) override; + Doc VisitStmt_(const StoreNode* op) override; + Doc VisitStmt_(const BufferStoreNode* op) override; + Doc VisitStmt_(const BufferRealizeNode* op) override; + Doc VisitStmt_(const AllocateNode* op) override; + Doc VisitStmt_(const FreeNode* op) override; + Doc VisitStmt_(const IfThenElseNode* op) override; + Doc VisitStmt_(const SeqStmtNode* op) override; + Doc VisitStmt_(const EvaluateNode* op) override; + Doc VisitStmt_(const ForNode* op) override; + Doc VisitStmt_(const PrefetchNode* op) override; + Doc VisitStmtDefault_(const Object* op) override; + + Doc VisitType_(const PrimTypeNode* node) override; + Doc VisitType_(const PointerTypeNode* node) override; + Doc VisitType_(const TupleTypeNode* node) override; + + Doc PrintIRModule(const IRModule& module); + Doc PrintPrimFunc(const PrimFunc& primFunc); + Doc PrintArray(const ArrayNode* op); + Doc PrintIterVar(const IterVarNode* op); + Doc PrintRange(const RangeNode* op); + Doc PrintBuffer(const BufferNode* op); + Doc PrintString(const StringObj* op) { + return Doc::StrLiteral(op->data); + } + + /*! + * \brief special method to print out data type + * \param dtype The data type + */ + static Doc PrintDType(DataType dtype); + /*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param data The pointer to hold the data. + */ + template + static Doc PrintConstScalar(DataType dtype, const T& data); + Doc GetUniqueName(std::string prefix); + Doc AllocVar(const Var& var); + Doc AllocBuf(const Buffer& buffer); + /*! + * \brief special method to render vectors of docs with a separator + * \param vec vector of docs + * \param sep separator + */ + static Doc PrintSep(const std::vector& vec, const Doc& sep); + Doc PrintBody(const Stmt& body, bool indent = true); +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { + +class TextPrinter { + public: + explicit TextPrinter(bool show_meta_data, + const runtime::TypedPackedFunc& annotate) + : show_meta_data_(show_meta_data), annotate_(annotate), + relay_text_printer_(show_meta_data, &meta_, annotate), + tir_text_printer_(show_meta_data, &meta_) {} + + /*! \brief whether show meta data */ + bool show_meta_data_; + /*! \brief meta data context */ + TextMetaDataContext meta_; + /*! \brief additional comment function */ + runtime::TypedPackedFunc annotate_; + /*! \brief Relay Text Printer */ + relay::RelayTextPrinter relay_text_printer_; + /*! \brief TIR Text Printer */ + tir::TIRTextPrinter tir_text_printer_; + + Doc PrintFinal(const ObjectRef& node) { + Doc doc; + if (node->IsInstance()) { + doc << PrintMod(Downcast(node)); + } else if (node->IsInstance() || node->IsInstance() + || node->IsInstance()) { + doc << tir_text_printer_.Print(node); + } else { + doc << relay_text_printer_.PrintFinal(node); + } + if (!meta_.empty()) { + doc << Doc::NewLine(); + if (show_meta_data_) { + // append meta data in the end. + doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection(); + } else { + doc << "// meta data omitted. you can use show_meta_data=True to include meta data"; + } + } + return doc; + } + + Doc PrintMod(const IRModule& mod); +}; +} // namespace tvm + +#endif // TVM_PRINTER_TEXT_PRINTER_H_ diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc new file mode 100644 index 000000000000..a5754d7004b9 --- /dev/null +++ b/src/printer/tir_text_printer.cc @@ -0,0 +1,597 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir_text_printer.cc + * \brief Printer to print out the IR text format + * that can be parsed by a parser. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "doc.h" +#include "meta_data.h" +#include "text_printer.h" + +namespace tvm { +namespace tir { + +Doc TIRTextPrinter::Print(const ObjectRef& node) { + if (!node.defined()) return Doc::Text("(nullptr)"); + if (node->IsInstance()) { + return VisitStmt(Downcast(node)); + } else if (node->IsInstance()) { + return Doc::Text("?"); + } else if (node->IsInstance()) { + return VisitExpr(Downcast(node)); + } else if (node->IsInstance()) { + return VisitType(Downcast(node)); + } else if (node->IsInstance()) { + return PrintPrimFunc(Downcast(node)); + } else if (node->IsInstance()) { + return PrintIRModule(Downcast(node)); + } else if (node->IsInstance()) { + return PrintArray(node.as()); + } else if (node->IsInstance()) { + return PrintIterVar(node.as()); + } else if (node->IsInstance()) { + return PrintRange(node.as()); + } else if (node->IsInstance()) { + return PrintBuffer(node.as()); + } else if (node->IsInstance()) { + return PrintString(node.as()); + } else { + return this->meta_->GetMetaNode(node); + } +} + +Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) { + const auto* op = primFunc.operator->(); + const auto& signature = op->func_type_annotation(); + // collect Meta in DictAttr + for (const auto& it : primFunc->attrs->dict) { + meta_collector_.Collect(it.second); + } + // collect buffers in buffer_map + memo_var_.clear(); + memo_buf_.clear(); + for (const auto& it : op->buffer_map) { + memo_buf_[it.second] = AllocBuf(it.second); + } + // print PrimFunc + Doc doc; + doc << "primfn" << "("; + // print params and its type annotation + std::vector params; + for (const auto& param : op->params) { + params.push_back(Print(param)); + } + Doc sep; + doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")"; + // print return type + doc << " -> " << Print(signature->ret_type); + // print attr + Doc attr_doc; + std::vector attr_docs; + for (const auto& it : op->attrs->dict) { + attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); + } + attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}"; + doc << Doc::Indent(2, attr_doc); + // print all the buffers in the tree + Doc buffer_doc; + std::vector buffer_docs; + for (const auto& it : memo_buf_) { + const auto& buf = it.first; + buffer_docs.push_back(Print(buf) + << Doc::Text(": Buffer(") << Print(buf->data) << ", " + << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", " + << Print(buf->strides)); + if (!is_zero(buf->elem_offset)) { + buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset); + } + if (buf->scope != "global") { + buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope); + } + if (buf->data_alignment != 128) { + buffer_docs.back() << ", align=" << buf->data_alignment; + } + if (buf->offset_factor != 1) { + buffer_docs.back() << ", offset_factor=" << buf->offset_factor; + } + if (buf->buffer_type != 1) { + buffer_docs.back() << ", type=" << Doc::StrLiteral("auto"); + } + buffer_docs.back() << ")"; + } + buffer_doc << Doc::NewLine() << "buffers = {"; + buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine())); + doc << Doc::Indent(2, buffer_doc) << "}"; + // print buffer_map + std::vector buffer_map_doc; + for (const auto& it : op->buffer_map) { + buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second)); + } + doc << Doc::Indent(2, Doc::NewLine() + << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); + doc << PrintBody(op->body); + return doc; +} + +Doc TIRTextPrinter::PrintIRModule(const IRModule& module) { + const auto* op = module.operator->(); + Doc doc; + + Doc body; + body << Doc::NewLine(); + std::vector functions; + for (auto it = op->functions.begin(); it != op->functions.end(); ++it) { + if ((*it).second.as()) { + functions.push_back(Print((*it).second)); + } + } + body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine()); + doc << Doc::Indent(0, body); + return doc; +} + +Doc TIRTextPrinter::PrintArray(const ArrayNode* op) { + Doc doc; + doc << '['; + for (size_t i = 0; i < op->data.size(); ++i) { + if (i != 0) { + doc << ", "; + } + doc << Print(op->data[i]); + } + doc << ']'; + return doc; +} + +Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) { + Doc doc; + doc << "IterVar(" << Print(op->var); + if (op->dom.defined()) { + doc << ", [" << Print(op->dom) << "], "; + } else { + doc << ", " << Print(op->dom) << ", "; + } + doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", "; + doc << Doc::StrLiteral(op->thread_tag) << ")"; + return doc; +} + +Doc TIRTextPrinter::PrintRange(const RangeNode* op) { + return Print(op->min) << ":" << Print(op->min + op->extent); +} + +Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) { + const Buffer& buffer = GetRef(op); + CHECK_GT(memo_buf_.count(buffer), 0); + return meta_->InMeta(buffer) ? meta_->GetMetaNode(buffer) : memo_buf_[buffer]; +} + +Doc TIRTextPrinter::VisitExprDefault_(const Object* op) { + return this->meta_->GetMetaNode(GetRef(op)); +} + +Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) { + return this->meta_->GetMetaNode(GetRef(op)); +} + +Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) { + return PrintConstScalar(op->dtype, op->value); +} + +Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) { + return PrintConstScalar(op->dtype, op->value); +} + +Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); } + +Doc TIRTextPrinter::VisitExpr_(const CastNode* op) { + Doc doc; + doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const VarNode* op) { + const Var& var = GetRef(op); + return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef(op)); +} + +#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString) \ + Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \ + Doc doc; \ + doc << "(" << Print(op->a) << OpString; \ + doc << Print(op->b) << ")"; \ + return doc; \ + } + +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " && ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " || ") + +Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) { + Doc doc; + doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) { + Doc doc; + doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const MinNode* op) { + Doc doc; + doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) { + Doc doc; + doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const NotNode* op) { + Doc doc; + doc << "!" << Print(op->a); + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) { + Doc doc; + doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", " + << Print(op->false_value); + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) { + Doc doc; + doc << Print(op->buffer) << Print(op->indices); + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) { + Doc doc; + doc << "(" << PrintDType(op->dtype) << "*)" + << Print(op->buffer_var) << "[" << Print(op->index) << "])"; + if (!is_one(op->predicate)) { + doc << " if " << Print(op->predicate); + } + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const RampNode* op) { + Doc doc; + doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) { + Doc doc; + doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const LetNode* op) { + Doc doc; + doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body); + return doc; +} + +inline const char* CallType2String(CallNode::CallType t) { + switch (t) { + case CallNode::Extern:return "extern"; + case CallNode::ExternCPlusPlus:return "extern_cpp"; + case CallNode::PureExtern:return "pure_extern"; + case CallNode::Halide:return "halide"; + case CallNode::Intrinsic:return "intrin"; + case CallNode::PureIntrinsic:return "pure_intrin"; + } + LOG(FATAL) << "Unknown CallType"; + return "Unknown"; +} + +Doc TIRTextPrinter::VisitExpr_(const CallNode* op) { + Doc doc; + doc << "@" << Doc::Text(op->name) << "("; + std::vector args; + for (const auto& arg : op->args) { + args.push_back(Print(arg)); + } + doc << PrintSep(args, Doc::Text(", ")) + << ", dtype=" << PrintDType(op->dtype) + << ", type=" << Doc::StrLiteral(CallType2String(op->call_type)) + << ", index=" << op->value_index << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const ShuffleNode* op) { + Doc doc; + doc << "shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) { + Doc doc; + doc << "reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis) + << ", " << op->value_index << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) { + Doc doc; + doc << "let " << Print(op->var) << " = " << Print(op->value) << PrintBody(op->body); + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) { + Doc doc; + meta_collector_.Collect(op->node); + doc << "attr [" << Print(op->node) << "] " << Doc::StrLiteral(op->attr_key) << " = " + << Print(op->value); + if (op->body->IsInstance()) { + doc << PrintBody(op->body); + } else { + doc << ";" << Doc::NewLine() << Print(op->body); + } + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) { + Doc doc; + doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" + << PrintBody(op->body); + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const StoreNode* op) { + Doc doc; + doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value); + if (!is_one(op->predicate)) { + doc << " if " << Print(op->predicate); + } + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) { + Doc doc; + doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value); + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { + Doc doc; + doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", " + << Print(op->condition) << PrintBody(op->body) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { + Doc doc; + doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " + << Print(op->extents) << ")"; + if (!is_one(op->condition)) { + doc << " if " << Print(op->condition); + } + if (op->body->IsInstance()) { + doc << PrintBody(op->body); + } else { + doc << ";" << Doc::NewLine() << Print(op->body); + } + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const FreeNode* op) { + Doc doc; + doc << "free(" << Print(op->buffer_var) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) { + Doc doc; + doc << "if " << Print(op->condition) << PrintBody(op->then_case); + if (!is_one(op->condition) && op->else_case.defined()) { + doc << " else" << PrintBody(op->else_case); + } + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) { + std::vector stmts; + Doc seq_doc, doc; + for (Stmt stmt : op->seq) { + seq_doc << Doc::NewLine() << Print(stmt); + } + doc << " {" << Doc::Indent(2, seq_doc) << Doc::NewLine() << "}"; + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) { + Doc doc; + doc << Print(op->value); + return doc; +} + +inline const char* ForType2String(ForType t) { + switch (t) { + case ForType::Serial:return "serial"; + case ForType::Parallel:return "parallel"; + case ForType::Vectorized:return "vectorized"; + case ForType::Unrolled:return "unroll"; + } + LOG(FATAL) << "Unknown ForType"; + return "Unknown"; +} + +Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { + Doc doc; + doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " + << Print(op->min + op->extent) << ")"; + if (op->for_type != ForType::Serial) { + doc << " " << Doc::StrLiteral(ForType2String(op->for_type)); + } + doc << PrintBody(op->body); + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) { + Doc doc; + doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitType_(const PrimTypeNode* node) { + Doc doc; + doc << PrintDType(node->dtype); + return doc; +} + +Doc TIRTextPrinter::VisitType_(const PointerTypeNode* node) { + Doc doc; + doc << "Pointer(" << Print(node->element_type) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitType_(const TupleTypeNode* node) { + std::vector fields; + for (Type field : node->fields) { + fields.push_back(Print(field)); + } + Doc doc; + doc << "(" << Doc::Concat(fields); + // conform to python tuple format (1,) + if (node->fields.size() == 1) { + doc << ","; + } + return doc << ")"; +} + +Doc TIRTextPrinter::PrintDType(DataType dtype) { + return Doc::Text(runtime::DLDataType2String(dtype)); +} + +template +Doc TIRTextPrinter::PrintConstScalar(DataType dtype, const T& data) { + Doc doc; + std::ostringstream os; + os << data; + if (dtype == DataType::Int(32)) { + doc << Doc::Text(os.str()); + } else { + if (dtype.bits() == 1 && dtype.lanes() == 1 && dtype.code() == kDLUInt) { + doc << ((data == 1) ? "True" : "False"); + return doc; + } + doc << Doc::Text(os.str()); + switch (dtype.code()) { + case kDLInt: doc << "i"; break; + case kDLUInt: doc << "u"; break; + case kDLFloat: doc << "f"; break; + } + doc << Doc::Text(std::to_string(dtype.bits())); + if (dtype.lanes() != 1) doc << "x" << Doc::Text(std::to_string(dtype.lanes())); + } + return doc; +} + +Doc TIRTextPrinter::GetUniqueName(std::string prefix) { + // std::replace(prefix.begin(), prefix.end(), '.', '_'); + std::string unique_prefix = prefix; + auto it = name_alloc_map_.find(prefix); + if (it != name_alloc_map_.end()) { + while (name_alloc_map_.count( + unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) {} + } + name_alloc_map_[unique_prefix] = 0; + return Doc::Text(unique_prefix); +} + +Doc TIRTextPrinter::AllocVar(const Var& var) { + const auto& it = memo_var_.find(var); + if (it != memo_var_.end()) { + return it->second; + } + std::string name = var->name_hint; + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "v" + name; + } + Doc val = GetUniqueName(name); + memo_var_[var] = val; + return val << ": " << Print(GetType(var)); +} + +Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) { + const auto& it = memo_buf_.find(buffer); + if (it != memo_buf_.end()) { + return it->second; + } + std::string name = buffer->name; + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "buf_" + name; + } + Doc val = GetUniqueName(name); + memo_buf_[buffer] = val; + return val; +} + +Doc TIRTextPrinter::PrintSep(const std::vector& vec, const Doc& sep) { + Doc seq; + if (vec.size() != 0) { + seq = vec[0]; + for (size_t i = 1; i < vec.size(); i++) { + seq << sep << vec[i]; + } + } + return seq; +} + +Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) { + Doc doc; + if (body->IsInstance()) return Print(body); + doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}"; + return doc; +} + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 6efb67b19bad..372f0e9ce727 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -64,14 +64,14 @@ def test_deduce(): e2 = (tvm.te.max(5, a * 4) < 0) res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max_value) == "neg_inf" - assert str(res2.min_value) == "pos_inf" + assert str(res2.max_value) == "neg_inf: handle" + assert str(res2.min_value) == "pos_inf: handle" # expression containing variable a is on rhs e2 = (zero < tvm.te.max(5, a * 4)) res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max_value) == "neg_inf" - assert str(res2.min_value) == "pos_inf" + assert str(res2.max_value) == "neg_inf: handle" + assert str(res2.min_value) == "pos_inf: handle" e3 = (-b)+a*c-d res3 = tvm.arith.deduce_bound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) @@ -88,8 +88,8 @@ def test_deduce(): # Unsatisfiable `EQ`, variable as one of the Operand res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s}) - assert str(res5.max_value) == "neg_inf" - assert str(res5.min_value) == "pos_inf" + assert str(res5.max_value) == "neg_inf: handle" + assert str(res5.min_value) == "pos_inf: handle" # variable `a` on the RHS side res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {}) @@ -111,13 +111,13 @@ def test_deduce(): # Unsatisfiable Mul in `EQ` e5 = (4 * a == b) res9 = tvm.arith.deduce_bound(a, e5, {b: b_s}, {}) - assert str(res9.max_value) == "neg_inf" - assert str(res9.min_value) == "pos_inf" + assert str(res9.max_value) == "neg_inf: handle" + assert str(res9.min_value) == "pos_inf: handle" # Unsatisfiable Mul in `EQ` res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0) - assert str(res10.max_value) == "neg_inf" - assert str(res10.min_value) == "pos_inf" + assert str(res10.max_value) == "neg_inf: handle" + assert str(res10.min_value) == "pos_inf: handle" def test_check(): diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index 9e4d45e9efaa..9b8d4061afb4 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -286,8 +286,8 @@ def intrin_func(ins, outs, sp): stmt = tvm.lower(s, [A, C])["main"].body assert isinstance(stmt.body.body, tvm.tir.Evaluate) assert len(stmt.body.body.value.args) == 5 - assert str(stmt.body.body.value.args[3]) == "(i*i)" - assert str(stmt.body.body.value.args[4]) == "(i + j)" + assert str(stmt.body.body.value.args[3]) == "(i: int32*i)" + assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)" if __name__ == "__main__": test_singleton() diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 468ab1dbad6a..36c9c764f6ab 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -103,7 +103,7 @@ def test_basic(): a = te.var('a') b = te.var('b') c = a + b - assert str(c) == '(%s + %s)' % (a.name, b.name) + assert str(c) == '(%s: int32 + %s: int32)' % (a.name, b.name) def test_stmt(): @@ -138,11 +138,11 @@ def test_any(): assert False except ValueError: pass - assert str(tvm.tir.any(x < y)) == '(%s < %s)' % (x.name, y.name) - assert str(tvm.tir.any(x < y, x > z)) == '((%s < %s) || (%s > %s))' % ( + assert str(tvm.tir.any(x < y)) == '(%s: int32 < %s: int32)' % (x.name, y.name) + assert str(tvm.tir.any(x < y, x > z)) == '((%s: int32 < %s: int32) || (%s > %s: int32))' % ( x.name, y.name, x.name, z.name) assert str(tvm.tir.any(x < y, y > z + 1, x < z * 2)) == \ - '(((%s < %s) || (%s > (%s + 1))) || (%s < (%s*2)))' % ( + '(((%s: int32 < %s: int32) || (%s > (%s: int32 + 1))) || (%s < (%s*2)))' % ( x.name, y.name, y.name, z.name, x.name, z.name) @@ -160,29 +160,29 @@ def test_all(): assert False except ValueError: pass - assert str(tvm.tir.all(x < y)) == '(%s < %s)' % (x.name, y.name) - assert str(tvm.tir.all(x < y, x > z)) == '((%s < %s) && (%s > %s))' % ( + assert str(tvm.tir.all(x < y)) == '(%s: int32 < %s: int32)' % (x.name, y.name) + assert str(tvm.tir.all(x < y, x > z)) == '((%s: int32 < %s: int32) && (%s > %s: int32))' % ( x.name, y.name, x.name, z.name) assert str(tvm.tir.all(x < y, y > z + 1, x < z * 2)) == \ - '(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % ( + '(((%s: int32 < %s: int32) && (%s > (%s: int32 + 1))) && (%s < (%s*2)))' % ( x.name, y.name, y.name, z.name, x.name, z.name) def test_bitwise(): x = te.var('x') y = te.var('y') - assert str(x << y) == 'shift_left(x, y)' - assert str(x >> y) == 'shift_right(x, y)' - assert str(x & y) == 'bitwise_and(x, y)' - assert str(x | y) == 'bitwise_or(x, y)' - assert str(x ^ y) == 'bitwise_xor(x, y)' - assert str(10 & x) == 'bitwise_and(10, x)' - assert str(10 | x) == 'bitwise_or(10, x)' - assert str(10 ^ x) == 'bitwise_xor(10, x)' - assert str(10 >> x) == 'shift_right(10, x)' - assert str(10 << x) == 'shift_left(10, x)' - assert str(10 % x) == 'floormod(10, x)' - assert str(~x) == 'bitwise_not(x)' + assert str(x << y) == '@shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(x >> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin", index=0)' + assert str(10 % x) == 'floormod(10, x: int32)' + assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin", index=0)' assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2" assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" @@ -239,12 +239,12 @@ def test_divide_by_zero(): def test_isnan(): x = te.var('x', 'float32') - assert str(tvm.tir.isnan(x)) == 'isnan(x)' + assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin", index=0)' assert str(tvm.tir.isnan(x).dtype) == 'bool' y = te.var('y', 'float16') - assert str(tvm.tir.isnan(y)) == 'isnan(float32(y))' + assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin", index=0)' z = te.var('z', 'int32') - assert str(tvm.tir.isnan(z)) == '(bool)0' + assert str(tvm.tir.isnan(z)) == 'False' k = te.var('k', 'int8x2') assert str(tvm.tir.isnan(k).dtype) == 'uint1x2'