Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][COMPILER] Initial Relay interpreter and compiler for TVM runtime system. #1954

Merged
merged 47 commits into from
Oct 30, 2018
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
796096e
Add evaluator and runtime system.
jroesch Sep 21, 2018
4d3d9e0
Remove uneeded C++ test
jroesch Oct 29, 2018
e5f2d87
Reformat file
jroesch Oct 29, 2018
efe62d3
Add target field for schedule
jroesch Oct 29, 2018
eb130ea
Type Compute and Schedule
jroesch Oct 29, 2018
21a52cf
Clean up deadcode and whitespace
jroesch Oct 29, 2018
1df6f07
Reformat test
jroesch Oct 29, 2018
b930f51
Fix linting
jroesch Oct 29, 2018
070ed09
Fix more Python linting
jroesch Oct 29, 2018
7c2f404
Add doc string to evaluate
jroesch Oct 29, 2018
c88b465
Add evaluate_rts docstring
jroesch Oct 29, 2018
9241d09
Fixing last couple lints
jroesch Oct 29, 2018
c05c3c5
Fix lint
jroesch Oct 29, 2018
8e4fb55
Fix another lint
jroesch Oct 29, 2018
af80e67
Remove type annotation
jroesch Oct 29, 2018
bd999a6
Add docs for pass.h
jroesch Oct 29, 2018
b677d89
remove warning
jroesch Oct 29, 2018
5967333
Fix Python2 tests
jroesch Oct 29, 2018
5c9ae4a
Try to fix 2.7 error
jroesch Oct 29, 2018
161e14a
Address some code review feedback
jroesch Oct 29, 2018
3ce8376
Rename interpreter
jroesch Oct 29, 2018
6286f04
Add TOPI to path
jroesch Oct 29, 2018
bf70e13
Address MK's comments
jroesch Oct 30, 2018
9f6412e
A couple more pieces of feedback
jroesch Oct 30, 2018
baef91b
Fix unsigned vs. signed check
jroesch Oct 30, 2018
82b51fd
Add code for generating node_row_ptr
jroesch Oct 30, 2018
2167c47
Address some more feedback
jroesch Oct 30, 2018
3c76f0f
Update python/tvm/relay/graph_runtime_codegen.py
joshpoll Oct 30, 2018
e853433
Update python/tvm/relay/interpreter.py
joshpoll Oct 30, 2018
98ae9b9
Address a few more comments
jroesch Oct 30, 2018
b4d5d31
Add doc strings
jroesch Oct 30, 2018
b69485f
Update include/tvm/relay/pass.h
joshpoll Oct 30, 2018
3f9dc5e
Update include/tvm/relay/pass.h
joshpoll Oct 30, 2018
4e3f92d
Update include/tvm/relay/pass.h
joshpoll Oct 30, 2018
604ea7b
Update include/tvm/relay/pass.h
junrushao Oct 30, 2018
a61d663
Add more docs
jroesch Oct 30, 2018
95e49a0
Rename tests for graph runtime.
jroesch Oct 30, 2018
341a0c3
Remove old debugging
jroesch Oct 30, 2018
002f376
Update include/tvm/relay/interpreter.h
zhiics Oct 30, 2018
c562275
Update python/tvm/relay/graph_runtime_codegen.py
joshpoll Oct 30, 2018
9a229c3
Update src/relay/pass/fuse_ops.cc
junrushao Oct 30, 2018
a050815
A little more clean up
jroesch Oct 30, 2018
28bfdba
Fix CPP lint
jroesch Oct 30, 2018
ecfe6e5
Fix lint issue
jroesch Oct 30, 2018
315f9c0
Fix bad space
jroesch Oct 30, 2018
764bc5e
Fix testing error (hopefully)
jroesch Oct 30, 2018
86ff607
Fix integer types issue
jroesch Oct 30, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion include/tvm/relay/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,15 @@ namespace tvm {
* You can find more about Relay by reading the language reference.
*/
namespace relay {

#define RELAY_DEBUG(...) \
{ auto fdebug = runtime::Registry::Get("relay.debug"); \
CHECK(fdebug) << "Could not find Relay Python debugger function."; \
(*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
}

/*!
* \brief we always used NodeRef for referencing nodes.
* \brief We always used NodeRef for referencing nodes.
*
* By default, NodeRef is a std::shared_ptr of node
*/
Expand Down
12 changes: 11 additions & 1 deletion include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,18 @@ class FunctionNode : public ExprNode {
*/
tvm::Array<TypeVar> type_params;

/*!
* \brief The attributes which store metadata about functions.
*/
tvm::Attrs attrs;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("span", &span);
v->Visit("attrs", &attrs);
v->Visit("_checked_type_", &checked_type_);
}

Expand All @@ -230,10 +236,14 @@ class FunctionNode : public ExprNode {
*/
TVM_DLL FuncType func_type_annotation() const;

TVM_DLL NodeRef GetAttr(const std::string& key) const;
jroesch marked this conversation as resolved.
Show resolved Hide resolved
TVM_DLL Function SetAttr(const std::string& key, const NodeRef& data) const;

TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> ty_params);
tvm::Array<TypeVar> ty_params,
tvm::Attrs attrs = Attrs());

static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
Expand Down
140 changes: 140 additions & 0 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/interpreter.h
* \brief An interpreter for Relay.
*
* This file implements a simple reference interpreter for Relay programs.
* Given a Relay environment, and a Relay expression it produces a value.
*
* The interpreter's values are a naive representation of the values that
* can be produced by a Relay program and are exposed via tvm::Node's
* system to Python for introspection and debugging.
*
* The interpreter's intent is to serve as a reference semantics for the Relay IR,
* as well as for debugging and testing.
*/
#ifndef TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_

#include <tvm/relay/environment.h>
#include <tvm/relay/expr.h>

namespace tvm {
namespace relay {

/*!
* \brief A Relay value.
*/
class Value;

/*! \brief Evaluate an expression using the interpreter producing a value.
*
* The resulting value can be passed to Python, making it easy to use
* for testing and debugging.
*
* The interpreter interprets the program fragments not supported by the
* TVM runtime, although the interpreter is naively implemented it uses
* TVM operators for evaluating all operators.
*
* Our intent is that this will never be an the most efficient implementation of
jroesch marked this conversation as resolved.
Show resolved Hide resolved
* Relay's semantics, but a readable and clear one.
*/
Value Evaluate(Environment env, Expr e);

/*! \brief The base container type of Relay values. */
class ValueNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
};

class Value : public NodeRef {
public:
Value() {}
explicit Value(NodePtr<Node> n) : NodeRef(n) {}
const ValueNode* operator->() const {
return static_cast<const ValueNode*>(node_.get());
}

using ContainerType = ValueNode;
};

/*! \brief A Relay closure, i.e a scope and a function. */
class Closure;

/*! \brief The container type of Closures. */
class ClosureNode : public ValueNode {
public:
/*! \brief The set of free variables in the closure.
*
* These are the captured variables which are required for
* evaluation when we call the closure.
*/
tvm::Map<Var, Value> env;
/*! \brief The function which implements the closure.
*
* \note May reference the variables contained in the env.
*/
Function func;

ClosureNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("env", &env);
v->Visit("func", &func);
}

TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func);

static constexpr const char* _type_key = "relay.Closure";
TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);

/*! \brief A tuple value. */
class TupleValue;

/*! \brief Tuple (x, ... y). */
struct TupleValueNode : ValueNode {
tvm::Array<Value> fields;

TupleValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }

TVM_DLL static TupleValue make(tvm::Array<Value> value);

static constexpr const char* _type_key = "relay.TupleValue";
TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value);

