Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR] Establish printer in the source folder #4752

Merged
merged 2 commits into from
Jan 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ file(GLOB_RECURSE COMPILER_SRCS
src/autotvm/*.cc
src/tir/*.cc
src/driver/*.cc
src/printer/*.cc
src/api/*.cc
)

Expand Down
4 changes: 2 additions & 2 deletions apps/lldb/tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _GetContext(debugger):
def PrettyPrint(debugger, command, result, internal_dict):
ctx = _GetContext(debugger)
rc = ctx.EvaluateExpression(
"tvm::relay::PrettyPrint({command})".format(command=command)
"tvm::PrettyPrint({command})".format(command=command)
)
result.AppendMessage(str(rc))

Expand Down Expand Up @@ -175,7 +175,7 @@ def _EvalExpressionAsString(logger, ctx, expr):

def _EvalAsNodeRef(logger, ctx, value):
return _EvalExpressionAsString(
logger, ctx, "tvm::relay::PrettyPrint({name})".format(name=value.name)
logger, ctx, "tvm::PrettyPrint({name})".format(name=value.name)
)


Expand Down
28 changes: 28 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,5 +308,33 @@ class IRModule : public ObjectRef {
TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
};

/*!
* \brief Pretty print a node for debug purposes.
*
* \param node The node to be printed.
* \return The text reperesentation.
* \note This function does not show version or meta-data.
* Use AsText if you want to store the text.
* \sa AsText.
*/
TVM_DLL std::string PrettyPrint(const ObjectRef& node);

/*!
* \brief Render the node as a string in the text format.
*
* \param node The node to be rendered.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
*
* \note We support a limited set of IR nodes that are part of
* relay IR and
*
* \sa PrettyPrint.
* \return The text representation.
*/
TVM_DLL std::string AsText(const ObjectRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate = nullptr);
} // namespace tvm
#endif // TVM_IR_MODULE_H_
14 changes: 7 additions & 7 deletions include/tvm/node/functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* \brief Useful macro to set NodeFunctor dispatch in a global static field.
*
* \code
* // Use NodeFunctor to implement NodePrinter similar to Visitor Pattern.
* // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern.
* // vtable allows easy patch of new Node types, without changing
* // interface of NodePrinter.
* // interface of ReprPrinter.
*
* class NodePrinter {
* class ReprPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
Expand All @@ -152,18 +152,18 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* f(e, this);
* }
*
* using FType = NodeFunctor<void (const ObjectRef&, NodePrinter* )>;
* using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* NodePrinter::FType& NodePrinter::vtable() { // NOLINT(*)
* ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*)
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, NodePrinter* p) {
* TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) {
* auto* n = static_cast<const Add*>(ref.get());
* p->print(n->a);
* p->stream << '+'
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>

#include <string>
#include <vector>
Expand Down
16 changes: 8 additions & 8 deletions include/tvm/node/printer.h → include/tvm/node/repr_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,33 @@
* under the License.
*/
/*!
* \file tvm/node/printer.h
* \file tvm/node/repr_printer.h
* \brief Printer class to print repr string of each AST/IR nodes.
*/
#ifndef TVM_NODE_PRINTER_H_
#define TVM_NODE_PRINTER_H_
#ifndef TVM_NODE_REPR_PRINTER_H_
#define TVM_NODE_REPR_PRINTER_H_

#include <tvm/node/functor.h>
#include <iostream>

namespace tvm {
/*! \brief A printer class to print the AST/IR nodes. */
class NodePrinter {
class ReprPrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};

explicit NodePrinter(std::ostream& stream) // NOLINT(*)
explicit ReprPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}

/*! \brief The node to be printed. */
TVM_DLL void Print(const ObjectRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, NodePrinter*)>;
using FType = NodeFunctor<void(const ObjectRef&, ReprPrinter*)>;
TVM_DLL static FType& vtable();
};

Expand All @@ -60,9 +60,9 @@ namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
NodePrinter(os).Print(n);
ReprPrinter(os).Print(n);
return os;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_NODE_PRINTER_H_
#endif // TVM_NODE_REPR_PRINTER_H_
16 changes: 2 additions & 14 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <string>
#include <functional>
#include "./base.h"
Expand All @@ -40,6 +41,7 @@ using BaseFunc = tvm::BaseFunc;
using BaseFuncNode = tvm::BaseFuncNode;
using GlobalVar = tvm::GlobalVar;
using GlobalVarNode = tvm::GlobalVarNode;
using tvm::PrettyPrint;

