Skip to content

Commit

Permalink
[Relay] Reference (apache#2489)
Browse files Browse the repository at this point in the history
* move

fix test

fix lint

fix test

add more code

fix lint

better type infer ability

* fix build

* address comment
  • Loading branch information
MarisaKirisame authored and ZihengJiang committed Feb 15, 2019
1 parent 895ef97 commit d05fed2
Show file tree
Hide file tree
Showing 25 changed files with 578 additions and 29 deletions.
68 changes: 67 additions & 1 deletion include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,78 @@ class TupleGetItemNode : public ExprNode {

TVM_DLL static TupleGetItem make(Expr tuple, int index);

static constexpr const char * _type_key = "relay.TupleGetItem";
static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

/*! \brief Create a new Reference out of initial value. */
class RefCreate;
class RefCreateNode : public ExprNode {
public:
/*! \brief The initial value of the Reference. */
Expr value;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static RefCreate make(Expr value);

static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr);

/*! \brief Get value out of Reference. */
class RefRead;
class RefReadNode : public ExprNode {
public:
/*! \brief The Reference Expression. */
Expr ref;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("ref", &ref);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static RefRead make(Expr ref);

static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr);

/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite;
class RefWriteNode : public ExprNode {
public:
/*! \brief The Reference Expression. */
Expr ref;
/*! \brief The value to write into. */
Expr value;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("ref", &ref);
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static RefWrite make(Expr ref, Expr value);

static constexpr const char* _type_key = "relay.RefWrite";
TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);

/*!
* \brief Base class of the temporary expression.
*
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -108,6 +111,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
return vtable;
}
};
Expand All @@ -133,6 +139,9 @@ class ExprVisitor
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
void VisitExpr_(const RefCreateNode* op) override;
void VisitExpr_(const RefReadNode* op) override;
void VisitExpr_(const RefWriteNode* op) override;
virtual void VisitType(const Type& t);

protected:
Expand Down Expand Up @@ -168,6 +177,9 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
Expr VisitExpr_(const RefCreateNode* op) override;
Expr VisitExpr_(const RefReadNode* op) override;
Expr VisitExpr_(const RefWriteNode* op) override;
/*!
* \brief Used to visit the types inside of expressions.
*
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,25 @@ struct TensorValueNode : ValueNode {

RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);

/*! \brief A reference value. */
class RefValue;

struct RefValueNode : ValueNode {
mutable Value value;

RefValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
}

TVM_DLL static RefValue make(Value val);

static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);

} // namespace relay
} // namespace tvm
Expand Down
27 changes: 27 additions & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,33 @@ class TupleTypeNode : public TypeNode {

RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);

/*!
* \brief The type of reference values.
*/
class RefType;
/*!
* \brief Reference Type in relay.
*/
class RefTypeNode : public TypeNode {
public:
/*! \brief The type of value in the Reference. */
Type value;

RefTypeNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("span", &span);
}

TVM_DLL static RefType make(Type value);

static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type);

class TypeReporter;

/*!
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type
RefType = ty.RefType

# Expr
Expr = expr.Expr
Expand All @@ -56,15 +57,18 @@
Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprMutator = expr_functor.ExprMutator
RefCreate = expr.RefCreate
RefRead = expr.RefRead
RefWrite = expr.RefWrite

# helper functions
var = expr.var
const = expr.const
bind = expr.bind

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprMutator = expr_functor.ExprMutator

# Parser
fromtext = parser.fromtext
9 changes: 9 additions & 0 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,15 @@ def visit_call(self, call):
def visit_op(self, _):
raise Exception("can not compile op in non-eta expanded form")

def visit_ref_create(self, _):
raise RuntimeError("reference not supported")

def visit_ref_read(self, _):
raise RuntimeError("reference not supported")

def visit_ref_write(self, _):
raise RuntimeError("reference not supported")

def _get_json(self):
"""
Convert the sequence of nodes stored by the compiler into the
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __repr__(self):
def __iter__(self):
return iter(self.fields)


@register_relay_node
class Closure(Value):
"""A closure produced by the interpreter."""
Expand Down Expand Up @@ -79,6 +80,13 @@ def __str__(self):
return str(self.data)


@register_relay_node
class RefValue(Value):
def __init__(self, value):
self.__init_handle_by_constructor__(
_make.RefValue, value)


def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(_nd.cpu(0)))
Expand Down
40 changes: 40 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,46 @@ def __init__(self, tuple_value, index):
_make.TupleGetItem, tuple_value, index)


@register_relay_node
class RefCreate(Expr):
"""Create a new reference from initial value.
Parameters
----------
value: tvm.relay.Expr
The initial value.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefCreate, value)


@register_relay_node
class RefRead(Expr):
"""Get the value inside the reference.
Parameters
----------
ref: tvm.relay.Expr
The reference.
"""
def __init__(self, ref):
self.__init_handle_by_constructor__(_make.RefRead, ref)


@register_relay_node
class RefWrite(Expr):
"""
Update the value inside the reference.
The whole expression will evaluate to an empty tuple.
Parameters
----------
ref: tvm.relay.Expr
The reference.
value: tvm.relay.Expr
The new value.
"""
def __init__(self, ref, value):
self.__init_handle_by_constructor__(_make.RefWrite, ref, value)


class TempExpr(Expr):
"""Baseclass of all TempExpr.
Expand Down
18 changes: 16 additions & 2 deletions python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def visit(self, expr):
res = self.visit_constant(expr)
elif isinstance(expr, Op):
res = self.visit_op(expr)
elif isinstance(expr, RefCreate):
res = self.visit_ref_create(expr)
elif isinstance(expr, RefRead):
res = self.visit_ref_read(expr)
elif isinstance(expr, RefWrite):
res = self.visit_ref_write(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))

Expand Down Expand Up @@ -81,6 +87,14 @@ def visit_op(self, _):
def visit_constant(self, _):
raise NotImplementedError()

def visit_ref_create(self, _):
raise NotImplementedError()

def visit_ref_write(self, _):
raise NotImplementedError()

def visit_ref_read(self, _):
raise NotImplementedError()

class ExprMutator(ExprFunctor):
"""
Expand Down Expand Up @@ -145,8 +159,8 @@ def visit_constructor(self, con):
def visit_match(self, m):
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern])

def visit_ref_new(self, r):
return RefNew(self.visit(r.value))
def visit_ref_create(self, r):
return RefCreate(self.visit(r.value))

def visit_ref_write(self, r):
return RefWrite(self.visit(r.ref), self.visit(r.value))
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,19 @@ def __init__(self, func, args, num_inputs, attrs):
func, args, num_inputs, attrs)


@register_relay_node
class RefType(Type):
"""Reference Type in relay.
Parameters
----------
value: Type
The value type.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefType, value)


def scalar_type(dtype):
"""Creates a scalar type.
Expand Down
Loading

0 comments on commit d05fed2

Please sign in to comment.