Skip to content

Commit

Permalink
[RELAY][PASS] General OpFusion. (apache#2090)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and Wei Chen committed Feb 19, 2019
1 parent f2dd10d commit 7f8b2d8
Show file tree
Hide file tree
Showing 19 changed files with 1,026 additions and 113 deletions.
10 changes: 10 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const {
return node;
}

/*!
* \brief Print node as text format.
* \param node The node to be printed.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std::string RelayPrint(
const NodeRef& node,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
2 changes: 2 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> {
using TSelf = TypedPackedFunc<R(Args...)>;
/*! \brief default constructor */
TypedPackedFunc() {}
/*! \brief constructor from null */
TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*)
/*!
* \brief construct by wrap a PackedFunc
*
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/relay/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,20 @@ def register_relay_node(type_key=None):


class RelayNode(NodeBase):
def astext(self):
"""Base class of all relay node."""
def astext(self, annotate=None):
"""Get the text format of the expression.
Returns
-------
text : str
The text format of the expression.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
"""
return _expr._text_print(self)
return _expr.RelayPrint(self, annotate)


@register_relay_node
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,13 @@ def build(func,
else:
tophub_context = autotvm.util.EmptyContext()

cfg = BuildConfig.current

with tophub_context:
func = optimize(func)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as _np
from .base import RelayNode, register_relay_node
from . import _make
from . import _expr
from . import ty as _ty
from .._ffi import base as _base
from .. import nd as _nd
Expand Down Expand Up @@ -477,7 +476,7 @@ def astext(self):
text : str
The text format of the tuple expression.
"""
return _expr._text_print(self.tuple_value)
return self.tuple_value.astext()

def __getitem__(self, index):
if index >= len(self):
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,20 @@ def structural_hash(value):
raise TypeError(msg)


def fuse_ops(expr):
def fuse_ops(expr, opt_level=1):
"""Fuse operators in expr together.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
opt_level : int
The level of fuse optimization.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr)
return _ir_pass.FuseOps(expr, opt_level)
58 changes: 57 additions & 1 deletion src/common/arena.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,29 @@ class Arena {
/*!
* \brief Allocate a space from Arena for type T
* \param T the data type to be allocated
* \note The space of T is not initialized.
*/
template<typename T>
T* Alloc() {
T* allocate_() {
return static_cast<T*>(Alloc(sizeof(T), alignof(T)));
}
/*!
* \brief Create a new instance of type T.
* \param args The constructor argument.
* \tparam T the type to be created.
* \tparam Args Arguments to the constructor.
*
* \return The allocated object.
* \note The type T must be simple type, or only contain
* memory allocated from the same arena.
* Otherwise the destructor needs to be called explicitly.
*/
template<typename T, typename... Args>
T* make(Args&&... args) {
T* ptr = allocate_<T>();
new (ptr) T(std::forward<Args>(args)...);
return ptr;
}

private:
// page size 16 KB
Expand Down Expand Up @@ -87,6 +105,44 @@ class Arena {
}
};

/*!
* \brief Link list node
* \tparam T the content data type
*/
template<typename T>
struct LinkNode {
/*! \brief The content value */
T value;
/*! \brief pointer to the next location */
LinkNode<T>* next{nullptr};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
* \note This is a simple data structure that can be used together with the arena.
* \sa LinkNode
*/
template<typename T>
struct LinkedList {
/*! \brief Head pointer */
LinkNode<T>* head{nullptr};
/*! \brief Tail pointer */
LinkNode<T>* tail{nullptr};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void Push(LinkNode<T>* node) {
node->next = nullptr;
if (this->tail != nullptr) {
this->tail->next = node;
this->tail = node;
} else {
head = tail = node;
}
}
};

} // namespace common
} // namespace tvm
#endif // TVM_COMMON_ARENA_H_
23 changes: 23 additions & 0 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,29 @@ class ScheduleGetter :
return {};
}