/*!
* \brief Constant tensor, backed by an NDArray on the cpu(0) device.
Expand Down Expand Up @@ -539,20 +541,6 @@ class TempExpr : public Expr {
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode);
};

/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
std::string PrettyPrint(const ObjectRef& node);

/*!
* \brief Render the node as a string in the Relay text format.
* \param node The node to be rendered.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std::string AsText(const ObjectRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);

/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
Expand Down
4 changes: 2 additions & 2 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) {
}
}

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ConstIntBoundNode*>(node.get());
p->stream << "ConstIntBound[";
PrintBoundValue(p->stream, op->min_value);
Expand Down
4 changes: 2 additions & 2 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -813,8 +813,8 @@ IntSet EvalSet(Range r,

TVM_REGISTER_NODE_TYPE(IntervalSetNode);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntervalSetNode*>(node.get());
p->stream << "IntervalSet"
<< "[" << op->min_value << ", "
Expand Down
4 changes: 2 additions & 2 deletions src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) {
data_ = std::move(node);
}

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ModularSetNode*>(node.get());
p->stream << "ModularSet("
<< "coeff=" << op->coeff << ", base="
Expand Down
8 changes: 4 additions & 4 deletions src/ir/adt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ TVM_REGISTER_GLOBAL("relay._make.Constructor")
return Constructor(name_hint, inputs, belong_to);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ConstructorNode*>(ref.get());
p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")";
Expand All @@ -71,8 +71,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeData")
return TypeData(header, type_vars, constructors);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TypeDataNode*>(ref.get());
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")";
Expand Down
4 changes: 2 additions & 2 deletions src/ir/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
return Attrs(n);
}

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DictAttrsNode*>(node.get());
p->stream << op->dict;
});
Expand Down
4 changes: 2 additions & 2 deletions src/ir/env_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const EnvFuncNode*>(node.get());
p->stream << "EnvFunc(" << op->name << ")";
});
Expand Down
2 changes: 1 addition & 1 deletion src/ir/error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) {
//
// The annotation callback will annotate the error messages
// contained in the map.
annotated_prog << relay::AsText(func, false, [&err_map](tvm::relay::Expr expr) {
annotated_prog << AsText(func, false, [&err_map](const ObjectRef& expr) {
auto it = err_map.find(expr);
if (it != err_map.end()) {
CHECK_NE(it->second.size(), 0);
Expand Down
28 changes: 14 additions & 14 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ TVM_REGISTER_GLOBAL("make.IntImm")

TVM_REGISTER_NODE_TYPE(IntImmNode);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntImmNode*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
Expand All @@ -104,8 +104,8 @@ TVM_REGISTER_GLOBAL("make.FloatImm")

TVM_REGISTER_NODE_TYPE(FloatImmNode);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloatImmNode*>(node.get());
auto& stream = p->stream;
switch (op->dtype.bits()) {
Expand Down Expand Up @@ -134,8 +134,8 @@ Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
return Range(make_object<RangeNode>(min, extent));
}

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
Expand All @@ -159,15 +159,15 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalVar")
return GlobalVar(name);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const GlobalVarNode*>(ref.get());
p->stream << "GlobalVar(" << node->name_hint << ")";
});

// Container printer
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ArrayNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ArrayNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ArrayNode*>(node.get());
p->stream << '[';
for (size_t i = 0 ; i < op->data.size(); ++i) {
Expand All @@ -179,8 +179,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ']';
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MapNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MapNode*>(node.get());
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
Expand All @@ -194,8 +194,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << '}';
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<StrMapNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StrMapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StrMapNode*>(node.get());
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
Expand Down
4 changes: 2 additions & 2 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd")
mod->ImportFromStd(path);
});;

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IRModuleNode*>(ref.get());
p->stream << "IRModuleNode( " << node->functions << ")";
});
Expand Down
4 changes: 2 additions & 2 deletions src/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ TVM_REGISTER_NODE_TYPE(OpNode)
return static_cast<const OpNode*>(n)->name;
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const OpNode*>(ref.get());
p->stream << "Op(" << node->name << ")";
});
Expand Down
Loading