diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 06a1aa1ac9ef..4513022687f8 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -551,10 +551,9 @@ inline const TTypeNode* ExprNode::type_as() const { * additional comment block to an expr. * \return The text representation. */ -std::string RelayPrint( - const NodeRef& node, - bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr); +std::string RelayPrint(const NodeRef& node, + bool show_meta_data = true, + runtime::TypedPackedFunc annotate = nullptr); } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 9fdffab4e62e..de7e2ae24959 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -512,6 +512,9 @@ def make_parser(data): def fromtext(data, source_name=None): # type: (str, str) -> Union[expr.Expr, module.Module] """Parse a Relay program.""" + if data == "": + raise ParseError("Cannot parse the empty string.") + global __source_name_counter__ if source_name is None: diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index e0491d62f552..548c0e35a342 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -53,7 +53,7 @@ def astext(self, show_meta_data=True, annotate=None): Note ---- - The metadata section is necessary to fully parse the text format. + The meta data section is necessary to fully parse the text format. However, it can contain dumps that are big (e.g constant weights), so it can be helpful to skip printing the meta data section. diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 12b6ec8ca8e2..f3f8fea97412 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -905,3 +905,35 @@ def eliminate_common_subexpr(expr, fskip=None): The output expression. """ return _ir_pass.eliminate_common_subexpr(expr, fskip) + + +def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True): + """ + THIS SHOULD BE USED ONLY FOR DEBUGGING, NOT AS AN INTERCHANGE FORMAT! + USE `.astext()` INSTEAD! + + A version of the pretty printer intended for debugging passes. Contains + advanced printing options. + + Parameters + ---------- + ast : Union[relay.Expr, relay.Module, relay.Type] + The relay fragment to be turned into text. + + show_meta_data : bool + Whether to include meta data section in the text + if there is meta data. + + annotate: Optional[relay.Expr->str] + Optional annotate function to provide additional + information in the comment block. + + gnf : bool + Whether to print in GNF. If it is disabled, pointers are left implicit. + + Returns + ------- + text : str + A text representation of `ast`. + """ + return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf) diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc new file mode 100644 index 000000000000..ef98bd8ed1ed --- /dev/null +++ b/src/relay/ir/doc.cc @@ -0,0 +1,98 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file src/tvm/relay/doc.cc + * \brief Doc ADT used for pretty printing. + * Based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf. + */ +#include +#include +#include "doc.h" + +namespace tvm { +namespace relay { + +// Text constructor +DocAtom Text(const std::string& str) { + return std::make_shared(str); +} + +// Line constructor +DocAtom Line(int indent = 0) { + return std::make_shared(indent); +} + +Doc::Doc(const std::string& str) { + if (str == "\n") { + this->stream_ = {Line()}; + } else { + this->stream_ = {Text(str)}; + } +} + +// DSL function implementations + +Doc& Doc::operator<<(const Doc& right) { + assert(this != &right); + this->stream_.insert(this->stream_.end(), right.stream_.begin(), right.stream_.end()); + return *this; +} + +Doc& Doc::operator<<(const std::string& right) { + return *this << Doc(right); +} + +Doc Indent(int indent, const Doc& doc) { + Doc ret; + for (auto atom : doc.stream_) { + if (auto text = std::dynamic_pointer_cast(atom)) { + ret.stream_.push_back(text); + } else if (auto line = std::dynamic_pointer_cast(atom)) { + ret.stream_.push_back(Line(indent + line->indent)); + } else {assert(false);} + } + return ret; +} + +std::string Doc::str() { + std::ostringstream os; + for (auto atom : this->stream_) { + if (auto text = std::dynamic_pointer_cast(atom)) { + os << text->str; + } else if (auto line = std::dynamic_pointer_cast(atom)) { + os << "\n" << std::string(line->indent, ' '); + } else {assert(false);} + } + return os.str(); +} + +Doc PrintVec(const std::vector& vec, const Doc& sep) { + Doc seq; + if (vec.size() != 0) { + seq = vec[0]; + for (size_t i = 1; i < vec.size(); i++) { + seq << sep << vec[i]; + } + } + return seq; +} + +Doc PrintBool(bool value) { + if (value) { + return Doc("True"); + } else { + return Doc("False"); + } +} + +Doc PrintDType(DataType dtype) { + return Doc(runtime::TVMType2String(Type2TVMType(dtype))); +} + +Doc PrintString(const std::string& value) { + // TODO(M.K.): add escape. + Doc doc; + return doc << "\"" << value << "\""; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h new file mode 100644 index 000000000000..b9a82555c479 --- /dev/null +++ b/src/relay/ir/doc.h @@ -0,0 +1,99 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/relay/doc.h + * \brief Doc ADT used for pretty printing. + * Based on Section 1 of + * https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf, but with + * a vector instead of an implicitly linked list. + */ +#ifndef TVM_RELAY_IR_DOC_H_ +#define TVM_RELAY_IR_DOC_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +// Doc Atom ADT +struct DocAtomNode { + virtual ~DocAtomNode() = default; +}; + +using DocAtom = std::shared_ptr; + +struct TextNode : DocAtomNode { + std::string str; + + explicit TextNode(const std::string& str) : str(str) {} +}; + +struct LineNode : DocAtomNode { + int indent; + + explicit LineNode(int indent) : indent(indent) {} +}; + +// Doc is a stream-like interface +class Doc { + public: + Doc() {} + explicit Doc(const std::string& str); + + // Append right to this. + Doc& operator<<(const Doc& right); + // Like above, but automatically lifts string to a Doc. + Doc& operator<<(const std::string& right); + // Like above, but converts right to a string first. + template + Doc& operator<<(const T& right) { + std::ostringstream os; + os << right; + return *this << os.str(); + } + + // Indent a doc stream. + friend Doc Indent(int indent, const Doc& doc); + + // Wadler's `layout` + std::string str(); + + private: + std::vector stream_; +}; + +// DSL functions + +// Render vectors of docs with a separator. e.g. PrintVec([1, 2, 3], f) -> 1f2f3 +Doc PrintVec(const std::vector& vec, const Doc& sep = Doc(", ")); +// Print a constant bool value. +Doc PrintBool(bool value); +// Print a data type. +Doc PrintDType(DataType dtype); +// Print a string. +Doc PrintString(const std::string& value); +/*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param data The pointer to hold the data. + */ +template +Doc PrintConstScalar(DataType dtype, const T* data) { + std::ostringstream os; + if (dtype == Int(32)) { + os << data[0]; + } else if (dtype == Float(32)) { + os << data[0] << 'f'; + } else if (dtype == Bool()) { + return PrintBool(data[0] != 0); + } else { + os << dtype << "(" << data[0] << ")"; + } + return Doc(os.str()); +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_IR_DOC_H_ diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc new file mode 100644 index 000000000000..a030a056f7cd --- /dev/null +++ b/src/relay/ir/pretty_printer.cc @@ -0,0 +1,716 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file pretty_printer.cc + * \brief Pretty printer for Relay programs + * Supports ANF, GNF, and metadata. + */ +#include +#include +#include +#include "doc.h" +#include "type_functor.h" +#include "../../lang/attr_functor.h" + +namespace tvm { +namespace relay { + +/*! + * \brief Meta data context for PrettyPrinter. + * + * This is an important part to enable bi-directional serializability. + * We use tvm's Node system to build the current IR. + * It can be hard to design a text format for all the possible nodes + * as the set of nodes can grow when we do more extensions. + * + * Instead of trying to design readable text format for every node, + * we support a meta data section in the text format. + * We allow the text format to refer to a node in the meta data section. + * + * The meta data section is a json serialized string of an Map>. + * Each element in the meta data section can be referenced by the text format. + * Each meta data node is printed in the following format. + * + * meta[type-key-of-node>][] + * + * Specifically, consider the following IR(constructed by python). + * + * \code + * + * n = tvm.var("n") + * x = tvm.relay.var("x", shape=(n, 1)) + * f = tvm.relay.Function([x], x) + * print(f.astext()) + * + * \endcode + * + * The corresponding text format is shown in the following code block. + * + * \code + * + * fn (%x: Tensor[(meta[Variable][0],), float32]) { + * %x + * } + * # Meta data section is a json-serialized string + * # of the following array. + * # [tvm.var("n")] + * + * \endcode + * + * Note that we store tvm.var("n") in the meta data section. + * Since it is stored in the index-0 in the meta data section, + * we print it as meta[Variable][0]. + * + * The text parser can recover this object by loading from the corresponding + * location in the meta data section. + * + * This is is a design trade-off. + * It allows us to embedded any meta data in the text format, + * while still being able to tweak the text part of the printed IR easily. + */ +class TextMetaDataContext { + public: + /*! + * \brief Get text representation of meta node. + * \param node The node to be converted to meta node. + * \return A string representation of the meta node. + */ + Doc GetMetaNode(const NodeRef& node) { + auto it = meta_repr_.find(node); + if (it != meta_repr_.end()) { + return it->second; + } + Array& mvector = + meta_data_[node->type_key()]; + int64_t index = static_cast(mvector.size()); + mvector.push_back(node); + Doc doc; + doc << "meta[" << node->type_key() << "][" << index << "]"; + meta_repr_[node] = doc; + return meta_repr_[node]; + } + /*! + * \brief Get the metadata section in json format. + * \return the meta data string. + */ + std::string GetMetaSection() const { + if (meta_data_.size() == 0) return std::string(); + return SaveJSON(Map( + meta_data_.begin(), meta_data_.end())); + } + + /*! \return whether the meta data context is empty. */ + bool empty() const { + return meta_data_.empty(); + } + + private: + /*! \brief additional metadata stored in TVM json format */ + std::unordered_map > meta_data_; + /*! \brief map from meta data into its string representation */ + std::unordered_map meta_repr_; +}; + +class PrettyPrinter : + public ExprFunctor, + public PatternFunctor, + public TypeFunctor, + public AttrFunctor { + public: + explicit PrettyPrinter(bool GNF, + bool show_meta_data, + runtime::TypedPackedFunc annotate) : + GNF_(GNF), + show_meta_data_(show_meta_data), + annotate_(annotate) {} + + /*! + * \brief Print additional info about expr in comment. + * \param expr The expression. + */ + Doc PrintOptionalInfo(const Expr& expr) { + Doc doc; + // additional information in comment. + if (annotate_ != nullptr) { + return doc << " // " << annotate_(expr); + } else if (expr->checked_type_.defined()) { + doc << " // ty="; + return doc << Print(expr->checked_type()); + } else { + return doc; + } + } + + // indent a new body + // TODO(jmp): indent should be an instance variable of the printer + Doc PrintBody(const NodeRef& node, int indent = 2) { + Doc doc; + Doc body; + doc << "{"; + doc << Indent(indent, body << "\n" << PrintScope(node)) << "\n"; + doc << "}"; + return doc; + } + + // create a new scope by creating a new printer object. This allows temp var + // numbers to be reused and prevents hoisted vars from escaping too far + Doc PrintScope(const NodeRef& node) { + // print in a new scope + doc_stack_.push_back(Doc()); + // must print first so doc_stack_.back() reference doesn't become stale + Doc doc = Print(node); + doc = doc_stack_.back() << doc; + doc_stack_.pop_back(); + return doc; + } + + Doc PrintFinal(const NodeRef& node) { + Doc doc; + doc << PrintScope(node); + if (!meta_.empty()) { + if (show_meta_data_) { + std::string meta_json = meta_.GetMetaSection(); + // append meta data in the end. + doc << "\n" << "/* meta data */" << "\n" << meta_json; + } else { + doc << "\n" + << "// meta data omitted. you can use show_meta_data=True to include meta data"; + } + } + return doc; + } + + Doc PrintAttrs(const Attrs& attrs, const Expr& op); + + Doc Print(const NodeRef& node, bool meta = false) { + if (node.as_derived()) { + return PrintExpr(Downcast(node), meta); + } else if (node.as_derived()) { + return PrintType(Downcast(node), meta); + } else if (node.as_derived()) { + return PrintMod(Downcast(node)); + } else { + Doc doc; + return doc << node; + } + } + + Doc TempVar(int n) { + Doc doc; + return doc << "%" << n; + } + + Doc AllocTemp() { + return TempVar(temp_var_counter_++); + } + + /*! + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ + Doc GetUniqueName(const std::string& prefix) { + std::string unique_prefix = prefix; + auto it = name_alloc_map_.find(prefix); + if (it != name_alloc_map_.end()) { + while (true) { + std::ostringstream os; + os << prefix << (++it->second); + std::string name = os.str(); + if (name_alloc_map_.count(name) == 0) { + unique_prefix = name; + break; + } + } + } + name_alloc_map_[unique_prefix] = 0; + return Doc(unique_prefix); + } + + /*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ + Doc AllocVar(const Var& var) { + std::string name = var->name_hint(); + // always make sure first name is alpha + if (name.length() != 0 && !std::isalpha(name[0])) { + name = "v" + name; + } + Doc val = GetUniqueName("%" + name); + // still print if ir is malformed, but show the error. + if (memo_.count(var)) { + val << "-malformed-ir"; + } + memo_[var] = val; + if (var->type_annotation.defined()) { + val << ": " << Print(var->type_annotation); + } + return val; + } + + inline bool IsAtomicExpr(const Expr& expr) { + return expr.as() || expr.as() || + expr.as() || expr.as(); + } + + //------------------------------------ + // Overload of Expr printing functions + //------------------------------------ + Doc PrintExpr(const Expr& expr, bool meta) { + // Exploit memoization to print GNF. + // The first time we visit an expression, we need to allocate a temp var + // for it. Every subsequent time we can just use its assigned variable. + // This works since hashing uses pointer equality. + auto it = memo_.find(expr); + if (it != memo_.end()) return it->second; + Doc printed_expr; + if (meta) { + printed_expr = meta_.GetMetaNode(GetRef(expr.get())); + } else if (GNF_ && expr.as()) { + // wrap GNFed let in brackets + Doc body; + printed_expr << "{"; + printed_expr << Indent(2, body << "\n" << VisitExpr(expr)) << "\n"; + printed_expr << "}"; + } else { + printed_expr = VisitExpr(expr); + } + // we choose to inline atomic exprs + if (GNF_ && !IsAtomicExpr(expr)) { + Doc temp_var = AllocTemp(); + memo_[expr] = temp_var; + doc_stack_.back() << temp_var << " = " << printed_expr; + if (expr.as()) { + doc_stack_.back() << PrintOptionalInfo(expr); + } + doc_stack_.back() << "\n"; + return temp_var; + } else if (expr.as()) { + // This is our first time visiting the var and we hit the VarNode case + // in the visitor. Thus the variable is free. + doc_stack_.back() << "free_var " << printed_expr << "\n"; + // Memoization is done in AllocVar. + return memo_[expr]; + } else { + memo_[expr] = printed_expr; + if (GNF_ && expr.as()) { + printed_expr << PrintOptionalInfo(expr); + } + return printed_expr; + } + } + + // Should only be triggered when op is a free variable being visited for the + // first time. + Doc VisitExpr_(const VarNode* op) final { + return AllocVar(GetRef(op)); + } + + Doc VisitExpr_(const ConstantNode* op) final { + // Print out simple scalars directly. + if (op->is_scalar()) { + std::ostringstream os; + DataType dtype = TVMType2Type(op->data->dtype); + CHECK_EQ(op->data->ctx.device_type, kDLCPU); + if (dtype == Int(32)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Int(64)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Float(32)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Float(64)) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } else if (dtype == Bool()) { + return PrintConstScalar(dtype, static_cast(op->data->data)); + } + } + // default fall-back, record it as meta node. + Doc doc; + return doc << Print(GetRef(op), true) + << PrintOptionalInfo(GetRef(op)); + } + + Doc VisitExpr_(const TupleNode* op) final { + std::vector fields; + for (Expr field : op->fields) { + fields.push_back(Print(field)); + } + Doc doc; + doc << "(" << PrintVec(fields); + // conform to python tuple format (1,) + if (op->fields.size() == 1) { + doc << ","; + } + return doc << ")"; + } + + Doc VisitExpr_(const TupleGetItemNode* op) final { + Doc doc; + return doc << Print(op->tuple) << "." << op->index; + } + + Doc VisitExpr_(const IfNode* op) final { + Doc doc; + doc << "if (" << Print(op->cond) << ") "; + doc << PrintBody(op->true_branch); + doc << " else "; + doc << PrintBody(op->false_branch); + return doc; + } + + Doc VisitExpr_(const LetNode* op) final { + Doc doc; + doc << "let " << AllocVar(op->var) << " = " << Print(op->value) << "\n"; + // we use a scope here so GNF hoisting doesn't escape too far + doc << PrintScope(op->body); + return doc; + } + + Doc PrintFunc(const Doc& prefix, const Function& fn) { + // TODO(tqchen, M.K.) support generic function + // Possibly through meta data + CHECK_EQ(fn->type_params.size(), 0U) + << "generic fn not yet supported"; + Doc doc; + doc << prefix << "("; + std::vector params; + for (Var param : fn->params) { + params.push_back(AllocVar(param)); + } + doc << PrintVec(params) << PrintAttrs(fn->attrs, fn); + doc << ") "; + if (fn->ret_type.defined()) { + doc << "-> " << Print(fn->ret_type) << " "; + } + doc << PrintBody(fn->body); + return doc; + } + + Doc PrintMod(const Module& mod) { + Doc doc; + int counter = 0; + for (const auto& kv : mod->functions) { + std::ostringstream os; + if (counter++ != 0) { + doc << "\n"; + } + os << "def @" << kv.first->name_hint; + doc << PrintFunc(Doc(os.str()), kv.second); + doc << "\n"; + } + return doc; + } + + Doc VisitExpr_(const FunctionNode* op) final { + return PrintFunc(Doc("fn "), GetRef(op)); + } + + Doc VisitExpr_(const GlobalVarNode* op) final { + return Doc('@' + op->name_hint); + } + + Doc VisitExpr_(const OpNode* op) final { + return Doc(op->name); + } + + Doc VisitExpr_(const CallNode* op) final { + Doc doc; + doc << Print(op->op); + std::vector args; + for (Expr arg : op->args) { + args.push_back(Print(arg)); + } + return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs, op->op) << ")"; + } + + Doc VisitExpr_(const RefCreateNode* op) final { + Doc doc; + return doc << "ref(" << Print(op->value) << ")"; + } + + Doc VisitExpr_(const RefReadNode* op) final { + Doc doc; + return doc << Print(op->ref) << "^"; + } + + Doc VisitExpr_(const RefWriteNode* op) final { + Doc doc; + return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")"; + } + + Doc VisitExpr_(const MatchNode* op) final { + // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs. + Doc doc; + Doc body; + doc << "match " << Print(op->data) << " "; + doc << "{"; + std::vector clauses; + for (const auto& clause : op->clauses) { + Doc clause_doc; + clauses.push_back(clause_doc << Print(clause->lhs) << " -> " + << Print(clause->rhs)); + } + doc << Indent(2, body << "\n" << PrintVec(clauses, Doc("\n"))) << "\n"; + doc << "}"; + return doc; + } + + Doc VisitPattern_(const PatternConstructorNode* p) final { + Doc doc; + doc << p->constructor->name_hint << "("; + std::vector pats; + for (const auto& pat : p->patterns) { + pats.push_back(Print(pat)); + } + return doc << PrintVec(pats) << ")"; + } + + Doc VisitPattern_(const PatternVarNode* pv) final { + return AllocVar(pv->var); + } + + Doc VisitExpr_(const ConstructorNode* n) final { + return Doc(n->name_hint); + } + + //------------------------------------ + // Overload of Type printing functions + //------------------------------------ + Doc PrintType(const Type& type, bool meta) { + auto it = memo_type_.find(type); + if (it != memo_type_.end()) return it->second; + Doc printed_type; + if (meta) { + printed_type = meta_.GetMetaNode(GetRef(type.get())); + } else { + printed_type = VisitType(type); + } + memo_type_[type] = printed_type; + return printed_type; + } + + Doc VisitTypeDefault_(const Node* node) final { + // by default always print as meta data + return Print(GetRef(node), true); + } + + Doc VisitType_(const TensorTypeNode* node) final { + // scalar type + if (node->shape.size() == 0) { + return PrintDType(node->dtype); + } + Doc doc; + doc << "Tensor[("; + std::vector shapes; + for (NodeRef shape : node->shape) { + shapes.push_back(PrintAttr(shape)); + } + doc << PrintVec(shapes); + // conform to python tuple format (1,) + if (node->shape.size() == 1) { + doc << ","; + } + return doc << "), " << PrintDType(node->dtype) << "]"; + } + + Doc VisitType_(const TupleTypeNode* node) final { + std::vector fields; + for (Type field : node->fields) { + fields.push_back(Print(field)); + } + Doc doc; + doc << "(" << PrintVec(fields); + // conform to python tuple format (1,) + if (node->fields.size() == 1) { + doc << ","; + } + return doc << ")"; + } + + Doc VisitType_(const FuncTypeNode* node) final { + Doc doc; + std::vector arg_types; + for (Type arg_type : node->arg_types) { + arg_types.push_back(Print(arg_type)); + } + return doc << "fn (" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); + } + + Doc VisitType_(const RefTypeNode* node) final { + Doc doc; + return doc << "ref(" << Print(node->value) << ")"; + } + + //------------------------------------ + // Overload of Attr printing functions + //------------------------------------ + + Doc PrintAttr(const NodeRef& value, bool meta = false) { + if (value.defined()) { + Doc printed_attr; + if (meta) { + printed_attr = meta_.GetMetaNode(value); + } else { + printed_attr = VisitAttr(value); + } + return printed_attr; + } else { + return Doc("None"); + } + } + + Doc VisitAttrDefault_(const Node* op) final { + return PrintAttr(GetRef(op), true); + } + + Doc VisitAttr_(const ArrayNode* op) final { + Doc doc; + doc << "["; + std::vector arr_vals; + for (NodePtr val : op->data) { + arr_vals.push_back(PrintAttr(NodeRef(val))); + } + doc << PrintVec(arr_vals); + doc << "]"; + return doc; + } + + Doc VisitAttr_(const ir::IntImm* op) final { + return PrintConstScalar(op->type, &(op->value)); + } + + Doc VisitAttr_(const ir::UIntImm* op) final { + return PrintConstScalar(op->type, &(op->value)); + } + + Doc VisitAttr_(const ir::FloatImm* op) final { + return PrintConstScalar(op->type, &(op->value)); + } + + Doc VisitAttr_(const ir::StringImm* op) final { + return PrintString(op->value); + } + + private: + /*! \brief Whether to use GNF. */ + bool GNF_; + /*! \brief Whether to print meta data. */ + bool show_meta_data_; + /*! \brief additional comment function */ + runtime::TypedPackedFunc annotate_; + /*! \brief Stack of docs to implement scoped GNFing. */ + std::vector doc_stack_{}; + /*! \brief Map from Expr to Doc */ + std::unordered_map memo_; + /*! \brief Map from Type to Doc */ + std::unordered_map memo_type_; + /*! \brief name allocation map */ + std::unordered_map name_alloc_map_; + /*! \brief meta data context */ + TextMetaDataContext meta_; + /*! \brief counter of temporary variable */ + size_t temp_var_counter_{0}; + class AttrPrinter; + friend class AttrPrinter; +}; + +/*! + * \brief Attribute printer which prints the attributes in the call. + */ +class PrettyPrinter::AttrPrinter : public AttrVisitor { + public: + AttrPrinter(Doc& doc, PrettyPrinter* parent) : doc_(doc), parent_(parent) {} + + template + Doc PrintKV(const char* key, const T& value) { + Doc doc; + return doc << ", " << key << "=" << value; + } + + void Visit(const char* key, double* value) final { + doc_ << PrintKV(key, value[0]); + } + void Visit(const char* key, int64_t* value) final { + doc_ << PrintKV(key, value[0]); + } + void Visit(const char* key, uint64_t* value) final { + doc_ << PrintKV(key, value[0]); + } + void Visit(const char* key, int* value) final { + doc_ << PrintKV(key, value[0]); + } + void Visit(const char* key, bool* value) final { + doc_ << PrintKV(key, PrintBool(value[0])); + } + void Visit(const char* key, std::string* value) final { + doc_ << PrintKV(key, PrintString(value[0])); + } + void Visit(const char* key, void** value) final { + LOG(FATAL) << "do not allow void as argument"; + } + void Visit(const char* key, DataType* value) final { + doc_ << PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(value[0])))); + } + void Visit(const char* key, NodeRef* value) final { + doc_ << PrintKV(key, parent_->PrintAttr(value[0])); + } + void Visit(const char* key, runtime::NDArray* value) final { + LOG(FATAL) << "do not allow NDarray as argument"; + } + + private: + Doc& doc_; + PrettyPrinter* parent_; +}; + +Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) { + Doc doc; + if (!attrs.defined()) return doc; + const auto* op_node = op.as(); + if (op_node && (attrs->type_index() != op_node->attrs_type_index)) { + // fallback + return doc << ", " << meta_.GetMetaNode(attrs); + } else { + AttrPrinter printer(doc, this); + const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); + return doc; + } +} + +std::string PrettyPrint_(const NodeRef& node, + bool show_meta_data, + runtime::TypedPackedFunc annotate, + bool gnf) { + Doc doc; + doc << "v0.0.1" << "\n" + << PrettyPrinter(gnf, show_meta_data, annotate).PrintFinal(node); + return doc.str(); +} + +std::string RelayPrint(const NodeRef& node, + bool show_meta_data, + runtime::TypedPackedFunc annotate) { + return PrettyPrint_(node, show_meta_data, annotate, true); +} + +std::string PassDebugPrint(const NodeRef& node, + bool show_meta_data, + runtime::TypedPackedFunc annotate, + bool gnf) { + return PrettyPrint_(node, show_meta_data, annotate, gnf); +} + +TVM_REGISTER_API("relay._expr.RelayPrint") +.set_body_typed)>(RelayPrint); + +TVM_REGISTER_API("relay._ir_pass.pass_debug_print") +.set_body_typed, + bool)>(PassDebugPrint); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc deleted file mode 100644 index 932856a2055d..000000000000 --- a/src/relay/ir/text_printer.cc +++ /dev/null @@ -1,904 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file text_printer.cc - * \brief Text printer to print relay in text form. - */ -#include -#include -#include -#include -#include "type_functor.h" -#include "../../lang/attr_functor.h" - -namespace tvm { -namespace relay { - -/*! - * \brief the text value used in text printer. - * Defined as a struct for future compatibility reason - */ -struct TextValue { - /*! \brief The str representation */ - std::string name; - // constructor - TextValue() {} - // constructor - explicit TextValue(std::string name) : name(name) {} - TextValue operator+(const TextValue& rhs) const { - return TextValue(name + rhs.name); - } - TextValue operator+(const std::string& str) const { - return TextValue(name + str); - } -}; - -// operator overloading -inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NOLINT(*) - return os << val.name; -} - -/*! - * \brief Meta data context for TextPrinter. - * - * This is an important part to enable bi-directional serializability. - * We use tvm's Node system to build the current IR. - * It can be hard to design a text format for all the possible nodes - * as the set of nodes can grow when we do more extensions. - * - * Instead of trying to design readable text format for every node, - * we support a meta-data section in the text format. - * We allow the text format to refer to a node in the meta-data section. - * - * The meta-data section is a json serialized string of an Map>. - * Each element in the meta-data section can be referenced by the text format. - * Each meta data node is printed in the following format. - * - * meta[type-key-of-node>][] - * - * Specifically, consider the following IR(constructed by python). - * - * \code - * - * n = tvm.var("n") - * x = tvm.relay.var("x", shape=(n, 1)) - * f = tvm.relay.Function([x], x) - * print(f.astext()) - * - * \endcode - * - * The corresponding text format is shown in the following code block. - * - * \code - * - * fn (%x: Tensor[(meta[Variable][0],), float32]) { - * %x - * } - * # Meta data section is a json-serialized string - * # of the following array. - * # [tvm.var("n")] - * - * \endcode - * - * Note that we store tvm.var("n") in the meta data section. - * Since it is stored in the index-0 in the meta-data section, - * we print it as meta[Variable][0]. - * - * The text parser can recover this object by loading from the corresponding - * location in the meta data section. - * - * This is is a design trade-off. - * It allows us to embedded any meta-data in the text format, - * while still being able to tweak the text part of the printed IR easily. - */ -class TextMetaDataContext { - public: - /*! - * \brief Get text representation of meta node. - * \param node The node to be converted to meta node. - * \return A string representation of the meta node. - */ - std::string GetMetaNode(const NodeRef& node) { - auto it = meta_repr_.find(node); - if (it != meta_repr_.end()) { - return it->second; - } - Array& mvector = - meta_data_[node->type_key()]; - int64_t index = static_cast(mvector.size()); - mvector.push_back(node); - std::ostringstream os; - os << "meta[" << node->type_key() << "][" << index << "]"; - meta_repr_[node] = os.str(); - return meta_repr_[node]; - } - /*! - * \brief Get the metadata section in json format. - * \return the meta datastring. - */ - std::string GetMetaSection() const { - if (meta_data_.size() == 0) return std::string(); - return SaveJSON(Map( - meta_data_.begin(), meta_data_.end())); - } - - /*! \return whether the meta data context is empty. */ - bool empty() const { - return meta_data_.empty(); - } - - private: - /*! \brief additional metadata stored in TVM json format */ - std::unordered_map > meta_data_; - /*! \brief map from meta data into its string representation */ - std::unordered_map meta_repr_; -}; - -class TextPrinter : - public ExprFunctor, - public PatternFunctor, - public TypeFunctor, // NOLINT(*) - public AttrFunctor { // NOLINT(*) - public: - explicit TextPrinter(bool show_meta_data, - runtime::TypedPackedFunc annotate) - : show_meta_data_(show_meta_data), annotate_(annotate) {} - /*! - * \brief Print a node to string. - * \param node. - * \return The string representation. - */ - std::string Print(const NodeRef& node) { - if (node.as()) { - this->PrintFunc(Downcast(node)); - } else if (node.as()) { - this->PrintEnv(Downcast(node)); - } else if (node.as_derived()) { - this->PrintType(Downcast(node), stream_); - } else if (node.as_derived()) { - this->PrintExpr(Downcast(node)); - } else { - stream_ << node; - } - if (!meta_.empty()) { - if (show_meta_data_) { - std::string meta_json = meta_.GetMetaSection(); - // append meta data in the end. - stream_ << "# meta data\n" - << "r\"\"\"\n" - << meta_json << "\n" - << "\"\"\""; - } else { - stream_ << "# meta data omitted. you can use show_meta_data=True to include meta-data\n"; - } - } - return stream_.str(); - } - - void PrintFunc(const Function& func) { - this->PrintFuncInternal("fn ", func); - stream_ << "\n"; - } - - void PrintEnv(const Module& mod) { - int counter = 0; - for (const auto& kv : mod->functions) { - std::ostringstream os; - if (counter++ != 0) { - stream_ << "\n"; - } - os << "def @" << kv.first->name_hint; - this->PrintFuncInternal(os.str(), kv.second); - stream_ << "\n"; - } - } - - void PrintExpr(const Expr& expr) { - TextValue val = GetValue(expr); - stream_ << val << "\n"; - } - - /*! - * \brief Get text representation of expr. - * - * This function may generate additional instructions - * in order to compute the final result id of expr. - * - * When trying to recursively print out an Expr. - * The caller should always call GetValue of its children first. - * Then the caller can print out to stream_ using the obtained value. - * - * This is to avoid the call of subsequent GetValue print out - * additional instructions which get mixed with the partial instruction - * printed by the caller. - * - * \param expr The input expression. - * \return The text value of Expr. - */ - TextValue GetValue(const Expr& expr) { - auto it = memo_.find(expr); - if (it != memo_.end()) return it->second; - TextValue val = this->VisitExpr(expr); - memo_[expr] = val; - return val; - } - TextValue GetValue(const Pattern& p) { - return this->VisitPattern(p); - } - //------------------------------------ - // Overload of Expr printing functions - //------------------------------------ - TextValue VisitExpr_(const ConstantNode* op) final { - // Print out simple scalar directly. - if (op->is_scalar()) { - std::ostringstream os; - DataType dtype = TVMType2Type(op->data->dtype); - CHECK_EQ(op->data->ctx.device_type, kDLCPU); - if (dtype == Int(32)) { - return ConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Int(64)) { - return ConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Float(32)) { - return ConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Float(64)) { - return ConstScalar(dtype, static_cast(op->data->data)); - } else if (dtype == Bool()) { - return ConstScalar(dtype, static_cast(op->data->data)); - } - } - // default fall-back, record it as meta node. - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = " << meta_.GetMetaNode(GetRef(op)); - this->PrintEndInst(""); - this->PrintOptionalInfo(GetRef(op)); - stream_ << '\n'; - return id; - } - - TextValue VisitExpr_(const TupleNode* op) final { - std::vector fields; - for (Expr field : op->fields) { - fields.push_back(GetValue(field)); - } - // NOTE: always recursively visit to get ids, - // before print out the current line - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = ("; - for (size_t i = 0; i < fields.size(); ++i) { - stream_ << fields[i]; - if (i + 1 != fields.size()) { - stream_ << ", "; - } - } - if (fields.size() == 1) { - stream_ << ','; - } - stream_ << ')'; - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const VarNode* op) final { - Var var = GetRef(op); - // This is an unbounded var. - TextValue val = AllocVarName(var); - this->PrintIndent(); - stream_ << "free_var "; - this->PrintVarDecl(var, stream_); - this->PrintEndInst("\n"); - return val; - } - - TextValue VisitExpr_(const GlobalVarNode* op) final { - return TextValue('@' + op->name_hint); - } - - TextValue VisitExpr_(const FunctionNode* op) final { - TextValue id = AllocTempVar(); - std::ostringstream os; - os << id << " = fn"; - this->PrintFuncInternal(os.str(), GetRef(op)); - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const CallNode* op) final { - // possibly through meta-data - std::vector args; - for (Expr arg : op->args) { - args.emplace_back(GetValue(arg)); - } - TextValue call_op = GetValue(op->op); - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - - stream_ << id << " = " << call_op; - - auto type_args = op->type_args; - - if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) { - stream_ << "<"; - for (size_t i = 0; i < op->type_args.size(); ++i) { - this->PrintType(type_args[i], stream_); - if (i + 1 != type_args.size()) { - stream_ << ", "; - } - } - stream_ << ">"; - } - - stream_ << "("; - for (size_t i = 0; i < args.size(); ++i) { - stream_ << args[i]; - if (i + 1 != args.size()) { - stream_ << ", "; - } - } - this->PrintCallAttrs(op->op, op->attrs, stream_); - stream_ << ")"; - this->PrintEndInst(""); - this->PrintOptionalInfo(GetRef(op)); - stream_ << '\n'; - return id; - } - - TextValue VisitExpr_(const LetNode* op) final { - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = "; - this->PrintScope(GetRef(op)); - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const IfNode* op) final { - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = "; - this->PrintScope(GetRef(op)); - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const OpNode* op) final { - return TextValue(op->name); - } - - TextValue VisitExpr_(const TupleGetItemNode* op) final { - TextValue tuple = GetValue(op->tuple); - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = " << tuple << "." << op->index << ""; - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const RefCreateNode* op) final { - TextValue value = GetValue(op->value); - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = " << "RefCreate(" << op->value << ")"; - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const RefReadNode* op) final { - TextValue ref = GetValue(op->ref); - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = " << "RefRead(" << ref << ")"; - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const RefWriteNode* op) final { - TextValue ref = GetValue(op->ref); - TextValue value = GetValue(op->value); - TextValue id = this->AllocTempVar(); - this->PrintIndent(); - stream_ << id << " = " << "RefWrite(" << ref << ", " << value << ")"; - this->PrintEndInst("\n"); - return id; - } - - TextValue VisitExpr_(const MatchNode* op) final { - TextValue data = GetValue(op->data); - this->PrintIndent(); - TextValue id = this->AllocTempVar(); - stream_ << id << " = " << "Match " << data << " with"; - this->PrintEndInst("\n"); - for (const auto& c : op->clauses) { - this->PrintIndent(); - stream_ << GetValue(c->lhs) << " to " << GetValue(c->rhs); - this->PrintEndInst("\n"); - } - return id; - } - - TextValue VisitPattern_(const PatternConstructorNode* p) final { - TextValue ret(p->constructor->name_hint + "("); - for (const Pattern& pat : p->patterns) { - ret = ret + " " + GetValue(pat); - } - return ret + ")"; - } - - TextValue VisitPattern_(const PatternVarNode* pv) final { - return GetValue(pv->var); - } - - TextValue VisitExpr_(const ConstructorNode* n) final { - return TextValue(n->name_hint); - } - - /*! - * \brief Print the type to os - * \param type The type to be printed. - * \param os The output type. - */ - void PrintType(const Type& type, std::ostream& os) { // NOLINT(*) - this->VisitType(type, os); - } - //------------------------------------ - // Overload of Expr printing functions - //------------------------------------ - void VisitType_(const TensorTypeNode* node, std::ostream& os) final { // NOLINT(*) - // scalar type - if (node->shape.size() == 0) { - os << runtime::TVMType2String(Type2TVMType(node->dtype)); - return; - } - os << "Tensor[("; - for (size_t i = 0; i < node->shape.size(); ++i) { - this->PrintAttr(node->shape[i], os); - if (i + 1 != node->shape.size()) { - os << ", "; - } - } - // conform to python tuple format (1,) - if (node->shape.size() == 1) { - os << ","; - } - os << "), " << runtime::TVMType2String(Type2TVMType(node->dtype)) << "]"; - } - - void VisitType_(const TupleTypeNode* node, std::ostream& os) final { // NOLINT(*) - os << "Tuple["; - for (size_t i = 0; i < node->fields.size(); ++i) { - this->PrintType(node->fields[i], os); - if (i + 1 != node->fields.size()) { - os << ", "; - } - } - os << "]"; - } - - void VisitType_(const RefTypeNode* node, std::ostream& os) final { - VisitTypeDefault_(node, os); - } - - void VisitType_(const TypeCallNode* node, std::ostream& os) final { - os << node->func << "(" << node->args << ")"; - } - - void VisitType_(const GlobalTypeVarNode* node, std::ostream& os) final { - VisitTypeDefault_(node, os); - } - - void VisitType_(const TypeDataNode* node, std::ostream& os) final { - VisitTypeDefault_(node, os); - } - - void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*) - // by default always print as meta-data - os << meta_.GetMetaNode(GetRef(node)); - } - - /*! - * \brief Print an attribute value to os. - * \param value The value to be printed. - * \param os The output type. - */ - void PrintAttr(const NodeRef& value, std::ostream& os) { // NOLINT(*) - if (value.defined()) { - this->VisitAttr(value, os); - } else { - os << "None"; - } - } - //------------------------------------ - // Overload of Attr printing functions - //------------------------------------ - void VisitAttr_(const ArrayNode* op, std::ostream& os) final { // NOLINT(*) - os << "["; - for (size_t i = 0; i < op->data.size(); ++i) { - this->PrintAttr(NodeRef(op->data[i]), os); - if (i + 1 != op->data.size()) { - os << ", "; - } - } - os << "]"; - } - void VisitAttrDefault_(const Node* op, std::ostream& os) final { // NOLINT(*) - os << meta_.GetMetaNode(GetRef(op)); - } - - void VisitAttr_(const ir::IntImm* op, std::ostream& os) final { // NOLINT(*) - this->PrintConstScalar(op->type, &(op->value), os); - } - - void VisitAttr_(const ir::UIntImm* op, std::ostream& os) final { // NOLINT(*) - this->PrintConstScalar(op->type, &(op->value), os); - } - - void VisitAttr_(const ir::FloatImm* op, std::ostream& os) final { // NOLINT(*) - this->PrintConstScalar(op->type, &(op->value), os); - } - - void VisitAttr_(const ir::StringImm* op, std::ostream& os) final { // NOLINT(*) - this->PrintString(op->value, os); - } - - protected: - /*! - * \brief Print attributes after call. - * \param op The operator to be called. - * \param attrs The attributes. - * \param os The output stream. - */ - void PrintCallAttrs(const Expr& op, const Attrs& attrs, std::ostream& os); // NOLINT(*) - - /*! - * \brief Print the a new scopr. - * \param body The body. - */ - void PrintScope(Expr body) { - stream_ << "{\n"; - int sid = this->BeginScope(); - this->PrintScopeBody(body); - this->EndScope(sid); - this->PrintIndent(); - stream_ << "}"; - } - /*! - * \brief Print the body of a new scope without {} - * - * This function will keep printing continuous sequence - * of let/if scope without introducing a new scope in the text. - * - * \param body The body. - */ - void PrintScopeBody(Expr body) { - if (const LetNode* let = body.as()) { - TextValue value = GetValue(let->value); - AllocVarName(let->var); - // let var = value; - this->PrintIndent(); - stream_ << "let "; - this->PrintVarDecl(let->var, stream_); - stream_ << " = " << value; - this->PrintEndInst("\n"); - this->PrintScopeBody(let->body); - } else if (const IfNode* ifnode = body.as()) { - TextValue cond = GetValue(ifnode->cond); - this->PrintIndent(); - stream_ << "if (" << cond << ") "; - this->PrintScope(ifnode->true_branch); - this->PrintIndent(); - stream_ << "else "; - this->PrintScope(ifnode->false_branch); - this->PrintEndInst("\n"); - } else { - TextValue value = GetValue(body); - this->PrintIndent(); - stream_ << value; - this->PrintEndInst("\n"); - } - } - - /*! - * \brief Internal function to print a function argument list and its body. - * \param prefix The prefix before argument list. - * \param fn The function to be printed. - */ - void PrintFuncInternal(std::string prefix, const Function& fn) { - // TODO(tqchen, M.K.) support generic function - // Possibly through meta-data - CHECK_EQ(fn->type_params.size(), 0U) - << "generic fn not yet supported"; - this->PrintIndent(); - stream_ << prefix << "("; - size_t decl_indent = prefix.length() + 1; - for (size_t i = 0; i < fn->params.size(); ++i) { - if (i != 0) { - this->PrintIndent(decl_indent); - } - AllocVarName(fn->params[i]); - this->PrintVarDecl(fn->params[i], stream_); - if (i + 1 != fn->params.size()) { - stream_ << ",\n"; - } - } - stream_ << ')'; - if (fn->ret_type.defined()) { - stream_ << '\n'; - this->PrintIndent(decl_indent); - stream_ << "-> "; - this->PrintType(fn->ret_type, stream_); - } - stream_ << ' '; - this->PrintScope(fn->body); - } - /*! - * \brief Print additional info about expr in comment. - * \param expr The expression. - */ - void PrintOptionalInfo(const Expr& expr) { - // additional information in comment. - if (annotate_ != nullptr) { - stream_ << " # " << annotate_(expr); - } else if (expr->checked_type_.defined()) { - stream_ << " # ty="; - this->PrintType(expr->checked_type(), stream_); - } - } - /*! - * \brief print var_name[:type] - * \param var The variable to be printed - * \param os The output stream - */ - void PrintVarDecl(const Var& var, std::ostream& os) { // NOLINT(*) - TextValue v = GetValue(var); - os << v; - if (var->type_annotation.defined()) { - os << ": "; - this->PrintType(var->type_annotation, os); - } - } - /*! - * \brief Get a constant scalar value. - * \param dtype The data type. - * \param data The pointer to the data. - * \tparam T the content data type holding the data. - */ - template - TextValue ConstScalar(DataType dtype, const T* data) { - std::ostringstream os; - PrintConstScalar(dtype, data, os); - return TextValue(os.str()); - } - /*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param data The pointer to hold the data. - * \param os The output stream. - */ - template - void PrintConstScalar(DataType dtype, const T* data, std::ostream& os) { // NOLINT(*) - if (dtype == Int(32)) { - os << data[0]; - } else if (dtype == Float(32)) { - os << data[0] << 'f'; - } else if (dtype == Bool()) { - PrintBool(data[0] != 0, os); - } else { - os << dtype << "(" << data[0] << ")"; - } - } - /*! - * \brief Print constant bool value. - * \param value The value to be printed. - * \param os The output stream - */ - void PrintBool(bool value, std::ostream& os) { // NOLINT(*) - if (value) { - os << "True"; - } else { - os << "False"; - } - } - /*! - * \brief Print constant string. - * \param value The value to be printed. - * \param os The output stream - */ - void PrintString(const std::string& value, std::ostream& os) { // NOLINT(*) - // TODO(M.K.): add escape. - os << "\"" << value << "\""; - } - /*! - * \brief get a unique name with the corresponding prefix - * \param prefix The prefix of the name - * \return The returned name. - */ - std::string GetUniqueName(std::string prefix) { - auto it = name_alloc_map_.find(prefix); - if (it != name_alloc_map_.end()) { - while (true) { - std::ostringstream os; - os << prefix << (++it->second); - std::string name = os.str(); - if (name_alloc_map_.count(name) == 0) { - prefix = name; - break; - } - } - } - name_alloc_map_[prefix] = 0; - return prefix; - } - /*! - * \brief mark the beginning of a new scope - * \return The scope id. - */ - int BeginScope() { - int sid = static_cast(scope_valid_.size()); - scope_valid_.push_back(true); - indent_ += 2; - return sid; - } - /*! - * \brief mark the end of an old scope. - * \param scope_id The scope id to be ended. - */ - void EndScope(int scope_id) { - scope_valid_[scope_id] = false; - indent_ -= 2; - } - /*! - * \brief Print the indent to the stream. - * \param more_indent More indentation besides the current one. - */ - void PrintIndent(int64_t more_indent = 0) { - for (int i = 0; i < indent_ + more_indent; ++i) { - stream_ << ' '; - } - } - /*! - * \brief print end of the line. - */ - void PrintEndInst(const char* suffix) { - stream_ << suffix; - } - /*! - * \brief Allocate temporary value - * \return A new text value. - */ - TextValue AllocTempVar() { - std::ostringstream os; - os << '%' << temp_var_counter_++; - return TextValue(os.str()); - } - /*! - * \brief Allocate name to a variable. - * \param var The input variable. - * \return The corresponding name. - */ - TextValue AllocVarName(const Var& var) { - std::string name = var->name_hint(); - // always make sure first name is alpha - if (name.length() != 0 && !std::isalpha(name[0])) { - name = "%v" + name; - } else { - name = "%" + name; - } - TextValue val(GetUniqueName(name)); - // still print if ir is malformed, but show the error. - if (memo_.count(var)) { - memo_[var] = TextValue(val.name + "-malformed-ir"); - } - memo_[var] = val; - return val; - } - - private: - class AttrPrinter; - friend class AttrPrinter; - /*! \brief Whether to print meta data. */ - bool show_meta_data_; - /*! \brief additional comment function */ - runtime::TypedPackedFunc annotate_; - /*! \brief meta data context */ - TextMetaDataContext meta_; - /*! \brief Check whether scope is still valid */ - std::vector scope_valid_; - /*! \brief The current indentation value */ - int indent_{0}; - /*! \brief name allocation map */ - std::unordered_map name_alloc_map_; - /*! \brief Map from expression to its text value */ - std::unordered_map memo_; - /*! \brief counter of temporary variable */ - int64_t temp_var_counter_{0}; - /*! \brief Output stream */ - std::ostringstream stream_; -}; - -/*! - * \brief Attribute printer which prints the attributes in the call. - */ -class TextPrinter::AttrPrinter: public AttrVisitor { - public: - AttrPrinter(std::ostream& stream, TextPrinter* parent) // NOLINT(*) - : stream_(stream), parent_(parent) {} - - void Visit(const char* key, double* value) final { - PrintSep(); - stream_ << key << "=" << value[0]; - } - void Visit(const char* key, int64_t* value) final { - PrintSep(); - stream_ << key << "=" << value[0]; - } - void Visit(const char* key, uint64_t* value) final { - PrintSep(); - stream_ << key << "=" << value[0]; - } - void Visit(const char* key, int* value) final { - PrintSep(); - stream_ << key << "=" << value[0]; - } - void Visit(const char* key, bool* value) final { - PrintSep(); - stream_ << key << "="; - parent_->PrintBool(value[0], stream_); - } - void Visit(const char* key, std::string* value) final { - PrintSep(); - stream_ << key << "="; - parent_->PrintString(value[0], stream_); - } - void Visit(const char* key, void** value) final { - LOG(FATAL) << "do not allow void as argument"; - } - void Visit(const char* key, DataType* value) final { - PrintSep(); - stream_ << key << "="; - parent_->PrintString(runtime::TVMType2String(Type2TVMType(value[0])), stream_); - } - void Visit(const char* key, NodeRef* value) final { - PrintSep(); - stream_ << key << "="; - parent_->PrintAttr(value[0], stream_); - } - void Visit(const char* key, runtime::NDArray* value) final { - LOG(FATAL) << "do not allow NDarray as argument"; - } - - private: - void PrintSep() { - stream_ << ", "; - } - std::ostream& stream_; // NOLINT(*) - TextPrinter* parent_; -}; - -void TextPrinter::PrintCallAttrs(const Expr& op, - const Attrs& attrs, - std::ostream& os) { // NOLINT(*) - if (!attrs.defined()) return; - if (const auto* op_node = op.as()) { - if (attrs->type_index() == op_node->attrs_type_index) { - AttrPrinter printer(os, this); - const_cast(attrs.operator->()) - ->VisitNonDefaultAttrs(&printer); - return; - } - } - os << ", " << meta_.GetMetaNode(attrs); -} - -std::string RelayPrint(const NodeRef& node, - bool show_meta_data, - runtime::TypedPackedFunc annotate) { - return TextPrinter(show_meta_data, annotate).Print(node); -} - -TVM_REGISTER_API("relay._expr.RelayPrint") -.set_body_typed)>(RelayPrint); - -} // namespace relay -} // namespace tvm diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 21bd85a3eb37..626436d9573f 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -33,8 +33,8 @@ def test_env(): text = env.astext() assert "def @myf" in text assert "def @myf" in str(env) - assert "%1 = add(%0, %0) # ty=float32" in text - assert "%1 = add(%0, %0) # ty=float32" in str(env) + assert "%1 = add(%0, %0) // ty=float32" in text + assert "%1 = add(%0, %0) // ty=float32" in str(env) show(env.astext(annotate=lambda x: str(x.checked_type.dtype))) show(text) @@ -95,7 +95,7 @@ def test_let_if_scope(): f = relay.Function([x, y, cond], result) text = f.astext() - assert text.count("{") == 4 + assert text.count("{") == 6 assert "%cond: bool" in text show(f.astext()) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 8c8e7dfd1fcc..8fb83ece0ebd 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -28,7 +28,7 @@ def initialize_box_adt(mod): def test_monomorphic_let(): - "Program: let x = 1; return x" + "Program: let x = 1; x" sb = relay.ScopeBuilder() x = sb.let('x', relay.const(1.0, "float64")) sb.ret(x) @@ -48,7 +48,7 @@ def test_add_broadcast_op(): """ Program: fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { - return x + y; + x + y } """ x = relay.var('x', shape=(10, 4)) @@ -67,7 +67,7 @@ def test_dual_op(): fn (x : Tensor[f32, (10, 10)]) { let t1 = log(x); let t2 = add(t1, x); - return t1; + t1 } """ tp = relay.TensorType((10, 10), "float32") @@ -84,7 +84,7 @@ def test_dual_op(): def test_decl(): """Program: def f(x : Tensor[(10, 10), f32]) { - return log(x); + log(x) } """ tp = relay.TensorType((10, 10)) @@ -99,9 +99,9 @@ def test_recursion(): Program: def f(n: i32, data: f32) -> f32 { if (n == 0) { - return data; + data } else { - return f(n - 1, log(data)); + f(n - 1, log(data)) } } """