diff --git a/relay/include/relay/environment.h b/relay/include/relay/environment.h index 85f80af010bc3..dbf086a408187 100644 --- a/relay/include/relay/environment.h +++ b/relay/include/relay/environment.h @@ -17,22 +17,30 @@ namespace relay { struct Environment; /*! \brief Integer literal `0`, `1000`. */ -struct EnvironmentNode : ValueNode { - std::unordered_map table; +class EnvironmentNode : public ValueNode { + private: + std::unordered_map global_map_; // What if there are two globalid with the same name? // This should be fixed in the python code, // But I havent take much look into it, so I will just hack around. + inline void add_global(const std::string & str, GlobalId id); + + public: tvm::Map items; EnvironmentNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("items", &items); } + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("items", &items); + } TVM_DLL static Environment make(tvm::Map items); void add(const Item& item); - dmlc::optional lookup(const GlobalIdNode* id); + GlobalId global_id(const std::string & str); + Item lookup(const GlobalId & id); + Item lookup(const std::string & str); static constexpr const char* _type_key = "nnvm.Environment"; TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node); diff --git a/relay/include/relay/evaluator.h b/relay/include/relay/evaluator.h index 3a7a2f340337e..602993c366d80 100644 --- a/relay/include/relay/evaluator.h +++ b/relay/include/relay/evaluator.h @@ -19,10 +19,12 @@ namespace relay { class Evaluator : public ExprFunctor { public: Environment env; - std::vector> + // Because we know variables will always be unique we can collapse this into a single mapping? + std::vector> stack; // eventually should be encoded as a stack machine Evaluator(); Evaluator(Environment env) : env(env) {} + void extend(const LocalId & id, Value v); Value Eval(const Expr& expr); Value VisitExpr_(const LocalIdNode* op) override; Value VisitExpr_(const GlobalIdNode* op) override; @@ -39,7 +41,7 @@ class Evaluator : public ExprFunctor { Value VisitExpr_(const DebugNode* op) override; Value VisitExpr_(const UnaryOpNode* op) override; Value VisitExpr_(const BinaryOpNode* op) override; - Value VisitExpr_(const AssignmentNode* op) override; + Value VisitExpr_(const LetNode* op) override; }; } // namespace relay diff --git a/relay/include/relay/expr_functor.h b/relay/include/relay/expr_functor.h index f29ca279b819c..c0a4d44d031aa 100644 --- a/relay/include/relay/expr_functor.h +++ b/relay/include/relay/expr_functor.h @@ -122,7 +122,7 @@ class ExprFunctor { Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const BinaryOpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const AssignmentNode* op, + virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ReverseNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -153,7 +153,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(CallNode); IR_EXPR_FUNCTOR_DISPATCH(UnaryOpNode); IR_EXPR_FUNCTOR_DISPATCH(BinaryOpNode); - IR_EXPR_FUNCTOR_DISPATCH(AssignmentNode); + IR_EXPR_FUNCTOR_DISPATCH(LetNode); IR_EXPR_FUNCTOR_DISPATCH(ReverseNode); IR_EXPR_FUNCTOR_DISPATCH(AccumulateNode); IR_EXPR_FUNCTOR_DISPATCH(ZeroNode); diff --git a/relay/include/relay/node.h b/relay/include/relay/node.h index 7e6b13db0d589..278a057e86cff 100644 --- a/relay/include/relay/node.h +++ b/relay/include/relay/node.h @@ -123,6 +123,8 @@ class FloatValueNode : public ValueNode { TVM_DEFINE_NODE_REF(FloatValue, FloatValueNode); + + // end move me /*! \brief Base type of the Relay type hiearchy. */ @@ -432,6 +434,8 @@ class GlobalIdNode : public ExprNode { TVM_DLL static GlobalId make(std::string name); + inline size_t hash() const { return std::hash{}(name); } + static constexpr const char* _type_key = "nnvm.GlobalId"; TVM_DECLARE_NODE_TYPE_INFO(GlobalIdNode, ExprNode); }; @@ -610,30 +614,32 @@ class BinaryOpNode : public ExprNode { TVM_DEFINE_NODE_REF(BinaryOp, BinaryOpNode) -class Assignment; +class Let; // TODO(jroesch) : make me contain the proper fields. -/*! \brief Assignment. */ -class AssignmentNode : public ExprNode { +/*! \brief A binding of a sub-network. */ +class LetNode : public ExprNode { public: - tvm::Map assignments; + LocalId id; + Expr value; Expr body; - AssignmentNode() {} + LetNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("assignments", &assignments); + v->Visit("id", &id); + v->Visit("value", &value); v->Visit("body", &body); } - TVM_DLL static Assignment make(NodeRef node); + TVM_DLL static Let make(LocalId id, Expr value, Expr body); - static constexpr const char* _type_key = "nnvm.Assignment"; - TVM_DECLARE_NODE_TYPE_INFO(AssignmentNode, ExprNode); + static constexpr const char* _type_key = "nnvm.Let"; + TVM_DECLARE_NODE_TYPE_INFO(LetNode, ExprNode); }; -TVM_DEFINE_NODE_REF(Assignment, AssignmentNode); +TVM_DEFINE_NODE_REF(Let, LetNode); class Primitive; @@ -781,6 +787,30 @@ class DefnNode : public ItemNode { TVM_DEFINE_NODE_REF(Defn, DefnNode); +// Move me too +class FnValue; + +/*! \brief A floating point value. */ +class FnValueNode : public ValueNode { + public: + tvm::Map env; + Function func; + + FnValueNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("env", &env); + v->Visit("func", &func); + } + + TVM_DLL static FnValue make(tvm::Map env, Function func); + + static constexpr const char* _type_key = "nnvm.FnValue"; + TVM_DECLARE_NODE_TYPE_INFO(FnValueNode, ValueNode); +}; + +TVM_DEFINE_NODE_REF(FnValue, FnValueNode); + } // namespace relay } // namespace nnvm @@ -788,16 +818,17 @@ namespace std { /*! \brief std hash function for GlobalId */ template <> -struct hash { +struct hash { /*! * \brief returns hash of global id. * \param id global id. * \return hash code. */ - size_t operator()(const nnvm::relay::GlobalId& id) const { - return hash{}(id->name); + size_t operator()(const nnvm::relay::LocalId& id) const { + return id.hash(); } }; + } // namespace std #endif // NNVM_RELAY_NODE_H_ diff --git a/relay/include/relay/pretty_printer.h b/relay/include/relay/pretty_printer.h index 2f1e7b98795ed..7c568237bae31 100644 --- a/relay/include/relay/pretty_printer.h +++ b/relay/include/relay/pretty_printer.h @@ -38,7 +38,7 @@ struct PrettyPrinter : public ExprFunctor { void VisitExpr_(const DebugNode * op, ostream & os) override; void VisitExpr_(const UnaryOpNode * op, ostream & os) override; void VisitExpr_(const BinaryOpNode * op, ostream & os) override; - void VisitExpr_(const AssignmentNode * op, ostream & os) override; + void VisitExpr_(const LetNode * op, ostream & os) override; }; } // namespace relay diff --git a/relay/include/relay/type_functor.h b/relay/include/relay/type_functor.h index 664331affd2ab..b75c6c8f6f67a 100644 --- a/relay/include/relay/type_functor.h +++ b/relay/include/relay/type_functor.h @@ -112,7 +112,7 @@ class TypeFunctor { virtual R VisitType_(const DebugNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitType_(const UnaryOpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitType_(const BinaryOpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitType_(const AssignmentNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitType_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitType_(const ReverseNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitType_(const AccumulateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitType_(const ZeroNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -140,7 +140,7 @@ class TypeFunctor { IR_EXPR_FUNCTOR_DISPATCH(CallNode); IR_EXPR_FUNCTOR_DISPATCH(UnaryOpNode); IR_EXPR_FUNCTOR_DISPATCH(BinaryOpNode); - IR_EXPR_FUNCTOR_DISPATCH(AssignmentNode); + IR_EXPR_FUNCTOR_DISPATCH(LetNode); IR_EXPR_FUNCTOR_DISPATCH(ReverseNode); IR_EXPR_FUNCTOR_DISPATCH(AccumulateNode); IR_EXPR_FUNCTOR_DISPATCH(ZeroNode); diff --git a/relay/include/relay/typechecker.h b/relay/include/relay/typechecker.h index 66354a7c50f90..ca7e402300b81 100644 --- a/relay/include/relay/typechecker.h +++ b/relay/include/relay/typechecker.h @@ -35,7 +35,7 @@ class Typechecker : public ExprFunctor { Type VisitExpr_(const DebugNode* op) override; Type VisitExpr_(const UnaryOpNode* op) override; Type VisitExpr_(const BinaryOpNode* op) override; - Type VisitExpr_(const AssignmentNode* op) override; + Type VisitExpr_(const LetNode* op) override; }; } // namespace relay diff --git a/relay/include/relay/visitor.h b/relay/include/relay/visitor.h index c05f433ae881c..f81f9ada2918c 100644 --- a/relay/include/relay/visitor.h +++ b/relay/include/relay/visitor.h @@ -169,12 +169,12 @@ class FunctorNode } } - NodeRef VisitExpr_(const AssignmentNode* bop, tvm::Array args) override { + NodeRef VisitExpr_(const LetNode* bop, tvm::Array args) override { throw "foo"; // if (visit_assignment != nullptr) { // return visit_assignment(intrinsic->name, args); // } else { - // return AssignmentNode::make(intrinsic->name); + // return LetNode::make(intrinsic->name); // } } diff --git a/relay/python/relay/datasets.py b/relay/python/relay/datasets.py new file mode 100644 index 0000000000000..cb28a9ce109e1 --- /dev/null +++ b/relay/python/relay/datasets.py @@ -0,0 +1,107 @@ +#pylint: disable-all +# We don't need linting this code is all temporary and will probably be moved/changed. +"""A module borrowed from tinyflow for importing some sample datasets.""" +import numpy as np +from collections import namedtuple +from sklearn.datasets import fetch_mldata +import pickle +import sys +import os +from subprocess import call + +class ArrayPacker(object): + """Dataset packer for iterator""" + def __init__(self, X, Y): + self.images = X + self.labels = Y + self.ptr = 0 + + def next_batch(self, batch_size): + if self.ptr + batch_size >= self.labels.shape[0]: + self.ptr = 0 + X = self.images[self.ptr:self.ptr+batch_size] + Y = self.labels[self.ptr:self.ptr+batch_size] + self.ptr += batch_size + return X, Y + +MNISTData = namedtuple("MNISTData", ["train", "test"]) + +def get_mnist(flatten=False, onehot=False): + mnist = fetch_mldata('MNIST original') + np.random.seed(1234) # set seed for deterministic ordering + p = np.random.permutation(mnist.data.shape[0]) + X = mnist.data[p] + Y = mnist.target[p] + X = X.astype(np.float32) / 255.0 + if flatten: + X = X.reshape((70000, 28 * 28)) + else: + X = X.reshape((70000, 1, 28, 28)) + if onehot: + onehot = np.zeros((Y.shape[0], 10)) + onehot[np.arange(Y.shape[0]), Y.astype(np.int32)] = 1 + Y = onehot + X_train = X[:60000] + Y_train = Y[:60000] + X_test = X[60000:] + Y_test = Y[60000:] + return MNISTData(train=ArrayPacker(X_train, Y_train), + test=ArrayPacker(X_test, Y_test)) + + +CIFAR10Data = namedtuple("CIFAR10Data", ["train", "test"]) + +def load_batch(fpath, label_key='labels'): + f = open(fpath, 'rb') + if sys.version_info < (3,): + d = cPickle.load(f) + else: + d = cPickle.load(f, encoding="bytes") + # decode utf8 + for k, v in d.items(): + del(d[k]) + d[k.decode("utf8")] = v + f.close() + data = d["data"] + labels = d[label_key] + + data = data.reshape(data.shape[0], 3, 32, 32).astype(np.float32) + labels = np.array(labels, dtype="float32") + return data, labels + + +def get_cifar10(swap_axes=False): + path = "cifar-10-batches-py" + if not os.path.exists(path): + tar_file = "cifar-10-python.tar.gz" + origin = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + if os.path.exists(tar_file): + need_download = False + else: + need_download = True + if need_download: + call(["wget", origin]) + call(["tar", "-xvf", "cifar-10-python.tar.gz"]) + else: + call(["tar", "-xvf", "cifar-10-python.tar.gz"]) + + nb_train_samples = 50000 + + X_train = np.zeros((nb_train_samples, 3, 32, 32), dtype="float32") + y_train = np.zeros((nb_train_samples,), dtype="float32") + + for i in range(1, 6): + fpath = os.path.join(path, 'data_batch_' + str(i)) + data, labels = load_batch(fpath) + X_train[(i - 1) * 10000: i * 10000, :, :, :] = data + y_train[(i - 1) * 10000: i * 10000] = labels + + fpath = os.path.join(path, 'test_batch') + X_test, y_test = load_batch(fpath) + + if swap_axes: + X_train = np.swapaxes(X_train, 1, 3) + X_test = np.swapaxes(X_test, 1, 3) + + return CIFAR10Data(train=ArrayPacker(X_train, y_train), + test=ArrayPacker(X_test, y_test)) diff --git a/relay/python/relay/expr.py b/relay/python/relay/expr.py index 281872b7870cd..ebfdc1c52fb81 100644 --- a/relay/python/relay/expr.py +++ b/relay/python/relay/expr.py @@ -1,16 +1,42 @@ +#pylint: disable=no-else-return, unidiomatic-typecheck """All the expression nodes""" from enum import IntEnum +from typing import Union from .base import register_nnvm_node, NodeBase from . import make +# TODO(@jroesch): Add meta-programming dunder methods to a new Base class + + @register_nnvm_node class Environment(NodeBase): - pass + """Global Environment + """ + + def add(self, func: "GlobalId") -> None: + return make.Environment_add(self, func) + + def global_id(self, name: str) -> "GlobalId": + return make.Environment_global_id(self, name) + + def lookup(self, ident: Union["GlobalId", str]) -> "Item": + if isinstance(ident, str): + return make.Environment_lookup_str(self, ident) + else: + return make.Environment_lookup(self, ident) + + def ilookup(self, _: "InstrinsicId") -> None: + assert False + class Item(NodeBase): """Base class of all expressions. """ + def __eq__(self, other): + if type(self) != type(other): + return False + for field in dir(self): if self.__getattr__(field) == other.__getattr__(field): pass @@ -19,10 +45,15 @@ def __eq__(self, other): return True + class Type(NodeBase): """Base class of all types. """ + def __eq__(self, other): + if type(self) != type(other): + return False + for field in dir(self): if self.__getattr__(field) == other.__getattr__(field): pass @@ -31,8 +62,10 @@ def __eq__(self, other): return True + class Builder(object): """A helper class for building partial AST fragments.""" + def __init__(self): pass @@ -75,10 +108,15 @@ def __neg__(self): def cast(self, target): return make.Cast(target, self) + class Expr(NodeBase, Builder): """Base class of all expressions. """ + def __eq__(self, other): + if type(self) != type(other): + return False + for field in dir(self): if self.__getattr__(field) == other.__getattr__(field): pass @@ -87,6 +125,7 @@ def __eq__(self, other): return True + class Value(NodeBase): """Base class of all expressions. """ @@ -100,82 +139,104 @@ def __eq__(self, other): return True + @register_nnvm_node class Defn(Item): pass + @register_nnvm_node class Primitive(Item): pass + @register_nnvm_node class String(Expr): pass + @register_nnvm_node class IntLit(Expr): pass + @register_nnvm_node class FloatLit(Expr): pass + @register_nnvm_node class BoolLit(Expr): pass + @register_nnvm_node class TensorLit(Expr): pass + @register_nnvm_node class ProductLit(Expr): pass + @register_nnvm_node class IntType(Type): pass + @register_nnvm_node class Cast(Expr): pass + @register_nnvm_node class LocalId(Expr): - pass + def __hash__(self) -> int: + return hash(self.name) + @register_nnvm_node class GlobalId(Expr): pass + @register_nnvm_node class IntrinsicId(Expr): pass + @register_nnvm_node class Param(NodeBase): pass + @register_nnvm_node class Function(Expr): pass + @register_nnvm_node class Call(Expr): pass + @register_nnvm_node class Debug(Expr): pass + class UOp(IntEnum): NEG = 0 + @register_nnvm_node class UnaryOp(Expr): pass #pylint: disable=invalid-name + + class BOp(IntEnum): """The set of builtin binary ops supported by Relay.""" PLUS = 0 @@ -189,27 +250,51 @@ class BOp(IntEnum): LE = 8 GE = 9 + @register_nnvm_node class BinaryOp(Expr): pass + +@register_nnvm_node +class Let(Expr): + pass + + @register_nnvm_node class Functor(Expr): pass + @register_nnvm_node class Reverse(Expr): pass + @register_nnvm_node class Accumulate(Expr): pass + @register_nnvm_node class Zero(Expr): pass # MOve me? + + @register_nnvm_node class IntValue(Value): pass + +@register_nnvm_node +class FloatValue(Value): + pass + +@register_nnvm_node +class BoolValue(Value): + pass + +@register_nnvm_node +class FnValue(Value): + pass diff --git a/relay/python/relay/relay.py b/relay/python/relay/relay.py index 64a1b90385ec8..89587e1023501 100644 --- a/relay/python/relay/relay.py +++ b/relay/python/relay/relay.py @@ -2,7 +2,10 @@ """A decorator for rewriting Python code into Relay.""" import ast import inspect -#pylint: disable=wildcard-import +# from typing import Dict, List +from collections import OrderedDict +# pylint: disable=wildcard-import +import nnvm.relay.eval as re from .make import * # This contains a global environment of all items. @@ -40,14 +43,15 @@ def compile_args_to_params(args): class DefToRelay(ast.NodeVisitor): """Compile a single Python def to a Realy definition.""" + # local_scopes: List[Dict[LocalId, Expr]] def __init__(self, python_def): + self.local_scopes = [] self.python_def = python_def #pylint: disable=invalid-name def visit_Name(self, name_node): - ident = name_node.id - return LocalId(ident) + return self.translate_ident(name_node) #pylint: disable=invalid-name def visit_Return(self, return_node): @@ -57,11 +61,50 @@ def visit_Return(self, return_node): else: raise Exception("return must have a value") + def visit_Call(self, call_node): + """Transform a Python call into a Relay call""" + func = call_node.func + # args = call_node.args + # keywords = call_node.keywords + if isinstance(func, ast.Attribute): + if func.value.id == 'relay': + relay_func = IntrinsicId(func.attr) + else: + raise Exception( + "only supported namespace is relay right now") # improve me + else: + raise Exception("unsupported calls") + # Todo(jroesch): Handle args + return Call(relay_func, []) + + def visit_Assign(self, assign_node): + targets = assign_node.targets + value = assign_node.value + assert len(targets) == 1 + ident = self.translate_ident(targets[0]) + rhs = self.visit(value) + self.local_scopes[-1][ident] = rhs + + # Need to put types some-where def compile_stmt_seq_to_body(self, stmts): - x = self.visit(stmts[0]) - #assert x is not None - #todo(M.K.) somehow it is returning null. very bad. - return x + """Compile a sequence of statements into a Relay expression.""" + assert stmts + ret = stmts[-1] + stmts = stmts[0:-1] + self.local_scopes.append(OrderedDict([])) + for stmt in stmts: + self.visit(stmt) + + cont = self.visit(ret) + scope = self.local_scopes.pop() + for key in reversed(scope): + value = scope[key] + cont = Let(key, value, cont) + return cont + + def translate_ident(self, ident): + # import pdb; pdb.set_trace() + return LocalId(ident.id) def run(self): """executes visitor""" @@ -70,15 +113,13 @@ def run(self): args = func.args body = func.body params = compile_args_to_params(args) - print(args) - print(body) relay_body = self.compile_stmt_seq_to_body(body) func = Function(params, relay_body) defunc = Defn(GlobalId(name), None, func) return defunc -def compile_def_to_defunc(func): +def compile_def_to_defn(func): def_to_relay = DefToRelay(func) return def_to_relay.run() @@ -86,7 +127,16 @@ def compile_def_to_defunc(func): def get_env(): return __relay_environment__ -def relay(f): +def relay_compile(f): + """Compile the Python function to a Relay function.""" + source = inspect.getsource(f) + mod = ast.parse(source) + mod_body = mod.body + assert len(mod_body) == 1 # We only handle one function at a time. + func = mod_body[0] + return compile_def_to_defn(func) + +def relay(func): """ This should be the implementation of the decorator eventually. @@ -102,22 +152,21 @@ def relay(f): Adds to global environment. """ - # Returns handle to function. - # unify with C++ code path - global __relay_environment__ - source = inspect.getsource(f) - mod = ast.parse(source) - mod_body = mod.body - assert len(mod_body) == 1 # We only handle one function at a time. - func = mod_body[0] - name = GlobalId(func.name) - defunc = compile_def_to_defunc(func) - Environment_add(__relay_environment__, defunc) - print(__relay_environment__.items) - return name # for the time being add thing to environment, and then - # Eventually - # def wrapper(args): - # relay_args = args - # this might require full compiler run - # and dynamic linking - # return eval(Call(name, relay_args)) + defn = relay_compile(func) + get_env().add(defn) + + def wrapper(*_): + return re.eval(get_env(), Call(defn.id, [])) + + return wrapper + + +# Store Python line and columb information for errors +# +# Make it possible to handle calls with keywords arguments +# +# Ensure names prefixed with relay. becomes a Intrinsic +# +# Handle assignments +# +# Ensure the Environment is updated with the function. diff --git a/relay/src/relay/environment.cc b/relay/src/relay/environment.cc index e94af7c576d90..7dd6bfaa64ba7 100644 --- a/relay/src/relay/environment.cc +++ b/relay/src/relay/environment.cc @@ -4,6 +4,7 @@ * \brief Relay node data structure. */ #include +#include namespace nnvm { namespace relay { @@ -11,24 +12,48 @@ namespace relay { using tvm::IRPrinter; using namespace tvm::runtime; +struct EnvError : dmlc::Error { + explicit EnvError(const std::string & msg) : dmlc::Error(msg) {} +}; + Environment EnvironmentNode::make(tvm::Map items) { std::shared_ptr n = std::make_shared(); n->items = std::move(items); return Environment(n); } +void EnvironmentNode::add_global(const std::string & str, GlobalId id) { + global_map_[str] = id; +} + +GlobalId EnvironmentNode::global_id(const std::string & str) { + if (global_map_.find(str) != global_map_.end()) { + return global_map_.at(str); + } else { + GlobalId id = GlobalIdNode::make(str); + this->add_global(str, id); + return id; + } +} + void EnvironmentNode::add(const Item &item) { - this->table.insert({item->id->name, item->id}); + // Should we first check to see if any Global of this name + // been allocated and even disallow duplicate hints? + add_global(item->id->name, item->id); this->items.Set(item->id, item); } -dmlc::optional EnvironmentNode::lookup(const GlobalIdNode *id) { - auto nit = this->table.find(id->name); - if (nit == this->table.end()) { - return dmlc::optional(); - } - auto it = this->items.find((*nit).second); - return it == this->items.end() ? dmlc::optional() : dmlc::optional((*it).second); +Item EnvironmentNode::lookup(const GlobalId & id) { + if (items.find(id) != items.end()) { + return items.at(id); + } else { + throw EnvError("there is no definition of " + id->name); + } +} + +Item EnvironmentNode::lookup(const std::string & str) { + GlobalId id = this->global_id(str); + return this->lookup(id); } TVM_REGISTER_API("nnvm.make.Environment") @@ -44,10 +69,31 @@ TVM_REGISTER_API("nnvm.make.Environment_add") env->add(item); }); +TVM_REGISTER_API("nnvm.make.Environment_lookup") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + GlobalId id = args[1]; + *ret = env->lookup(id); + }); + +TVM_REGISTER_API("nnvm.make.Environment_lookup_str") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + std::string str = args[1]; + *ret = env->lookup(str); + }); + + +TVM_REGISTER_API("nnvm.make.Environment_global_id") + .set_body([](TVMArgs args, TVMRetValue *ret) { + Environment env = args[0]; + std::string str = args[1]; + *ret = env->global_id(str); + }); + TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) - .set_dispatch([](const EnvironmentNode *node, - tvm::IRPrinter *p) { - p->stream << "EnvironmentNode(TODO)"; + .set_dispatch([](const EnvironmentNode *node, tvm::IRPrinter *p) { + p->stream << "EnvironmentNode(" << node->items << ")"; }); } // namespace relay diff --git a/relay/src/relay/evaluator.cc b/relay/src/relay/evaluator.cc index a6810f971d04b..a4bad15674f35 100644 --- a/relay/src/relay/evaluator.cc +++ b/relay/src/relay/evaluator.cc @@ -23,17 +23,18 @@ Evaluator::Evaluator() : env() {} Value Evaluator::Eval(const Expr& expr) { return this->operator()(expr); } -Value Evaluator::VisitExpr_(const LocalIdNode* local) { +Value Evaluator::VisitExpr_(const LocalIdNode* local_node) { // We should instead compile this to a stack machine with statically resolved // offsets. - for (auto frame = this->stack.rbegin(); frame != this->stack.rend(); - frame++) { - if (frame->find(local->name) != frame->end()) { - return frame->at(local->name); - } else { - continue; - } - } + // LocalId local = LocalId(local_node->GetNodeRef().node_); + // for (auto frame = this->stack.rbegin(); frame != this->stack.rend(); + // frame++) { + // if (frame->find(local) != frame->end()) { + // return frame->at(local); + // } else { + // continue; + // } + // } // If we get here this is a bug, we are referencing a variable that doesn't // exist. @@ -41,15 +42,30 @@ Value Evaluator::VisitExpr_(const LocalIdNode* local) { } Value Evaluator::VisitExpr_(const GlobalIdNode* op) { - if (auto global = this->env->lookup(op)) { - Item i = *global; - if (const DefnNode* def = i.as()) { - return this->VisitExpr(def->body); - } else { - throw EvalError("unknown global id"); - } + // This is not memory safe ... + // + // There are three issues with this current design of visitor, because we receive the unboxed + // node we can not directly index the map by the NodeRef, because we no longer have handle + // to the NodeRef. + // + // We can box the Node back into a shared pointer, and then into a NodeRef, + // but this is very memory unsafe in the long run, + // because the original shared pointer's lifetime may be greater then + // the lifetime of the shared pointer we build inside the function. + // + // This is always memory unsafe if the new shared pointer escapes the frame, because + // its near guaranteed that the two diferent shared pointers reference counts will + // not decmement in lockstep, leading to a dangling pointer in one of them. + // + // Finally we don't want to directly store the pointers as keys either because it + // fails to capture the desired semantics of the identifier, not the location being + // semantically meaningful. + GlobalId id = GlobalIdNode::make(op->name); + Item i = this->env->lookup(GlobalId(id)); + if (const DefnNode* def = i.as()) { + return this->VisitExpr(def->body); } else { - throw EvalError("unknown global value"); + throw EvalError("unknown global id"); } } @@ -90,6 +106,7 @@ Value Evaluator::VisitExpr_(const FunctionNode* op) { } Value Evaluator::VisitExpr_(const CallNode* op) { + // auto fn = this->VisitExpr(op->fn); throw EvalError("Call NYI"); } @@ -127,7 +144,7 @@ Value Evaluator::VisitExpr_(const BinaryOpNode* op) { } } -Value Evaluator::VisitExpr_(const AssignmentNode* op) { +Value Evaluator::VisitExpr_(const LetNode* op) { throw EvalError("assignment node"); // fix me } diff --git a/relay/src/relay/node.cc b/relay/src/relay/node.cc index 333b18499e645..f1de15806216a 100644 --- a/relay/src/relay/node.cc +++ b/relay/src/relay/node.cc @@ -425,6 +425,25 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "ZeroNode(" << node->type << ")"; }); +Let LetNode::make(LocalId id, Expr value, Expr body) { + std::shared_ptr n = std::make_shared(); + n->id = std::move(id); + n->value = std::move(value); + n->body = std::move(body); + return Let(n); +} + +TVM_REGISTER_API("nnvm.make.Let") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LetNode::make(args[0], args[1], args[2]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const LetNode *node, + tvm::IRPrinter *p) { + // TODO(jroesch): p->stream << "LetNode(" << node->id << ", " << node->type << ")"; + }); + Primitive PrimitiveNode::make(GlobalId id, Type type) { std::shared_ptr n = std::make_shared(); n->id = id; @@ -511,5 +530,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "FloatValueNode(" << node->value << ")"; }); +FnValue FnValueNode::make(tvm::Map env, Function func) { + std::shared_ptr n = std::make_shared(); + n->env = std::move(env); + n->func = std::move(func); + return FnValue(n); +} + +TVM_REGISTER_API("nnvm.make.FnValue") + .set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FnValueNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const FnValueNode *node, + tvm::IRPrinter *p) { + p->stream << "FnValueNode(todo)"; + }); + } // namespace relay } // namespace nnvm diff --git a/relay/src/relay/pretty_printer.cc b/relay/src/relay/pretty_printer.cc index 873e4fd0309a2..ddd49db4392d9 100644 --- a/relay/src/relay/pretty_printer.cc +++ b/relay/src/relay/pretty_printer.cc @@ -30,15 +30,11 @@ void PrettyPrinter::VisitExpr_(const LocalIdNode * local, ostream & os) { } void PrettyPrinter::VisitExpr_(const GlobalIdNode * op, ostream & os) { - if (auto global = this->env->lookup(op)) { - Item i = *global; - if (const DefnNode* def = i.as()) { - this->PrettyPrint(def->body, os); - } else { - throw PrintError("unknown global id"); - } + Item i = this->env->lookup(op->name); + if (const DefnNode* def = i.as()) { + this->PrettyPrint(def->body, os); } else { - throw PrintError("unknown global value"); + throw PrintError("unknown global id"); } } @@ -119,7 +115,7 @@ void PrettyPrinter::VisitExpr_(const BinaryOpNode * op, ostream & os) { } } -void PrettyPrinter::VisitExpr_(const AssignmentNode * op, ostream & os) { +void PrettyPrinter::VisitExpr_(const LetNode * op, ostream & os) { throw PrintError("Assignment NYI"); } diff --git a/relay/src/relay/typechecker.cc b/relay/src/relay/typechecker.cc index 00833b4273b1e..82e4ee9a2ff42 100644 --- a/relay/src/relay/typechecker.cc +++ b/relay/src/relay/typechecker.cc @@ -84,8 +84,8 @@ Type Typechecker::VisitExpr_(const BinaryOpNode *op) { throw TypecheckerError("BinaryOpNode not implmented"); } -Type Typechecker::VisitExpr_(const AssignmentNode *op) { - throw TypecheckerError("AssignmentNode not implemented"); +Type Typechecker::VisitExpr_(const LetNode *op) { + throw TypecheckerError("LetNode not implemented"); } TVM_REGISTER_API("nnvm.tyck.check")