Skip to content

Commit

Permalink
Rename NodePrinter -> ReprPrinter to distinguish it from other printers
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jan 20, 2020
1 parent 9cf846b commit 685fc9e
Show file tree
Hide file tree
Showing 39 changed files with 259 additions and 259 deletions.
14 changes: 7 additions & 7 deletions include/tvm/node/functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* \brief Useful macro to set NodeFunctor dispatch in a global static field.
*
* \code
* // Use NodeFunctor to implement NodePrinter similar to Visitor Pattern.
* // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern.
* // vtable allows easy patch of new Node types, without changing
* // interface of NodePrinter.
* // interface of ReprPrinter.
*
* class NodePrinter {
* class ReprPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
Expand All @@ -152,18 +152,18 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* f(e, this);
* }
*
* using FType = NodeFunctor<void (const ObjectRef&, NodePrinter* )>;
* using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* NodePrinter::FType& NodePrinter::vtable() { // NOLINT(*)
* ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*)
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, NodePrinter* p) {
* TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) {
* auto* n = static_cast<const Add*>(ref.get());
* p->print(n->a);
* p->stream << '+'
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>

#include <string>
#include <vector>
Expand Down
16 changes: 8 additions & 8 deletions include/tvm/node/printer.h → include/tvm/node/repr_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,33 @@
* under the License.
*/
/*!
* \file tvm/node/printer.h
* \file tvm/node/repr_printer.h
* \brief Printer class to print repr string of each AST/IR nodes.
*/
#ifndef TVM_NODE_PRINTER_H_
#define TVM_NODE_PRINTER_H_
#ifndef TVM_NODE_REPR_PRINTER_H_
#define TVM_NODE_REPR_PRINTER_H_

#include <tvm/node/functor.h>
#include <iostream>

namespace tvm {
/*! \brief A printer class to print the AST/IR nodes. */
class NodePrinter {
class ReprPrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};

explicit NodePrinter(std::ostream& stream) // NOLINT(*)
explicit ReprPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}

/*! \brief The node to be printed. */
TVM_DLL void Print(const ObjectRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, NodePrinter*)>;
using FType = NodeFunctor<void(const ObjectRef&, ReprPrinter*)>;
TVM_DLL static FType& vtable();
};

Expand All @@ -60,9 +60,9 @@ namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
NodePrinter(os).Print(n);
ReprPrinter(os).Print(n);
return os;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_NODE_PRINTER_H_
#endif // TVM_NODE_REPR_PRINTER_H_
4 changes: 2 additions & 2 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) {
}
}

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ConstIntBoundNode*>(node.get());
p->stream << "ConstIntBound[";
PrintBoundValue(p->stream, op->min_value);
Expand Down
4 changes: 2 additions & 2 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -813,8 +813,8 @@ IntSet EvalSet(Range r,

TVM_REGISTER_NODE_TYPE(IntervalSetNode);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntervalSetNode*>(node.get());
p->stream << "IntervalSet"
<< "[" << op->min_value << ", "
Expand Down
4 changes: 2 additions & 2 deletions src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) {
data_ = std::move(node);
}

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ModularSetNode*>(node.get());
p->stream << "ModularSet("
<< "coeff=" << op->coeff << ", base="
Expand Down
8 changes: 4 additions & 4 deletions src/ir/adt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ TVM_REGISTER_GLOBAL("relay._make.Constructor")
return Constructor(name_hint, inputs, belong_to);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ConstructorNode*>(ref.get());
p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")";
Expand All @@ -71,8 +71,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeData")
return TypeData(header, type_vars, constructors);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TypeDataNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TypeDataNode*>(ref.get());
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")";
Expand Down
4 changes: 2 additions & 2 deletions src/ir/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
return Attrs(n);
}

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DictAttrsNode*>(node.get());
p->stream << op->dict;
});
Expand Down
4 changes: 2 additions & 2 deletions src/ir/env_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const EnvFuncNode*>(node.get());
p->stream << "EnvFunc(" << op->name << ")";
});
Expand Down
28 changes: 14 additions & 14 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ TVM_REGISTER_GLOBAL("make.IntImm")

TVM_REGISTER_NODE_TYPE(IntImmNode);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntImmNode*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
Expand All @@ -104,8 +104,8 @@ TVM_REGISTER_GLOBAL("make.FloatImm")

