-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RELAY][COMPILER] Initial Relay interpreter and compiler for TVM runtime system. #1954
Changes from 18 commits
796096e
4d3d9e0
e5f2d87
efe62d3
eb130ea
21a52cf
1df6f07
b930f51
070ed09
7c2f404
c88b465
9241d09
c05c3c5
8e4fb55
af80e67
bd999a6
b677d89
5967333
5c9ae4a
161e14a
3ce8376
6286f04
bf70e13
9f6412e
baef91b
82b51fd
2167c47
3c76f0f
e853433
98ae9b9
b4d5d31
b69485f
3f9dc5e
4e3f92d
604ea7b
a61d663
95e49a0
341a0c3
002f376
c562275
9a229c3
a050815
28bfdba
ecfe6e5
315f9c0
764bc5e
86ff607
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
/*! | ||
* Copyright (c) 2018 by Contributors | ||
* \file tvm/relay/interpreter.h | ||
* \brief An interpreter for Relay. | ||
* | ||
* This file implements a simple reference interpreter for Relay programs. | ||
* Given a Relay environment, and a Relay expression it produces a value. | ||
* | ||
* The interpreter's values are a naive representation of the values that | ||
* can be produced by a Relay program and are exposed via tvm::Node's | ||
* system to Python for introspection and debugging. | ||
* | ||
* The interpreter's intent is to serve as a reference semantics for the Relay IR, | ||
* as well as for debugging and testing. | ||
*/ | ||
#ifndef TVM_RELAY_INTERPRETER_H_ | ||
#define TVM_RELAY_INTERPRETER_H_ | ||
|
||
#include <tvm/relay/environment.h> | ||
#include <tvm/relay/expr.h> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
/*! | ||
* \brief A Relay value. | ||
*/ | ||
class Value; | ||
|
||
/*! \brief Evaluate an expression using the interpreter producing a value. | ||
* | ||
* The resulting value can be passed to Python, making it easy to use | ||
* for testing and debugging. | ||
* | ||
* The interpreter interprets the program fragments not supported by the | ||
* TVM runtime, although the interpreter is naively implemented it uses | ||
* TVM operators for evaluating all operators. | ||
* | ||
* Our intent is that this will never be an the most efficient implementation of | ||
jroesch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* Relay's semantics, but a readable and clear one. | ||
*/ | ||
Value Evaluate(Environment env, Expr e); | ||
|
||
/*! \brief The base container type of Relay values. */ | ||
class ValueNode : public RelayNode { | ||
public: | ||
static constexpr const char* _type_key = "relay.Value"; | ||
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode); | ||
}; | ||
|
||
class Value : public NodeRef { | ||
public: | ||
Value() {} | ||
explicit Value(NodePtr<Node> n) : NodeRef(n) {} | ||
const ValueNode* operator->() const { | ||
return static_cast<const ValueNode*>(node_.get()); | ||
} | ||
|
||
using ContainerType = ValueNode; | ||
}; | ||
|
||
/*! \brief A Relay closure, i.e a scope and a function. */ | ||
class Closure; | ||
|
||
/*! \brief The container type of Closures. */ | ||
class ClosureNode : public ValueNode { | ||
public: | ||
/*! \brief The set of free variables in the closure. | ||
* | ||
* These are the captured variables which are required for | ||
* evaluation when we call the closure. | ||
*/ | ||
tvm::Map<Var, Value> env; | ||
/*! \brief The function which implements the closure. | ||
* | ||
* \note May reference the variables contained in the env. | ||
*/ | ||
Function func; | ||
|
||
ClosureNode() {} | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { | ||
v->Visit("env", &env); | ||
v->Visit("func", &func); | ||
} | ||
|
||
TVM_DLL static Closure make(tvm::Map<Var, Value> env, Function func); | ||
|
||
static constexpr const char* _type_key = "relay.Closure"; | ||
TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value); | ||
|
||
/*! \brief A tuple value. */ | ||
class TupleValue; | ||
|
||
/*! \brief Tuple (x, ... y). */ | ||
struct TupleValueNode : ValueNode { | ||
tvm::Array<Value> fields; | ||
|
||
TupleValueNode() {} | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); } | ||
|
||
TVM_DLL static TupleValue make(tvm::Array<Value> value); | ||
|
||
static constexpr const char* _type_key = "relay.TupleValue"; | ||
TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value); | ||
|
||
/*! \brief A tensor value. */ | ||
class TensorValue; | ||
|
||
/*! \brief The tensor value container, wrapping an NDArray. */ | ||
struct TensorValueNode : ValueNode { | ||
runtime::NDArray data; | ||
|
||
TensorValueNode() {} | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); } | ||
|
||
/*! \brief Build a value from an NDArray. */ | ||
TVM_DLL static TensorValue make(runtime::NDArray data); | ||
|
||
/*! \brief Construct an empty tensor value from t. */ | ||
TVM_DLL static TensorValue FromType(const Type& t); | ||
|
||
static constexpr const char* _type_key = "relay.TensorValue"; | ||
TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value); | ||
|
||
|
||
} // namespace relay | ||
} // namespace tvm | ||
#endif // TVM_RELAY_INTERPRETER_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,8 +6,10 @@ | |
#ifndef TVM_RELAY_PASS_H_ | ||
#define TVM_RELAY_PASS_H_ | ||
|
||
#include <tvm/lowered_func.h> | ||
#include <tvm/relay/environment.h> | ||
#include <tvm/relay/expr.h> | ||
#include <string> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
@@ -20,7 +22,8 @@ namespace relay { | |
* populated with the result type. | ||
* | ||
* \param expr The expression to type check. | ||
* \param env The environment used for referencing global functions, can be None. | ||
* \param env The environment used for referencing global functions, can be | ||
* None. | ||
* | ||
* \return A type checked expression with its checked_type field populated. | ||
*/ | ||
|
@@ -35,7 +38,8 @@ Expr InferType(const Expr& expr, const Environment& env); | |
* \return A type checked Function with its checked_type field populated. | ||
* \note this function mutates env and is not thread-safe. | ||
*/ | ||
Function InferType(const Function& f, const Environment& env, const GlobalVar& var); | ||
Function InferType(const Function& f, const Environment& env, | ||
const GlobalVar& var); | ||
|
||
/*! | ||
* \brief Check that types are well kinded by applying "kinding rules". | ||
|
@@ -94,7 +98,8 @@ bool AlphaEqual(const Type& t1, const Type& t2); | |
* | ||
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice. | ||
* | ||
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, although x is not shadowed. | ||
* `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, | ||
* although x is not shadowed. | ||
* | ||
* \param e the expression to check. | ||
jroesch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* | ||
|
@@ -103,6 +108,17 @@ bool AlphaEqual(const Type& t1, const Type& t2); | |
bool WellFormed(const Expr& e); | ||
jroesch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
/*! \brief Get free Vars from expr in PostDFS order. | ||
* | ||
* Free variables are variables that are not bound by a let or a function | ||
* parameter in the context. | ||
* | ||
* \param e the expression. | ||
* | ||
* \return the set of free variable. | ||
*/ | ||
tvm::Array<Var> FreeVariables(const Expr& e); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FreeVars? seems it is already there There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably just a merge issue. |
||
|
||
/*! \brief Get free type parameters from expression e. | ||
jroesch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* | ||
* Free variables are variables that are not bound by a | ||
* let or a function parameter in the context. | ||
|
@@ -115,7 +131,8 @@ tvm::Array<Var> FreeVars(const Expr& expr); | |
|
||
/*! \brief Get free TypeVars from expression expr. | ||
* | ||
* Free type parameters are type parameters that are not bound by a function type in the context. | ||
* Free type parameters are type parameters that are not bound by a function | ||
* type in the context. | ||
* | ||
* \param expr the expression. | ||
* | ||
|
@@ -125,10 +142,12 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr); | |
|
||
/*! \brief Remove expressions which does not effect the program result. | ||
* | ||
* It will remove let binding that are not referenced, and if branch that are not entered. | ||
* It will remove let bindings which are not referenced, and branches that will | ||
* not be entered. | ||
* | ||
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of the expression does not depend on a. | ||
* Another example is `if (true) then 1 else 2` will be optimized into 1. | ||
* For example, this pass should turn `let a = 1 in 2` into `2`, as the value of | ||
* the expression does not depend on a. Another example is `if (true) then 1 | ||
* else 2` will be optimized into 1. | ||
* | ||
* \param e the expression to optimize. | ||
* | ||
|
@@ -156,7 +175,74 @@ size_t StructuralHash(const Type& type); | |
*/ | ||
size_t StructuralHash(const Expr& expr); | ||
|
||
/*! \brief The hash struct for expressions. */ | ||
struct ExprHash { | ||
jroesch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
size_t operator()(const Expr& a) const { return StructuralHash(a); } | ||
}; | ||
|
||
/*! \brief The equal comparator for expressions. */ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider AlphaEqual struct |
||
struct ExprEqual { | ||
bool operator()(const Expr& a, const Expr& b) const { | ||
return AlphaEqual(a, b); | ||
} | ||
}; | ||
|
||
/*! \brief A lowered Relay operation. | ||
* | ||
* A lowered operation is a pair containing the "primitive" function used | ||
* to produce the lowered function as well as the lowered function itself. | ||
jroesch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
*/ | ||
class LoweredOp; | ||
jroesch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/*! \brief Call container. */ | ||
class LoweredOpNode : public Node { | ||
public: | ||
/*! | ||
* \brief The primitive function to be lowered. | ||
* | ||
* A primitive function consists only of calls to relay::Op which | ||
* can be fused. | ||
*/ | ||
Function func; | ||
|
||
/*! | ||
* \brief The lowered function. | ||
*/ | ||
LoweredFunc lowered_func; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) final { | ||
v->Visit("func", &func); | ||
v->Visit("lowered_func", &lowered_func); | ||
} | ||
|
||
TVM_DLL static LoweredOp make( | ||
Function func, | ||
LoweredFunc lowered_func); | ||
|
||
static constexpr const char* _type_key = "relay.LoweredOp"; | ||
TVM_DECLARE_NODE_TYPE_INFO(LoweredOpNode, Node); | ||
}; | ||
|
||
RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef); | ||
|
||
/*! | ||
* \brief Lower the operations contained in a Relay expression. | ||
* | ||
* The lowering pass will only lower functions marked as primitive, | ||
* the FuseOps pass will provide this behavior, if run before LowerOps. | ||
* | ||
* \note This will do a reachability analysis and lower all definitions | ||
* reachable from the provided expression. | ||
* | ||
* \param env The environment. | ||
* \param expr The expression with operations to be lowered. | ||
* \param target The target to lower the functions to. | ||
* | ||
* \return The set of lowered operations. | ||
*/ | ||
Array<LoweredOp> LowerOps(const Environment& env, const Expr& expr, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function could be
And we run an extraction pass to get the necessary pass. Notably, this lowering function can happen inside the cached JIT engine There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand there is a rewriting going on here trying to rewrite the body to use func with special loweredFunc attr. Alternatively, we can expose a CompileEngine, which have CompileEngine.lookup(func: Function)-> LoweredFunc |
||
const std::string& target = "llvm"); | ||
|
||
} // namespace relay | ||
} // namespace tvm | ||
|
||
#endif // TVM_RELAY_PASS_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
"""The interface to the Evaluator exposed from C++.""" | ||
from tvm._ffi.function import _init_api | ||
|
||
_init_api("relay._eval", __name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove mutable