Array<Tensor> VisitExpr_(const ConstantNode* op) final {
CHECK(op->is_scalar());
void* data = op->data->data;
DataType dtype = TVMType2Type(op->data->dtype);
Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
if (dtype == Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == Int(64)) {
return make_const(dtype, static_cast<const int64_t*>(data)[0]);
} else if (dtype == Float(32)) {
return make_const(dtype, static_cast<const float*>(data)[0]);
} else if (dtype == Float(64)) {
return make_const(dtype, static_cast<const double*>(data)[0]);
} else if (dtype == Bool()) {
return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
} else {
LOG(FATAL) << "not handled";
return tvm::Expr();
}
});
return {value};
}

Array<Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fcompute =
Op::GetAttr<FTVMCompute>("FTVMCompute");
Expand Down
26 changes: 19 additions & 7 deletions src/relay/ir/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class TextPrinter :
public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*)
public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
public:
explicit TextPrinter(runtime::TypedPackedFunc<std::string(Expr)> annotate)
: annotate_(annotate) {}
/*!
* \brief Print a node to string.
* \param node.
Expand Down Expand Up @@ -279,11 +281,11 @@ class TextPrinter :

TextValue VisitExpr_(const CallNode* op) final {
// possibly through meta-data
TextValue call_op = GetValue(op->op);
std::vector<TextValue> args;
for (Expr arg : op->args) {
args.emplace_back(GetValue(arg));
}
TextValue call_op = GetValue(op->op);
TextValue id = this->AllocTempVar();
this->PrintIndent();

Expand Down Expand Up @@ -532,7 +534,9 @@ class TextPrinter :
*/
void PrintOptionalInfo(const Expr& expr) {
// additional information in comment.
if (expr->checked_type_.defined()) {
if (annotate_ != nullptr) {
stream_ << " # " << annotate_(expr);
} else if (expr->checked_type_.defined()) {
stream_ << " # ty=";
this->PrintType(expr->checked_type(), stream_);
}
Expand Down Expand Up @@ -678,14 +682,19 @@ class TextPrinter :
name = "%" + name;
}
TextValue val(GetUniqueName(name));
CHECK(!memo_.count(var)) << "Duplicated variable " << var;
// still print if ir is malformed, but show the error.
if (memo_.count(var)) {
memo_[var] = TextValue(val.name + "-malformed-ir");
}
memo_[var] = val;
return val;
}

private:
class AttrPrinter;
friend class AttrPrinter;
/*! \brief additional comment function */
runtime::TypedPackedFunc<std::string(Expr)> annotate_;
/*! \brief meta data context */
TextMetaDataContext meta_;
/*! \brief Check whether scope is still valid */
Expand Down Expand Up @@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
os << ", " << meta_.GetMetaNode(attrs);
}

std::string RelayPrint(const NodeRef& node) {
return TextPrinter().Print(node);
std::string RelayPrint(const NodeRef& node,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return TextPrinter(annotate).Print(node);
}

TVM_REGISTER_API("relay._expr._text_print")
.set_body_typed<std::string(const NodeRef&)>(RelayPrint);
TVM_REGISTER_API("relay._expr.RelayPrint")
.set_body_typed<std::string(
const NodeRef&,
runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint);

} // namespace relay
} // namespace tvm
18 changes: 1 addition & 17 deletions src/relay/pass/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include "pattern_util.h"
#include "pass_util.h"
#include "../op/nn/layout.h"

namespace tvm {
Expand Down Expand Up @@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc<
//----------------------------------------------
// Generic Visitors for FScaleAxisBackward
//----------------------------------------------
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std::unordered_map<const Node*, size_t>
GetExprRefCount(const Expr& body) {
class ExprRefCounter : private ExprVisitor {
public:
std::unordered_map<const Node*, size_t>
Get(const Expr& body) {
this->VisitExpr(body);
return std::move(this->visit_counter_);
}
};
return ExprRefCounter().Get(body);
}

class BackwardPrep : private ExprVisitor {
public:
Expand Down
Loading

0 comments on commit 7f8b2d8

Please sign in to comment.