TVM_REGISTER_NODE_TYPE(FloatImmNode);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloatImmNode*>(node.get());
auto& stream = p->stream;
switch (op->dtype.bits()) {
Expand Down Expand Up @@ -134,8 +134,8 @@ Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
return Range(make_object<RangeNode>(min, extent));
}

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
Expand All @@ -159,15 +159,15 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalVar")
return GlobalVar(name);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const GlobalVarNode*>(ref.get());
p->stream << "GlobalVar(" << node->name_hint << ")";
});

// Container printer
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ArrayNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ArrayNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ArrayNode*>(node.get());
p->stream << '[';
for (size_t i = 0 ; i < op->data.size(); ++i) {
Expand All @@ -179,8 +179,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << ']';
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<MapNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MapNode*>(node.get());
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
Expand All @@ -194,8 +194,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << '}';
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<StrMapNode>([](const ObjectRef& node, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StrMapNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StrMapNode*>(node.get());
p->stream << '{';
for (auto it = op->data.begin(); it != op->data.end(); ++it) {
Expand Down
4 changes: 2 additions & 2 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd")
mod->ImportFromStd(path);
});;

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const IRModuleNode*>(ref.get());
p->stream << "IRModuleNode( " << node->functions << ")";
});
Expand Down
4 changes: 2 additions & 2 deletions src/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ TVM_REGISTER_NODE_TYPE(OpNode)
return static_cast<const OpNode*>(n)->name;
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<OpNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const OpNode*>(ref.get());
p->stream << "Op(" << node->name << ")";
});
Expand Down
8 changes: 4 additions & 4 deletions src/ir/span.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ SourceName SourceName::Get(const std::string& name) {
TVM_REGISTER_GLOBAL("relay._make.SourceName")
.set_body_typed(SourceName::Get);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SourceNameNode*>(ref.get());
p->stream << "SourceName(" << node->name << ", " << node << ")";
});
Expand All @@ -73,8 +73,8 @@ TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_GLOBAL("relay._make.Span")
.set_body_typed(SpanNode::make);

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
p->stream << "Span(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
Expand Down
6 changes: 3 additions & 3 deletions src/ir/tensor_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

namespace tvm {

using tvm::NodePrinter;
using tvm::ReprPrinter;
using namespace tvm::runtime;

TensorType::TensorType(Array<PrimExpr> shape, DataType dtype) {
Expand Down Expand Up @@ -60,8 +60,8 @@ TVM_REGISTER_GLOBAL("relay._make.TensorType")
return TensorType(shape, dtype);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TensorTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TensorTypeNode*>(ref.get());
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
Expand Down
20 changes: 10 additions & 10 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/node/printer.h>
#include <tvm/node/repr_printer.h>
#include <tvm/ir/transform.h>

// TODO(tqchen): Update to use String container after it is merged.
Expand All @@ -38,7 +38,7 @@ namespace transform {

using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
using tvm::NodePrinter;
using tvm::ReprPrinter;

struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
Expand Down Expand Up @@ -341,8 +341,8 @@ TVM_REGISTER_GLOBAL("relay._transform.Info")
*ret = pass->Info();
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::ReprPrinter* p) {
auto* node = static_cast<const PassInfoNode*>(ref.get());
p->stream << "The meta data of the pass: ";
p->stream << "pass name: " << node->name;
Expand Down Expand Up @@ -371,8 +371,8 @@ TVM_REGISTER_GLOBAL("relay._transform.RunPass")
*ret = pass(mod);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModulePassNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModulePassNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ModulePassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Module pass: " << info->name
Expand All @@ -391,8 +391,8 @@ TVM_REGISTER_GLOBAL("relay._transform.Sequential")
*ret = Sequential(passes, pass_info);
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SequentialNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SequentialNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const SequentialNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Sequential pass: " << info->name
Expand Down Expand Up @@ -421,8 +421,8 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext")
*ret = pctx;
});

TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<PassContextNode>([](const ObjectRef& ref, NodePrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PassContextNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PassContextNode*>(ref.get());
p->stream << "Pass context information: " << "\n";
p->stream << "\topt_level: " << node->opt_level << "\n";
Expand Down
Loading

0 comments on commit 685fc9e

Please sign in to comment.