/*! \brief A tensor value. */
class TensorValue;

/*! \brief The tensor value container, wrapping an NDArray. */
struct TensorValueNode : ValueNode {
runtime::NDArray data;

TensorValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); }

/*! \brief Build a value from an NDArray. */
TVM_DLL static TensorValue make(runtime::NDArray data);

/*! \brief Construct an empty tensor value from t. */
TVM_DLL static TensorValue FromType(const Type& t);

static constexpr const char* _type_key = "relay.TensorValue";
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);


} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_INTERPRETER_H_
95 changes: 86 additions & 9 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_

#include <tvm/lowered_func.h>
#include <tvm/relay/environment.h>
#include <tvm/relay/expr.h>
#include <string>

namespace tvm {
namespace relay {
Expand All @@ -20,7 +22,8 @@ namespace relay {
* populated with the result type.
*
* \param expr The expression to type check.
* \param env The environment used for referencing global functions, can be None.
* \param env The environment used for referencing global functions, can be
* None.
*
* \return A type checked expression with its checked_type field populated.
*/
Expand All @@ -35,7 +38,8 @@ Expr InferType(const Expr& expr, const Environment& env);
* \return A type checked Function with its checked_type field populated.
* \note this function mutates env and is not thread-safe.
*/
Function InferType(const Function& f, const Environment& env, const GlobalVar& var);
Function InferType(const Function& f, const Environment& env,
const GlobalVar& var);

/*!
* \brief Check that types are well kinded by applying "kinding rules".
Expand Down Expand Up @@ -94,28 +98,30 @@ bool AlphaEqual(const Type& t1, const Type& t2);
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
*
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, although x is not shadowed.
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice,
* although x is not shadowed.
*
* \param e the expression to check.
jroesch marked this conversation as resolved.
Show resolved Hide resolved
*
* \return true iff all Var in e is bound at most once.
jroesch marked this conversation as resolved.
Show resolved Hide resolved
*/
bool WellFormed(const Expr& e);
jroesch marked this conversation as resolved.
Show resolved Hide resolved

/*! \brief Get free Vars from expr in PostDFS order.
/*! \brief Get free type parameters from expression e.
jroesch marked this conversation as resolved.
Show resolved Hide resolved
*
* Free variables are variables that are not bound by a
* let or a function parameter in the context.
*
* \param expr the expression.
*
* \return List of free vars, in the PostDFS order visited by expr.
* \return List of free vars, in the PostDFS order in the expression.
*/
tvm::Array<Var> FreeVars(const Expr& expr);

/*! \brief Get free TypeVars from expression expr.
*
* Free type parameters are type parameters that are not bound by a function type in the context.
* Free type parameters are type parameters that are not bound by a function
* type in the context.
*
* \param expr the expression.
*
Expand All @@ -125,10 +131,12 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);

/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let binding that are not referenced, and if branch that are not entered.
* It will remove let bindings which are not referenced, and branches that will
* not be entered.
*
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of the expression does not depend on a.
* Another example is `if (true) then 1 else 2` will be optimized into 1.
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of
* the expression does not depend on a. Another example is `if (true) then 1
* else 2` will be optimized into 1.
*
* \param e the expression to optimize.
*
Expand Down Expand Up @@ -156,7 +164,76 @@ size_t StructuralHash(const Type& type);
*/
size_t StructuralHash(const Expr& expr);

/*! \brief The hash struct for expressions. */
struct ExprHash {
jroesch marked this conversation as resolved.
Show resolved Hide resolved
size_t operator()(const Expr& a) const {
return StructuralHash(a);
}
};

/*! \brief The equal comparator for expressions. */
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider AlphaEqual struct

struct ExprEqual {
bool operator()(const Expr& a, const Expr& b) const {
return AlphaEqual(a, b);
}
};

/*! \brief A lowered Relay operation.
*
* A lowered operation is a pair containing the "primitive" function used
* to produce the lowered function as well as the lowered function itself.
jroesch marked this conversation as resolved.
Show resolved Hide resolved
*/
class LoweredOp;
jroesch marked this conversation as resolved.
Show resolved Hide resolved
/*! \brief Call container. */
class LoweredOpNode : public Node {
public:
/*!
* \brief The primitive function to be lowered.
*
* A primitive function consists only of calls to relay::Op which
* can be fused.
*/
Function func;

/*!
* \brief The lowered function.
*/
LoweredFunc lowered_func;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("lowered_func", &lowered_func);
}

TVM_DLL static LoweredOp make(
Function func,
LoweredFunc lowered_func);

static constexpr const char* _type_key = "relay.LoweredOp";
TVM_DECLARE_NODE_TYPE_INFO(LoweredOpNode, Node);
};

RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef);

/*!
* \brief Lower the operations contained in a Relay expression.
*
* The lowering pass will only lower functions marked as primitive,
* the FuseOps pass will provide this behavior, if run before LowerOps.
*
* \note This will do a reachability analysis and lower all definitions
* reachable from the provided expression.
*
* \param env The environment.
* \param expr The expression with operations to be lowered.
* \param target The target to lower the functions to.
*
* \return The set of lowered operations.
*/
Array<LoweredOp> LowerOps(const Environment& env, const Expr& expr,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function could be

LoweredFunc LowerFunc(Function);

And we run an extraction pass to get the necessary pass. Notably, this lowering function can happen inside the cached JIT engine

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand there is a rewriting going on here trying to rewrite the body to use func with special loweredFunc attr. Alternatively, we can expose a CompileEngine, which have CompileEngine.lookup(func: Function)-> LoweredFunc

const std::string& target = "llvm");

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_PASS_H_
19 changes: 19 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
from ..api import register_func
from . import base
from . import ty
from . import expr
from . import env
from . import ir_pass
from . import testing

# Root operators
from .op import Op
Expand All @@ -15,6 +18,7 @@
from . import vision
from . import image


from .scope_builder import ScopeBuilder

# Span
Expand Down Expand Up @@ -46,6 +50,21 @@
If = expr.If
TupleGetItem = expr.TupleGetItem


# helper functions
var = expr.var
const = expr.const

@register_func("relay._tensor_value_repr")
def _tensor_value_repr(tv):
return str(tv.data.asnumpy())

@register_func("relay._constant_repr")
def _tensor_constant_repr(tv):
return str(tv.data.asnumpy())

# pylint: disable=unused-argument
@register_func("relay.debug")
def _debug(*args):
import pdb
pdb.set_trace()
Loading