Skip to content

Commit

Permalink
Add MNIST Lenet example, including WIP front-end improvements (apache#19
Browse files Browse the repository at this point in the history
)

Add an example of LeNet, modify front-end and refactor remaining AST nodes

* Add mnist_lenet example and dataset code from TinyFlow

* Move datasets and add comments/examples

* Add local scopes and fix a few issues

* First attempt at writing lenet test

A few problems we noticed:
  - Looking up a value in Environment.items crashed with a TVM error.
  - The C++ AssignmentNode definition does not encode an ordering for
    the assignments (since it uses a tvm::Map).

* failed attempt to implement lookup

* Add notes about why current situation isn't good

* Address annoying pointer liveness problems by copying

* fix typo that prevented compilation

* Refactor Let node
Modify assignment and rename to Let

* Remove hashing hacking

* Fix environment equality handling

* Test now passes

* Fix some of the linting issues

* fix linting errors

* Add function values

* Fix NYI message, and trigger CI again

* Integrate with changes from master

* Use ostreamstring?

* Use string concat instead

* Repair pretty printing tests

* Address CR comment

* Tweak style, auto-formatter is not working
  • Loading branch information
jroesch committed Aug 16, 2018
1 parent 2c79399 commit 475f77e
Show file tree
Hide file tree
Showing 16 changed files with 477 additions and 99 deletions.
16 changes: 12 additions & 4 deletions relay/include/relay/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,30 @@ namespace relay {
struct Environment;

/*! \brief Integer literal `0`, `1000`. */
struct EnvironmentNode : ValueNode {
std::unordered_map<std::string, GlobalId> table;
class EnvironmentNode : public ValueNode {
private:
std::unordered_map<std::string, GlobalId> global_map_;
// What if there are two globalid with the same name?
// This should be fixed in the python code,
// But I havent take much look into it, so I will just hack around.

inline void add_global(const std::string & str, GlobalId id);

public:
tvm::Map<GlobalId, Item> items;

EnvironmentNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("items", &items); }
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("items", &items);
}

TVM_DLL static Environment make(tvm::Map<GlobalId, Item> items);

void add(const Item& item);
dmlc::optional<Item> lookup(const GlobalIdNode* id);
GlobalId global_id(const std::string & str);
Item lookup(const GlobalId & id);
Item lookup(const std::string & str);

static constexpr const char* _type_key = "nnvm.Environment";
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
Expand Down
6 changes: 4 additions & 2 deletions relay/include/relay/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ namespace relay {
class Evaluator : public ExprFunctor<Value(const Expr& n)> {
public:
Environment env;
std::vector<std::unordered_map<std::string, Value>>
// Because we know variables will always be unique we can collapse this into a single mapping?
std::vector<std::unordered_map<LocalId, Value>>
stack; // eventually should be encoded as a stack machine
Evaluator();
Evaluator(Environment env) : env(env) {}
void extend(const LocalId & id, Value v);
Value Eval(const Expr& expr);
Value VisitExpr_(const LocalIdNode* op) override;
Value VisitExpr_(const GlobalIdNode* op) override;
Expand All @@ -39,7 +41,7 @@ class Evaluator : public ExprFunctor<Value(const Expr& n)> {
Value VisitExpr_(const DebugNode* op) override;
Value VisitExpr_(const UnaryOpNode* op) override;
Value VisitExpr_(const BinaryOpNode* op) override;
Value VisitExpr_(const AssignmentNode* op) override;
Value VisitExpr_(const LetNode* op) override;
};

} // namespace relay
Expand Down
4 changes: 2 additions & 2 deletions relay/include/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const BinaryOpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const AssignmentNode* op,
virtual R VisitExpr_(const LetNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ReverseNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -153,7 +153,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(CallNode);
IR_EXPR_FUNCTOR_DISPATCH(UnaryOpNode);
IR_EXPR_FUNCTOR_DISPATCH(BinaryOpNode);
IR_EXPR_FUNCTOR_DISPATCH(AssignmentNode);
IR_EXPR_FUNCTOR_DISPATCH(LetNode);
IR_EXPR_FUNCTOR_DISPATCH(ReverseNode);
IR_EXPR_FUNCTOR_DISPATCH(AccumulateNode);
IR_EXPR_FUNCTOR_DISPATCH(ZeroNode);
Expand Down
57 changes: 44 additions & 13 deletions relay/include/relay/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class FloatValueNode : public ValueNode {

TVM_DEFINE_NODE_REF(FloatValue, FloatValueNode);



// end move me

/*! \brief Base type of the Relay type hiearchy. */
Expand Down Expand Up @@ -432,6 +434,8 @@ class GlobalIdNode : public ExprNode {

TVM_DLL static GlobalId make(std::string name);

inline size_t hash() const { return std::hash<std::string>{}(name); }

static constexpr const char* _type_key = "nnvm.GlobalId";
TVM_DECLARE_NODE_TYPE_INFO(GlobalIdNode, ExprNode);
};
Expand Down Expand Up @@ -610,30 +614,32 @@ class BinaryOpNode : public ExprNode {

TVM_DEFINE_NODE_REF(BinaryOp, BinaryOpNode)

class Assignment;
class Let;

// TODO(jroesch) : make me contain the proper fields.

/*! \brief Assignment. */
class AssignmentNode : public ExprNode {
/*! \brief A binding of a sub-network. */
class LetNode : public ExprNode {
public:
tvm::Map<LocalId, Expr> assignments;
LocalId id;
Expr value;
Expr body;

AssignmentNode() {}
LetNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("assignments", &assignments);
v->Visit("id", &id);
v->Visit("value", &value);
v->Visit("body", &body);
}

TVM_DLL static Assignment make(NodeRef node);
TVM_DLL static Let make(LocalId id, Expr value, Expr body);

static constexpr const char* _type_key = "nnvm.Assignment";
TVM_DECLARE_NODE_TYPE_INFO(AssignmentNode, ExprNode);
static constexpr const char* _type_key = "nnvm.Let";
TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode);
};

TVM_DEFINE_NODE_REF(Assignment, AssignmentNode);
TVM_DEFINE_NODE_REF(Let, LetNode);

class Primitive;

Expand Down Expand Up @@ -781,23 +787,48 @@ class DefnNode : public ItemNode {

TVM_DEFINE_NODE_REF(Defn, DefnNode);

// Move me too
class FnValue;

/*! \brief A floating point value. */
class FnValueNode : public ValueNode {
public:
tvm::Map<LocalId, Value> env;
Function func;

FnValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("env", &env);
v->Visit("func", &func);
}

TVM_DLL static FnValue make(tvm::Map<LocalId, Value> env, Function func);

static constexpr const char* _type_key = "nnvm.FnValue";
TVM_DECLARE_NODE_TYPE_INFO(FnValueNode, ValueNode);
};

TVM_DEFINE_NODE_REF(FnValue, FnValueNode);

} // namespace relay
} // namespace nnvm

namespace std {

/*! \brief std hash function for GlobalId */
template <>
struct hash<nnvm::relay::GlobalId> {
struct hash<nnvm::relay::LocalId> {
/*!
* \brief returns hash of global id.
* \param id global id.
* \return hash code.
*/
size_t operator()(const nnvm::relay::GlobalId& id) const {
return hash<std::string>{}(id->name);
size_t operator()(const nnvm::relay::LocalId& id) const {
return id.hash();
}
};

} // namespace std

#endif // NNVM_RELAY_NODE_H_
2 changes: 1 addition & 1 deletion relay/include/relay/pretty_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct PrettyPrinter : public ExprFunctor<void(const Expr& n, ostream & os)> {
void VisitExpr_(const DebugNode * op, ostream & os) override;
void VisitExpr_(const UnaryOpNode * op, ostream & os) override;
void VisitExpr_(const BinaryOpNode * op, ostream & os) override;
void VisitExpr_(const AssignmentNode * op, ostream & os) override;
void VisitExpr_(const LetNode * op, ostream & os) override;
};

} // namespace relay
Expand Down
4 changes: 2 additions & 2 deletions relay/include/relay/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class TypeFunctor<R(const Expr& n, Args...)> {
virtual R VisitType_(const DebugNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitType_(const UnaryOpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitType_(const BinaryOpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitType_(const AssignmentNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitType_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitType_(const ReverseNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitType_(const AccumulateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitType_(const ZeroNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -140,7 +140,7 @@ class TypeFunctor<R(const Expr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(CallNode);
IR_EXPR_FUNCTOR_DISPATCH(UnaryOpNode);
IR_EXPR_FUNCTOR_DISPATCH(BinaryOpNode);
IR_EXPR_FUNCTOR_DISPATCH(AssignmentNode);
IR_EXPR_FUNCTOR_DISPATCH(LetNode);
IR_EXPR_FUNCTOR_DISPATCH(ReverseNode);
IR_EXPR_FUNCTOR_DISPATCH(AccumulateNode);
IR_EXPR_FUNCTOR_DISPATCH(ZeroNode);
Expand Down
2 changes: 1 addition & 1 deletion relay/include/relay/typechecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Typechecker : public ExprFunctor<Type(const Expr & n)> {
Type VisitExpr_(const DebugNode* op) override;
Type VisitExpr_(const UnaryOpNode* op) override;
Type VisitExpr_(const BinaryOpNode* op) override;
Type VisitExpr_(const AssignmentNode* op) override;
Type VisitExpr_(const LetNode* op) override;
};

} // namespace relay
Expand Down
4 changes: 2 additions & 2 deletions relay/include/relay/visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ class FunctorNode
}
}

NodeRef VisitExpr_(const AssignmentNode* bop, tvm::Array<NodeRef> args) override {
NodeRef VisitExpr_(const LetNode* bop, tvm::Array<NodeRef> args) override {
throw "foo";
// if (visit_assignment != nullptr) {
// return visit_assignment(intrinsic->name, args);
// } else {
// return AssignmentNode::make(intrinsic->name);
// return LetNode::make(intrinsic->name);
// }
}

Expand Down
107 changes: 107 additions & 0 deletions relay/python/relay/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#pylint: disable-all
# We don't need linting this code is all temporary and will probably be moved/changed.
"""A module borrowed from tinyflow for importing some sample datasets."""
import numpy as np
from collections import namedtuple
from sklearn.datasets import fetch_mldata
import pickle
import sys
import os
from subprocess import call

class ArrayPacker(object):
"""Dataset packer for iterator"""
def __init__(self, X, Y):
self.images = X
self.labels = Y
self.ptr = 0

def next_batch(self, batch_size):
if self.ptr + batch_size >= self.labels.shape[0]:
self.ptr = 0
X = self.images[self.ptr:self.ptr+batch_size]
Y = self.labels[self.ptr:self.ptr+batch_size]
self.ptr += batch_size
return X, Y

MNISTData = namedtuple("MNISTData", ["train", "test"])

def get_mnist(flatten=False, onehot=False):
mnist = fetch_mldata('MNIST original')
np.random.seed(1234) # set seed for deterministic ordering
p = np.random.permutation(mnist.data.shape[0])
X = mnist.data[p]
Y = mnist.target[p]
X = X.astype(np.float32) / 255.0
if flatten:
X = X.reshape((70000, 28 * 28))
else:
X = X.reshape((70000, 1, 28, 28))
if onehot:
onehot = np.zeros((Y.shape[0], 10))
onehot[np.arange(Y.shape[0]), Y.astype(np.int32)] = 1
Y = onehot
X_train = X[:60000]
Y_train = Y[:60000]
X_test = X[60000:]
Y_test = Y[60000:]
return MNISTData(train=ArrayPacker(X_train, Y_train),
test=ArrayPacker(X_test, Y_test))


CIFAR10Data = namedtuple("CIFAR10Data", ["train", "test"])

def load_batch(fpath, label_key='labels'):
f = open(fpath, 'rb')
if sys.version_info < (3,):
d = cPickle.load(f)
else:
d = cPickle.load(f, encoding="bytes")
# decode utf8
for k, v in d.items():
del(d[k])
d[k.decode("utf8")] = v
f.close()
data = d["data"]
labels = d[label_key]

data = data.reshape(data.shape[0], 3, 32, 32).astype(np.float32)
labels = np.array(labels, dtype="float32")
return data, labels


def get_cifar10(swap_axes=False):
path = "cifar-10-batches-py"
if not os.path.exists(path):
tar_file = "cifar-10-python.tar.gz"
origin = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
if os.path.exists(tar_file):
need_download = False
else:
need_download = True
if need_download:
call(["wget", origin])
call(["tar", "-xvf", "cifar-10-python.tar.gz"])
else:
call(["tar", "-xvf", "cifar-10-python.tar.gz"])

nb_train_samples = 50000

X_train = np.zeros((nb_train_samples, 3, 32, 32), dtype="float32")
y_train = np.zeros((nb_train_samples,), dtype="float32")

for i in range(1, 6):
fpath = os.path.join(path, 'data_batch_' + str(i))
data, labels = load_batch(fpath)
X_train[(i - 1) * 10000: i * 10000, :, :, :] = data
y_train[(i - 1) * 10000: i * 10000] = labels

fpath = os.path.join(path, 'test_batch')
X_test, y_test = load_batch(fpath)

if swap_axes:
X_train = np.swapaxes(X_train, 1, 3)
X_test = np.swapaxes(X_test, 1, 3)

return CIFAR10Data(train=ArrayPacker(X_train, y_train),
test=ArrayPacker(X_test, y_test))
Loading

0 comments on commit 475f77e

Please sign in to comment.