Skip to content

Commit

Permalink
[REFACTOR] IRPrinter->NodePrinter, move to node/printer.h
Browse files Browse the repository at this point in the history
Rationale: printer is a common infra that is shared across all nodes.
  • Loading branch information
tqchen committed Jan 5, 2020
1 parent 8152360 commit f2feb4b
Show file tree
Hide file tree
Showing 35 changed files with 365 additions and 303 deletions.
31 changes: 0 additions & 31 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,37 +470,6 @@ inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
}
return ret;
}

// Printer infra.
/*! \brief A Pretty printer class to print the IR. */
class IRPrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};
explicit IRPrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}

/*! \brief The node to be printed. */
TVM_DLL void Print(const ObjectRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, IRPrinter *)>;
TVM_DLL static FType& vtable();
};
} // namespace tvm

namespace tvm {
namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
IRPrinter(os).Print(n);
return os;
}
} // namespace runtime
} // namespace tvm

namespace std {
Expand Down
20 changes: 11 additions & 9 deletions include/tvm/node/functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@
#define TVM_NODE_FUNCTOR_H_

#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/node.h>
#include <tvm/runtime/object.h>

#include <vector>
#include <type_traits>
#include <utility>

namespace tvm {

using runtime::ObjectRef;

/*!
* \brief A dynamically dispatched functor on the type of the first argument.
*
Expand Down Expand Up @@ -137,11 +139,11 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* \brief Useful macro to set NodeFunctor dispatch in a global static field.
*
* \code
* // Use NodeFunctor to implement IRPrinter similar to Visitor Pattern.
* // Use NodeFunctor to implement NodePrinter similar to Visitor Pattern.
* // vtable allows easy patch of new Node types, without changing
* // interface of IRPrinter.
* // interface of NodePrinter.
*
* class IRPrinter {
* class NodePrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
Expand All @@ -150,18 +152,18 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
* f(e, this);
* }
*
* using FType = NodeFunctor<void (const ObjectRef&, IRPrinter *)>;
* using FType = NodeFunctor<void (const ObjectRef&, NodePrinter* )>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* IRPrinter::FType& IRPrinter::vtable() { // NOLINT(*)
* NodePrinter::FType& NodePrinter::vtable() { // NOLINT(*)
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, IRPrinter* p) {
* TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, NodePrinter* p) {
* auto* n = static_cast<const Add*>(ref.get());
* p->print(n->a);
* p->stream << '+'
Expand Down
1 change: 1 addition & 0 deletions include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
#include <tvm/node/printer.h>

#include <string>
#include <vector>
Expand Down
61 changes: 61 additions & 0 deletions include/tvm/node/printer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/printer.h
* \brief Printer class to print repr string of each AST/IR nodes.
*/
#ifndef TVM_NODE_PRINTER_H_
#define TVM_NODE_PRINTER_H_

#include <tvm/node/functor.h>
#include <iostream>

namespace tvm {
/*! \brief A printer class to print the AST/IR nodes. */
class NodePrinter {
public:
/*! \brief The output stream */
std::ostream& stream;
/*! \brief The indentation level. */
int indent{0};

explicit NodePrinter(std::ostream& stream) // NOLINT(*)
: stream(stream) {}

/*! \brief The node to be printed. */
TVM_DLL void Print(const ObjectRef& node);
/*! \brief Print indent to the stream */
TVM_DLL void PrintIndent();
// Allow registration to be printer.
using FType = NodeFunctor<void(const ObjectRef&, NodePrinter*)>;
TVM_DLL static FType& vtable();
};
} // namespace tvm

namespace tvm {
namespace runtime {
// default print function for all objects
// provide in the runtime namespace as this is where objectref originally comes from.
inline std::ostream& operator<<(std::ostream& os, const ObjectRef& n) { // NOLINT(*)
NodePrinter(os).Print(n);
return os;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_NODE_PRINTER_H_
4 changes: 2 additions & 2 deletions src/arithmetic/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) {
}
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const ConstIntBoundNode*>(node.get());
p->stream << "ConstIntBound[";
PrintBoundValue(p->stream, op->min_value);
Expand Down
4 changes: 2 additions & 2 deletions src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,8 @@ IntSet EvalSet(Range r,

TVM_REGISTER_NODE_TYPE(IntervalSetNode);

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntervalSetNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IntervalSetNode*>(node.get());
p->stream << "IntervalSet"
<< "[" << op->min_value << ", "
Expand Down
4 changes: 2 additions & 2 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) {
data_ = std::move(node);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<ModularSetNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const ModularSetNode*>(node.get());
p->stream << "ModularSet("
<< "coeff=" << op->coeff << ", base="
Expand Down
8 changes: 4 additions & 4 deletions src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ using runtime::PackedFunc;
TVM_REGISTER_NODE_TYPE(TargetNode);
TVM_REGISTER_NODE_TYPE(GenericFuncNode);

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TargetNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const TargetNode*>(node.get());
p->stream << op->str();
});
Expand Down Expand Up @@ -665,8 +665,8 @@ tvm::BuildConfig BuildConfig::Current() {

TVM_REGISTER_NODE_TYPE(BuildConfigNode);

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const BuildConfigNode*>(node.get());
p->stream << "build_config(";
p->stream << "data_alignment=" << op->data_alignment << ", ";
Expand Down
9 changes: 5 additions & 4 deletions src/ir/span.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \brief The span data structure.
*/
#include <tvm/ir/span.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>

namespace tvm {
Expand Down Expand Up @@ -48,8 +49,8 @@ SourceName SourceName::Get(const std::string& name) {
TVM_REGISTER_GLOBAL("relay._make.SourceName")
.set_body_typed(SourceName::Get);

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SourceNameNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const SourceNameNode*>(ref.get());
p->stream << "SourceName(" << node->name << ", " << node << ")";
});
Expand All @@ -73,8 +74,8 @@ TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_GLOBAL("relay._make.Span")
.set_body_typed(SpanNode::make);

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<SpanNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const SpanNode*>(ref.get());
p->stream << "Span(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
Expand Down
13 changes: 7 additions & 6 deletions src/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief Common type system AST nodes throughout the IR.
*/
#include <tvm/ir/type.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>

namespace tvm {
Expand All @@ -40,8 +41,8 @@ TVM_REGISTER_GLOBAL("relay._make.TypeVar")
return TypeVarNode::make(name, static_cast<TypeKind>(kind));
});

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVar(" << node->name_hint << ", "
<< node->kind << ")";
Expand All @@ -61,8 +62,8 @@ TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar")
return GlobalTypeVarNode::make(name, static_cast<TypeKind>(kind));
});

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
p->stream << "GlobalTypeVar(" << node->name_hint << ", "
<< node->kind << ")";
Expand All @@ -85,8 +86,8 @@ TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_GLOBAL("relay._make.FuncType")
.set_body_typed(FuncTypeNode::make);

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FuncTypeNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const FuncTypeNode*>(ref.get());
p->stream << "FuncType(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", "
Expand Down
4 changes: 2 additions & 2 deletions src/lang/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
return Attrs(n);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const DictAttrsNode*>(node.get());
p->stream << op->dict;
});
Expand Down
4 changes: 2 additions & 2 deletions src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,8 @@ Buffer BufferNode::make(Var data,
return Buffer(n);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BufferNode>([](const ObjectRef& node, IRPrinter *p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BufferNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const BufferNode*>(node.get());
p->stream << "buffer(" << op->name << ", " << op << ")";
});
Expand Down
8 changes: 4 additions & 4 deletions src/lang/data_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
return -1;
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LayoutNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<LayoutNode>([](const ObjectRef& node, NodePrinter* p) {
auto* l = static_cast<const LayoutNode*>(node.get());
p->stream << "Layout(" << l->name << ")";
});
Expand Down Expand Up @@ -361,8 +361,8 @@ BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
return BijectiveLayout(n);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BijectiveLayoutNode>([](const ObjectRef& node, NodePrinter* p) {
auto* b = static_cast<const BijectiveLayoutNode*>(node.get());
p->stream << "BijectiveLayout(" << b->src_layout.name()
<< "->" << b->dst_layout.name() << ")";
Expand Down
37 changes: 6 additions & 31 deletions src/lang/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,33 +97,8 @@ Var var(std::string name_hint, DataType t) {
return Var(name_hint, t);
}

void IRPrinter::Print(const ObjectRef& ir) {
static const FType& f = vtable();
if (!ir.defined()) {
stream << "(nullptr)";
} else {
if (f.can_dispatch(ir)) {
f(ir, this);
} else {
// default value, output type key and addr.
stream << ir->GetTypeKey() << "(" << ir.get() << ")";
}
}
}

void IRPrinter::PrintIndent() {
for (int i = 0; i < indent; ++i) {
stream << ' ';
}
}

IRPrinter::FType& IRPrinter::vtable() {
static FType inst;
return inst;
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntImm>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IntImm>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IntImm*>(node.get());
if (op->dtype == DataType::Int(32)) {
p->stream << op->value;
Expand All @@ -132,8 +107,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}
});

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<IterVarNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const IterVarNode*>(node.get());
p->stream << "iter_var(";
if (op->var->name_hint.length() != 0) {
Expand All @@ -148,8 +123,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << ")";
});

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, IRPrinter* p) {
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
Expand Down
Loading

0 comments on commit f2feb4b

Please sign in to comment.