Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Text Format] Pretty Printer Smart Inlining #2881

Merged
merged 9 commits into from
Apr 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,39 +925,6 @@ def eliminate_common_subexpr(expr, fskip=None):
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)


def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True):
"""
THIS SHOULD BE USED ONLY FOR DEBUGGING, NOT AS AN INTERCHANGE FORMAT!
USE `.astext()` INSTEAD!

A version of the pretty printer intended for debugging passes. Contains
advanced printing options.

Parameters
----------
ast : Union[relay.Expr, relay.Module, relay.Type]
The relay fragment to be turned into text.

show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.

annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.

gnf : bool
Whether to print in GNF. If it is disabled, pointers are left implicit.

Returns
-------
text : str
A text representation of `ast`.
"""
return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf)


def partial_evaluate(expr):
"""
Evaluate the static fragment of the code.
Expand Down
112 changes: 63 additions & 49 deletions src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,22 @@
* \file pretty_printer.cc
* \brief Pretty printer for Relay programs
* Supports ANF, GNF, and metadata.
*
* Inlining heuristics:
* - Always inline:
* - GlobalVar
* - Constant
* - Op
* - Var
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/

#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
#include "type_functor.h"
#include "../pass/dependency_graph.h"
#include "../../lang/attr_functor.h"

namespace tvm {
Expand Down Expand Up @@ -135,10 +145,8 @@ class PrettyPrinter :
public TypeFunctor<Doc(const Type&)>,
public AttrFunctor<Doc(const NodeRef&)> {
public:
explicit PrettyPrinter(bool GNF,
bool show_meta_data,
explicit PrettyPrinter(bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) :
GNF_(GNF),
show_meta_data_(show_meta_data),
annotate_(annotate) {}

Expand All @@ -150,10 +158,9 @@ class PrettyPrinter :
Doc doc;
// additional information in comment.
if (annotate_ != nullptr) {
return doc << " // " << annotate_(expr);
return doc << " /* " << annotate_(expr) << " */";
} else if (expr->checked_type_.defined()) {
doc << " // ty=";
return doc << Print(expr->checked_type());
return doc << " /* ty=" << Print(expr->checked_type()) << " */";
} else {
return doc;
}
Expand All @@ -176,13 +183,18 @@ class PrettyPrinter :
// print in a new scope
doc_stack_.push_back(Doc());
// must print first so doc_stack_.back() reference doesn't become stale
Doc doc = Print(node);
Doc doc = Print(node, false, true);
doc = doc_stack_.back() << doc;
doc_stack_.pop_back();
return doc;
}

Doc PrintFinal(const NodeRef& node) {
if (node.as_derived<ExprNode>()) {
Expr expr = Downcast<Expr>(node);
dg_ = DependencyGraph::Create(&arena_, expr);
}

Doc doc;
doc << PrintScope(node);
if (!meta_.empty()) {
Expand All @@ -200,9 +212,9 @@ class PrettyPrinter :

Doc PrintAttrs(const Attrs& attrs, const Expr& op);

Doc Print(const NodeRef& node, bool meta = false) {
Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make these two member of the language instead of passing them around.
If you pass them around you should change them. If they are constant save yourself the work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are flags for the Print function and are changed by the callers in specific instances. They are only changed for one invocation of Print at a time and importantly not for recursive invocations. I think they make the most sense as flags on Print rather than as instance variables.

if (node.as_derived<ExprNode>()) {
return PrintExpr(Downcast<Expr>(node), meta);
return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as_derived<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta);
} else if (node.as_derived<ModuleNode>()) {
Expand Down Expand Up @@ -308,25 +320,38 @@ class PrettyPrinter :
return val;
}

inline bool IsAtomicExpr(const Expr& expr) {
bool IsUnique(const Expr& expr) {
return !(dg_.expr_node.at(expr)->parents.head &&
dg_.expr_node.at(expr)->parents.head->next);
}

bool AlwaysInline(const Expr& expr) {
return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() ||
expr.as<OpNode>() || expr.as<VarNode>();
}

//------------------------------------
// Overload of Expr printing functions
//------------------------------------
Doc PrintExpr(const Expr& expr, bool meta) {
Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) {
// Exploit memoization to print GNF.
// The first time we visit an expression, we need to allocate a temp var
// for it. Every subsequent time we can just use its assigned variable.
// This works since hashing uses pointer equality.

// determine whether to inline
bool inline_expr = AlwaysInline(expr);
if (try_inline) {
inline_expr |= IsUnique(expr);
}

auto it = memo_.find(expr);
if (it != memo_.end()) return it->second;

Doc printed_expr;
if (meta) {
printed_expr = meta_.GetMetaNode(GetRef<NodeRef>(expr.get()));
} else if (GNF_ && expr.as<LetNode>()) {
} else if (!inline_expr && expr.as<LetNode>()) {
// wrap GNFed let in brackets
Doc body;
printed_expr << "{";
Expand All @@ -335,28 +360,26 @@ class PrettyPrinter :
} else {
printed_expr = VisitExpr(expr);
}
// we choose to inline atomic exprs
if (GNF_ && !IsAtomicExpr(expr)) {
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr;
if (expr.as<CallNode>()) {
doc_stack_.back() << PrintOptionalInfo(expr);
}
doc_stack_.back() << "\n";
return temp_var;
} else if (expr.as<VarNode>()) {

if (expr.as<CallNode>()) {
printed_expr << PrintOptionalInfo(expr);
}

// add expr to doc
if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << "\n";
// Memoization is done in AllocVar.
return memo_[expr];
} else {
} else if (inline_expr) {
memo_[expr] = printed_expr;
if (GNF_ && expr.as<CallNode>()) {
printed_expr << PrintOptionalInfo(expr);
}
return printed_expr;
} else {
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr << "\n";
return temp_var;
}
}

Expand Down Expand Up @@ -420,8 +443,9 @@ class PrettyPrinter :

Doc VisitExpr_(const LetNode* op) final {
Doc doc;
doc << "let " << AllocVar(op->var) << " = " << Print(op->value) << "\n";
doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << "\n";
// we use a scope here so GNF hoisting doesn't escape too far
// and nested, unique lets are not hoisted
doc << PrintScope(op->body);
return doc;
}
Expand Down Expand Up @@ -456,6 +480,8 @@ class PrettyPrinter :
Doc doc;
int counter = 0;
for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second);

std::ostringstream os;
if (counter++ != 0) {
doc << "\n";
Expand Down Expand Up @@ -664,8 +690,6 @@ class PrettyPrinter :
}

private:
/*! \brief Whether to use GNF. */
bool GNF_;
/*! \brief Whether to print meta data. */
bool show_meta_data_;
/*! \brief additional comment function */
Expand All @@ -682,6 +706,10 @@ class PrettyPrinter :
TextMetaDataContext meta_;
/*! \brief counter of temporary variable */
size_t temp_var_counter_{0};
/*! \brief arena for dependency graph */
common::Arena arena_;
/*! \brief dependency graph of the expr */
DependencyGraph dg_;
class AttrPrinter;
friend class AttrPrinter;
};
Expand Down Expand Up @@ -751,37 +779,23 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) {

std::string PrettyPrint_(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate,
bool gnf) {
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
Doc doc;
doc << "v0.0.1" << "\n"
<< PrettyPrinter(gnf, show_meta_data, annotate).PrintFinal(node);
<< PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str();
}

std::string AsText(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return PrettyPrint_(node, show_meta_data, annotate, true);
}

std::string PassDebugPrint(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate,
bool gnf) {
return PrettyPrint_(node, show_meta_data, annotate, gnf);
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return PrettyPrint_(node, show_meta_data, annotate);
}

TVM_REGISTER_API("relay._expr.AsText")
.set_body_typed<std::string(const NodeRef&,
bool,
runtime::TypedPackedFunc<std::string(Expr)>)>(AsText);

TVM_REGISTER_API("relay._ir_pass.pass_debug_print")
.set_body_typed<std::string(const NodeRef&,
bool,
runtime::TypedPackedFunc<std::string(Expr)>,
bool)>(PassDebugPrint);

} // namespace relay
} // namespace tvm
Loading