diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 4ae35f585c6f..b7621e20cf6a 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -22,8 +22,15 @@ namespace tvm { * You can find more about Relay by reading the language reference. */ namespace relay { + +#define RELAY_DEBUG(...) \ +{ auto fdebug = runtime::Registry::Get("relay.debug"); \ + CHECK(fdebug) << "Could not find Relay Python debugger function."; \ + (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ +} + /*! - * \brief we always used NodeRef for referencing nodes. + * \brief We always used NodeRef for referencing nodes. * * By default, NodeRef is a std::shared_ptr of node */ diff --git a/include/tvm/relay/build_module.h b/include/tvm/relay/build_module.h new file mode 100644 index 000000000000..ed889eba0bd0 --- /dev/null +++ b/include/tvm/relay/build_module.h @@ -0,0 +1,76 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/build_module.h + * \brief The passes and data structures needed to build a + * tvm::Module from a Relay program. + */ +#ifndef TVM_RELAY_BUILD_MODULE_H_ +#define TVM_RELAY_BUILD_MODULE_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief A lowered Relay operation. + * + * A lowered operation is a pair containing the "primitive" function used + * to produce the lowered function as well as the lowered function itself. + */ +class LoweredOp; +/*! \brief Call container. */ +class LoweredOpNode : public Node { + public: + /*! + * \brief The primitive function to be lowered. + * + * A primitive function consists only of calls to relay::Op which + * can be fused. + */ + Function func; + + /*! + * \brief The lowered function. + */ + LoweredFunc lowered_func; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("func", &func); + v->Visit("lowered_func", &lowered_func); + } + + TVM_DLL static LoweredOp make( + Function func, + LoweredFunc lowered_func); + + static constexpr const char* _type_key = "relay.LoweredOp"; + TVM_DECLARE_NODE_TYPE_INFO(LoweredOpNode, Node); +}; + +RELAY_DEFINE_NODE_REF(LoweredOp, LoweredOpNode, NodeRef); + +/*! + * \brief Lower the operations contained in a Relay expression. + * + * The lowering pass will only lower functions marked as primitive, + * the FuseOps pass will provide this behavior, if run before LowerOps. + * + * \note This will do a reachability analysis and lower all definitions + * reachable from the provided expression. + * + * \param env The environment. + * \param expr The expression with operations to be lowered. + * \param target The target to lower the functions to. + * + * \return The set of lowered operations. + */ +Array LowerOps(const Environment& env, const Expr& expr, + const std::string& target = "llvm"); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BUILD_MODULE_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 2e3bbadb7841..029470c067ce 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -213,12 +213,18 @@ class FunctionNode : public ExprNode { */ tvm::Array type_params; + /*! + * \brief The attributes which store metadata about functions. + */ + tvm::Attrs attrs; + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("params", ¶ms); v->Visit("body", &body); v->Visit("ret_type", &ret_type); v->Visit("type_params", &type_params); v->Visit("span", &span); + v->Visit("attrs", &attrs); v->Visit("_checked_type_", &checked_type_); } @@ -233,7 +239,8 @@ class FunctionNode : public ExprNode { TVM_DLL static Function make(tvm::Array params, Expr body, Type ret_type, - tvm::Array ty_params); + tvm::Array ty_params, + tvm::Attrs attrs = Attrs()); static constexpr const char* _type_key = "relay.Function"; TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode); @@ -241,6 +248,11 @@ class FunctionNode : public ExprNode { RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr); + +TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key); +TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data); + + /*! * \brief Call corresponds to operator invocation. * Corresponds to the operator in computational graph terminology. diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h new file mode 100644 index 000000000000..1c382faaef04 --- /dev/null +++ b/include/tvm/relay/interpreter.h @@ -0,0 +1,140 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/interpreter.h + * \brief An interpreter for Relay. + * + * This file implements a simple reference interpreter for Relay programs. + * Given a Relay environment, and a Relay expression it produces a value. + * + * The interpreter's values are a naive representation of the values that + * can be produced by a Relay program and are exposed via tvm::Node's + * system to Python for introspection and debugging. + * + * The interpreter's intent is to serve as a reference semantics for the Relay IR, + * as well as for debugging and testing. + */ +#ifndef TVM_RELAY_INTERPRETER_H_ +#define TVM_RELAY_INTERPRETER_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief A Relay value. + */ +class Value; + +/*! \brief Evaluate an expression using the interpreter producing a value. + * + * The resulting value can be passed to Python, making it easy to use + * for testing and debugging. + * + * The interpreter interprets the program fragments not supported by the + * TVM runtime, although the interpreter is naively implemented it uses + * TVM operators for evaluating all operators. + * + * Our intent is that this will never be the most efficient implementation of + * Relay's semantics, but a readable and clear one. + */ +Value Evaluate(Environment env, Expr e); + +/*! \brief The base container type of Relay values. */ +class ValueNode : public RelayNode { + public: + static constexpr const char* _type_key = "relay.Value"; + TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode); +}; + +class Value : public NodeRef { + public: + Value() {} + explicit Value(NodePtr n) : NodeRef(n) {} + const ValueNode* operator->() const { + return static_cast(node_.get()); + } + + using ContainerType = ValueNode; +}; + +/*! \brief A Relay closure, i.e a scope and a function. */ +class Closure; + +/*! \brief The container type of Closures. */ +class ClosureNode : public ValueNode { + public: + /*! \brief The set of free variables in the closure. + * + * These are the captured variables which are required for + * evaluation when we call the closure. + */ + tvm::Map env; + /*! \brief The function which implements the closure. + * + * \note May reference the variables contained in the env. + */ + Function func; + + ClosureNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("env", &env); + v->Visit("func", &func); + } + + TVM_DLL static Closure make(tvm::Map env, Function func); + + static constexpr const char* _type_key = "relay.Closure"; + TVM_DECLARE_NODE_TYPE_INFO(ClosureNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value); + +/*! \brief A tuple value. */ +class TupleValue; + +/*! \brief Tuple (x, ... y). */ +struct TupleValueNode : ValueNode { + tvm::Array fields; + + TupleValueNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); } + + TVM_DLL static TupleValue make(tvm::Array value); + + static constexpr const char* _type_key = "relay.TupleValue"; + TVM_DECLARE_NODE_TYPE_INFO(TupleValueNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(TupleValue, TupleValueNode, Value); + +/*! \brief A tensor value. */ +class TensorValue; + +/*! \brief The tensor value container, wrapping an NDArray. */ +struct TensorValueNode : ValueNode { + runtime::NDArray data; + + TensorValueNode() {} + + void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("data", &data); } + + /*! \brief Build a value from an NDArray. */ + TVM_DLL static TensorValue make(runtime::NDArray data); + + /*! \brief Construct an empty tensor value from t. */ + TVM_DLL static TensorValue FromType(const Type& t); + + static constexpr const char* _type_key = "relay.TensorValue"; + TVM_DECLARE_NODE_TYPE_INFO(TensorValueNode, ValueNode); +}; + +RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value); + + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_INTERPRETER_H_ diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index bf16c7ed8e33..b29678106d21 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -8,6 +8,7 @@ #include #include +#include namespace tvm { namespace relay { @@ -20,7 +21,8 @@ namespace relay { * populated with the result type. * * \param expr The expression to type check. - * \param env The environment used for referencing global functions, can be None. + * \param env The environment used for referencing global functions, can be + * None. * * \return A type checked expression with its checked_type field populated. */ @@ -35,7 +37,8 @@ Expr InferType(const Expr& expr, const Environment& env); * \return A type checked Function with its checked_type field populated. * \note this function mutates env and is not thread-safe. */ -Function InferType(const Function& f, const Environment& env, const GlobalVar& var); +Function InferType(const Function& f, const Environment& env, + const GlobalVar& var); /*! * \brief Check that types are well kinded by applying "kinding rules". @@ -94,28 +97,30 @@ bool AlphaEqual(const Type& t1, const Type& t2); * * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice. * - * `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, although x is not shadowed. + * `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, + * although x is not shadowed. * - * \param e the expression to check. + * \param expr the expression to check. * - * \return true iff all Var in e is bound at most once. + * \return true iff all Var in expr is bound at most once. */ -bool WellFormed(const Expr& e); +bool WellFormed(const Expr& expr); -/*! \brief Get free Vars from expr in PostDFS order. +/*! \brief Get free type parameters from expression expr. * * Free variables are variables that are not bound by a * let or a function parameter in the context. * * \param expr the expression. * - * \return List of free vars, in the PostDFS order visited by expr. + * \return List of free vars, in the PostDFS order in the expression. */ tvm::Array FreeVars(const Expr& expr); /*! \brief Get free TypeVars from expression expr. * - * Free type parameters are type parameters that are not bound by a function type in the context. + * Free type parameters are type parameters that are not bound by a function + * type in the context. * * \param expr the expression. * @@ -125,10 +130,12 @@ tvm::Array FreeTypeVars(const Expr& expr); /*! \brief Remove expressions which does not effect the program result. * - * It will remove let binding that are not referenced, and if branch that are not entered. + * It will remove let bindings which are not referenced, and branches that will + * not be entered. * - * For example, this pass should turn `let a = 1 in 2` into `2`, as the value of the expression does not depend on a. - * Another example is `if (true) then 1 else 2` will be optimized into 1. + * For example, this pass should turn `let a = 1 in 2` into `2`, as the value of + * the expression does not depend on a. Another example is `if (true) then 1 + * else 2` will be optimized into 1. * * \param e the expression to optimize. * @@ -136,27 +143,30 @@ tvm::Array FreeTypeVars(const Expr& expr); */ Expr DeadCodeElimination(const Expr& e); -/*! \brief Hash a Relay type. - * - * Implements structural hashing of a Relay type. - * - * \param type the type to hash. - * - * \return the hash value. - */ -size_t StructuralHash(const Type& type); - -/*! \brief Hash a Relay expression. - * - * Implements structural hashing of a Relay expression. - * - * \param expr the expression to hash. - * - * \return the hash value. - */ -size_t StructuralHash(const Expr& expr); +/*! \brief A hashing structure in the style of std::hash. */ +struct StructuralHash { + /*! \brief Hash a Relay type. + * + * Implements structural hashing of a Relay type. + * + * \param type the type to hash. + * + * \return the hash value. + */ + size_t operator()(const Type& type) const; + /*! \brief Hash a Relay expression. + * + * Implements structural hashing of a Relay expression. + * + * \param expr the expression to hash. + * + * \return the hash value. + */ + size_t operator()(const Expr& expr) const; +}; } // namespace relay } // namespace tvm + #endif // TVM_RELAY_PASS_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 731a816460ee..d3b60c1174fa 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -1,5 +1,7 @@ # pylint: disable=wildcard-import, redefined-builtin, invalid-name """The Relay IR namespace containing the IR definition and compiler.""" +from __future__ import absolute_import +from ..api import register_func from . import base from . import ty from . import expr @@ -15,6 +17,7 @@ from . import vision from . import image + from .scope_builder import ScopeBuilder # Span @@ -46,6 +49,21 @@ If = expr.If TupleGetItem = expr.TupleGetItem + # helper functions var = expr.var const = expr.const + +@register_func("relay._tensor_value_repr") +def _tensor_value_repr(tv): + return str(tv.data.asnumpy()) + +@register_func("relay._constant_repr") +def _tensor_constant_repr(tv): + return str(tv.data.asnumpy()) + +# pylint: disable=unused-argument +@register_func("relay.debug") +def _debug(*args): + import pdb + pdb.set_trace() diff --git a/python/tvm/relay/_interpreter.py b/python/tvm/relay/_interpreter.py new file mode 100644 index 000000000000..d04319c17a99 --- /dev/null +++ b/python/tvm/relay/_interpreter.py @@ -0,0 +1,4 @@ +"""The interface to the Evaluator exposed from C++.""" +from tvm._ffi.function import _init_api + +_init_api("relay._interpreter", __name__) diff --git a/python/tvm/relay/env.py b/python/tvm/relay/env.py index 9c3241e18ef8..37e0999dce9e 100644 --- a/python/tvm/relay/env.py +++ b/python/tvm/relay/env.py @@ -45,9 +45,12 @@ def __setitem__(self, var, func): func: Function The function. """ + return self._add(var, func) + + def _add(self, var, func, update=False): if isinstance(var, _base.string_types): var = _expr.GlobalVar(var) - _env.Environment_Add(self, var, func) + return _env.Environment_Add(self, var, func, update) def __getitem__(self, var): """Lookup a global function by name or by variable. diff --git a/python/tvm/relay/graph_runtime_codegen.py b/python/tvm/relay/graph_runtime_codegen.py new file mode 100644 index 000000000000..d0ce239fa7fd --- /dev/null +++ b/python/tvm/relay/graph_runtime_codegen.py @@ -0,0 +1,551 @@ +""" +A compiler from a Relay expression to TVM's graph runtime. + +The compiler is built from a few pieces. + +First we define a compiler from a single Relay expression to the +graph langauge. We require the expression to be a function. +The function's parameters correpond to the placeholder/inputs +and model parameters found in the computation graph representation. +The body of the function represents the computation graph. + +The compiler's output is a program in the graph language, which is composed of +graph langauge is composed of Node, NodeRef, InputNode, OpNode. +This "little language" represents programs in TVM's graph format. + +To connect to the graph runtime, we use a printer that converts our graph format +into TVM's JSON format. The resulting string can be loaded by +contrib.graph_runtime or any other TVM runtime comptatible system. + +We expose this functionality in compile_to_tvm. +""" + +from __future__ import absolute_import +import json +import attr +from . import ir_pass +from .op import Op +from .expr import Var, Function, Call, If, GlobalVar, Constant, Let, Tuple +from ..build_module import build as tvm_build_module +from .. contrib import graph_runtime +from .ir_pass import infer_type +from .. import cpu + +class AbstractExprVisitor(object): + """A visitor over Expr in Python.""" + + def __init__(self): + self.memo_map = {} + + # pylint: disable=no-else-return + def visit(self, expr): + """Apply the visitor to an expression.""" + found = self.memo_map.get(expr) + if found: + return found + + if isinstance(expr, Function): + res = self.visit_function(expr) + elif isinstance(expr, Call): + res = self.visit_call(expr) + elif isinstance(expr, Let): + res = self.visit_let(expr) + elif isinstance(expr, Var): + res = self.visit_var(expr) + elif isinstance(expr, GlobalVar): + res = self.visit_global_var(expr) + elif isinstance(expr, If): + res = self.visit_if(expr) + elif isinstance(expr, Tuple): + res = self.visit_tuple(expr) + elif isinstance(expr, Constant): + res = self.visit_constant(expr) + else: + raise Exception("warning unhandled case: {0}".format(type(expr))) + + self.memo_map[expr] = res + return res + + def visit_function(self, _): + raise Exception("Abstract method please implement me.") + + def visit_let(self, _): + raise Exception("Abstract method please implement me.") + + def visit_call(self, _): + raise Exception("Abstract method please implement me.") + + def visit_var(self, _): + raise Exception("Abstract method please implement me.") + + def visit_type(self, typ): + return typ + + def visit_if(self, _): + raise Exception("Abstract method please implement me.") + + def visit_tuple(self, _): + raise Exception("Abstract method please implement me.") + + def visit_constant(self, _): + raise Exception("Abstract method please implement me.") + + def visit_global_var(self, _): + raise Exception("Abstract method please implement me.") + + +class ExprMutator(AbstractExprVisitor): + """A functional visitor over Expr in Python.""" + + def visit_function(self, fn): + new_body = self.visit(fn.body) + return Function( + list(fn.params), + fn.ret_type, new_body, + fn.type_params) + + def visit_let(self, let): + new_var = self.visit(let.var) + new_val = self.visit(let.value) + new_body = self.visit(let.body) + return Let(new_var, new_val, new_body) + + def visit_call(self, call): + new_fn = self.visit(call.op) + new_args = [self.visit(arg) for arg in call.args] + return Call(new_fn, new_args, call.attrs) + + def visit_var(self, var): + return var + + def visit_global_id(self, global_var): + return global_var + + def visit_if(self, ite): + return If( + self.visit(ite.guard), + self.visit(ite.true_b), + self.visit(ite.false_b)) + + def visit_tuple(self, tup): + return Tuple([self.visit(field) for field in tup.fields]) + + def visit_constant(self, const): + return const + + +@attr.s +class NodeRef(object): + """A reference to a node, used for constructing the graph.""" + ident = attr.ib() + index = attr.ib(default=0) + version = attr.ib(default=0) + + def to_json(self): + return [self.ident, self.index, self.version] + + +@attr.s +class Node(object): + """The base class for nodes in the TVM runtime system graph input.""" + name = attr.ib() + attrs = attr.ib() + is_output = attr.ib() + + def to_json(self): + raise Exception("Abstract method, please implement me.") + + +@attr.s +class InputNode(Node): + """An input node in the TVM runtime system graph input.""" + name = attr.ib() + attrs = attr.ib() + is_output = attr.ib(default=False) + + def to_json(self): + return { + "op": "null", + "name": self.name, + "inputs": [] + } + + +@attr.s +class OpNode(Node): + """An operator node in the TVM runtime system's graph input.""" + op_name = attr.ib() + inputs = attr.ib() + op_attrs = attr.ib() + is_output = attr.ib(default=False) + + def to_json(self): + attrs = dict.copy(self.op_attrs) + # Extend ops with extra info. + attrs['func_name'] = self.op_name + # When do we flatten? + attrs['flatten_data'] = "0" + # Fix me! + attrs['num_inputs'] = str(len(self.inputs)) + attrs['num_outputs'] = "1" + + return { + "op": "tvm_op", + "name": self.name, + "attrs": attrs, + "inputs": self.inputs + } + + +def shape_to_json(shape): + return [sh.value for sh in shape] + + +def from_tensor(typ): + return (typ.dtype, shape_to_json(typ.shape)) + + +class GraphRuntimeCodegen(ExprMutator): + """The compiler from Relay to the TVM runtime system.""" + nodes = attr.ib() + id_map = attr.ib() + + def __init__(self, env): + ExprMutator.__init__(self) + self.nodes = [] + self.id_map = {} + self.env = env + + def add_node(self, node): + """ + Add a node to the graph. + + Parameters + ---------- + node: Node + The node to add to the graph. + + Returns + ------- + node_ref: NodeRef + A reference to the node. + + """ + self.nodes.append(node) + ident = len(self.nodes) - 1 + return NodeRef(ident) + + def add_binding(self, ident, ref): + """ + Add a identifier to node mapping. + + Parameters + ---------- + ident: relay.Var + The variable to map + + ref: NodeRef + The node the identifier points. + """ + self.id_map[ident] = ref + + def let_bind(self, ident, node): + """ + Let bind node to ident. + + Parameters + ---------- + ident: relay.Var + The variable to map. + + ref: NodeRef + The node the identifier points. + + Returns + ------- + ref: NodeRef + Return reference to the node. + """ + ref = self.add_node(node) + self.add_binding(ident, ref) + return ref + + def get_node(self, ref): + """ + Lookup a node by a node reference. + + Parameters + ---------- + ref: NodeRef + The reference to lookup. + + Returns + ------- + node: Node + The node. + """ + return self.nodes[ref.ident] + + def lookup(self, ident): + """ + Lookup a node by identifier. + + Parameters + ---------- + ident: relay.Var + The reference to lookup. + + Returns + ------- + node: Node + The node. + """ + return self.id_map[ident] + + def codegen(self, func): + """Compile a single function into a graph. + + Parameters + ---------- + func: tvm.relay.Expr + The function to compile. + """ + # First we convert all the parameters into input nodes. + params = func.params + + for param in params: + dtype, shape = from_tensor(param.type_annotation) + node = InputNode("{0}".format(param.name_hint), { + "shape": shape, + "dtype": dtype, + }) + self.let_bind(param, node) + + # Then we compile the body into a graph which can depend + # on input variables. + output_ref = self.visit(func.body) + + # Finally we retreive return value of program, which will + # become our output node. + self.get_node(output_ref).is_output = True + + def visit_let(self, let): + """ + Visit the let binding, by first traversing its value, + then setting the metadata on the returned NodeRef. + + Finally visit the body, and return the NodeRef corresponding + to it. + + Parameters + ---------- + let: tvm.relay.Expr + The let binding to transform. + + Returns + ------- + ref: NodeRef + The node reference to the body. + """ + ident = let.var + val = let.value + body = let.body + + val_ref = self.visit(val) + dtype, shape = from_tensor(val.checked_type()) + val_node = self.get_node(val_ref) + val_node.attrs["dtype"] = dtype + val_node.attrs["shape"] = shape + self.add_binding(ident, val_ref) + return self.visit(body) + + def visit_var(self, var): + return self.lookup(var) + + def visit_call(self, call): + """Transform a ::tvm.relay.Call into an operator in the TVM graph.""" + inputs = [] + for arg in call.args: + inputs.append(self.visit(arg).to_json()) + + if isinstance(call.op, Op): + raise Exception( + "Operators should be transformed away; try applying" + + "the fuse_ops transformation to the expression.") + elif isinstance(call.op, GlobalVar): + func = self.env[call.op] + elif isinstance(call.op, Function): + func = call.op + else: + raise Exception( + "TVM runtime does not support calls to {0}".format(type(call.op))) + + if int(func.attrs.Primitive) != 1: + raise Exception( + "TVM only support calls to primitive functions " + + "(i.e functions composed of fusable operator invocations)") + + op_name = func.attrs.LoweredFunc.name + + attrs = {'shape': shape_to_json(call.checked_type.shape), + 'dtype': call.checked_type.dtype} + call_hash = str(ir_pass.structural_hash(call)) + op_node = OpNode("call_" + call_hash, attrs, op_name, inputs, {}) + return self.add_node(op_node) + + def to_json(self): + """ + Convert the sequence of nodes stored by the compiler into the + TVM graph runtime format. + + Returns + ------- + graph_json : str + The generated JSON as a string. + """ + nodes = [] + # First we compute "nodes" field. + for node in self.nodes: + nodes.append(node.to_json()) + + arg_nodes = [] + heads = [] + # Compute "arg_nodes" and "heads" fields. + for i, node in enumerate(self.nodes): + if isinstance(node, InputNode): + arg_nodes.append(i) + + if node.is_output: + # Need to fix this. + heads.append(NodeRef(i).to_json()) + + def compute_node_row_ptr(nodes): + """Calculate the node_row_ptr field by doing a DFS backwards + from the output and reversing the path. + """ + row_ptr = [len(nodes)] + discovered = set() + stack = [] + stack.append(len(nodes) - 1) + while stack: + i = stack.pop() + if i not in discovered: + discovered.add(i) + row_ptr.append(i) + node = nodes[i] + if isinstance(node, OpNode): + for inp in node.inputs: + stack.append(inp[0]) + row_ptr.reverse() + return row_ptr + + # Compute "node_row_ptr". + node_row_ptr = compute_node_row_ptr(self.nodes) + + # Compute "attrs" field. + attrs = {} + + # These fields are mandatory. + shapes = [] + storage_ids = [] + dtype = [] + dltype = [] + + for i, node in enumerate(self.nodes): + storage_ids.append(i) + shapes.append(node.attrs['shape']) + if node.attrs['dtype'] == 'float32': + dtype.append(0) + dltype.append('float32') + + attrs["shape"] = ["list_shape", shapes] + attrs["storage_id"] = ["list_int", storage_ids] + attrs["dtype"] = ["list_int", dtype] + attrs["dltype"] = ["list_str", dltype] + + json_dict = { + "nodes": nodes, + "arg_nodes": arg_nodes, + "heads": heads, + "attrs": attrs, + "node_row_ptr": node_row_ptr + } + + return json.dumps(json_dict) + + +def build(env, func, target=None): + """ + Compile a single function to the components needed by the + TVM RTS. + + Parameters + ---------- + func: relay.Expr + The function to build. + + target: optional str + The target platform. + + Returns + ------- + (graph_json, mod, params): tuple of (str, tvm.Module, dict) + The outputs of building a Relay function for the TVM runtime. + + """ + if target is None: + target = 'llvm' + + comp = GraphRuntimeCodegen(env) + # NB(@jroesch) This creates lowered functions, and generates names for them + # + # We need these names to emit the correct graph as these are names of the + # functions contained in the module. + lowered_ops = ir_pass.lower_ops(env, func) + mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target) + + # Therefore the call to compile must come after. + comp.codegen(func) + graph_json = comp.to_json() + return graph_json, mod, None # params currently isn't supported by API + + +def graph_evaluate(env, func, *args): + """ + Corresponding function to tvm.relay.eval.evaluate. + + This function evaluates a Relay expression on the + TVM graph_runtime. + + Parameters + ---------- + env: tvm.relay.Environment + The global environment used. + + expr: tvm.relay.Expr + The expression to evaluate. + + args: list of tvm.relay.Expr + The arguments to apply to the expression, only works + if the expression has a function type. + + Returns + ------- + value: tvm.NDArray + The output Tensor produced by evaluating the expression. + """ + func = infer_type(func, env) + func = ir_pass.fuse_ops(env, func) + func = infer_type(func, env) + graph_json, mod, params = build(env, func) + assert params is None + gmodule = graph_runtime.create(graph_json, mod, cpu(0)) + # Create map of inputs. + inputs = {} + for i, arg in enumerate(args): + inputs[func.params[i].name_hint] = arg + # Set the inputs here. + gmodule.set_input(**inputs) + # Run the module, and fetch the output. + gmodule.run() + return gmodule.get_output(0) diff --git a/python/tvm/relay/interpreter.py b/python/tvm/relay/interpreter.py new file mode 100644 index 000000000000..06dc3c79fba4 --- /dev/null +++ b/python/tvm/relay/interpreter.py @@ -0,0 +1,130 @@ +#pylint: disable=no-else-return +"""An interface to the Realy interpreter.""" +from __future__ import absolute_import +import numpy as np +from .. import register_func, nd +from .base import NodeBase, register_relay_node +from . import _make +from . import _interpreter +from . import ir_pass +from .expr import Call, Constant, GlobalVar +from . import const +from .._ffi.base import integer_types + +class Value(NodeBase): + """Base class of all values. + """ + + @staticmethod + @register_func("relay.from_scalar") + def from_scalar(i, dtype=None): + """Convert a Python scalar to a Relay scalar.""" + if dtype is None: + if isinstance(i, integer_types): + dtype = 'int32' + elif isinstance(i, float): + dtype = 'float32' + elif isinstance(i, bool): + dtype = 'uint8' + else: + raise Exception("unable to infer dtype {0}".format(type(i))) + + return TensorValue(nd.array(np.array(i, dtype=dtype))) + + +@register_relay_node +class TupleValue(Value): + def __init__(self, *fields): + self.__init_handle_by_constructor__( + _make.TupleValue, fields) + + def __getitem__(self, field_no): + return self.fields[field_no] + + +@register_relay_node +class Closure(Value): + pass + + +@register_relay_node +class TensorValue(Value): + """A Tensor value produced by the evaluator.""" + + def __init__(self, data): + """Allocate a new TensorValue and copy the data from `array` into + the new array. + """ + if isinstance(data, np.ndarray): + data = nd.array(data) + + self.__init_handle_by_constructor__( + _make.TensorValue, data) + + def as_ndarray(self): + """Convert a Relay TensorValue into a tvm.ndarray.""" + return self.data + + def asnumpy(self): + """Convert a Relay TensorValue into a numpy.ndarray.""" + return self.data.asnumpy() + + def __eq__(self, other): + return self.data == other.data + + +def _arg_to_ast(arg): + if isinstance(arg, TensorValue): + return Constant(arg.data) + elif isinstance(arg, np.ndarray): + return Constant(nd.array(arg)) + elif isinstance(arg, Constant): + return arg + else: + return const(arg) + + +def apply_passes(expr, env=None): + ck_expr = ir_pass.infer_type(expr, env=env) + fused_expr = ir_pass.fuse_ops(env, ck_expr) + return fused_expr + + +def evaluate(env, expr, *args): + """ + Evaluate a Relay expression on the interpreter. + + Parameters + ---------- + env: tvm.relay.Environment + The global environment used. + + expr: tvm.relay.Expr + The expression to evaluate. + + args: list of tvm.relay.Expr + The arguments to apply to the expression, only works + if the expression has a function type. + + Returns + ------- + value: tvm.relay.eval.Value + The value produced by evaluating the expression. + """ + # assert len(args) == 0 + relay_args = [] + for arg in args: + relay_args.append(_arg_to_ast(arg)) + + # TODO: We need to move this optimization code into the optimizer/pass manager + if isinstance(expr, GlobalVar): + func = env[expr] + func = apply_passes(func, env) + env._add(expr, func, True) + opt_expr = Call(expr, relay_args) + # import pdb; pdb.set_trace() + return _interpreter.evaluate(env, opt_expr) + else: + expr = Call(expr, relay_args) + opt_expr = apply_passes(expr, env) + return _interpreter.evaluate(env, opt_expr) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 6adfaacdc86d..1940c8d60d7f 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -211,3 +211,9 @@ def structural_hash(value): msg = ("found value of type {0} expected" + "relay.Expr or relay.Type").format(type(value)) raise TypeError(msg) + +def fuse_ops(expr, env): + return _ir_pass.FuseOps(env, expr) + +def lower_ops(env, expr, target='llvm'): + return _ir_pass.LowerOps(env, expr, target) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 0bc2054cebdf..6ccb394ef8db 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -1,2 +1,49 @@ -#pylint: disable=invalid-name +#pylint: disable=invalid-name, unused-argument """Backend compiler related feature registration""" +import tvm +import topi +from . import register + +def add_compute(attrs, inputs, output_type, target): + assert len(inputs) == 2 + return [topi.add(inputs[0], inputs[1])] + +def add_schedule(outputs, target): + assert len(outputs) == 1 + return tvm.create_schedule(outputs[0].op) + +register("add", "FTVMCompute", add_compute) +register("add", "FTVMSchedule", add_schedule) + +def subtract_compute(attrs, inputs, output_type, target): + assert len(inputs) == 2 + return [topi.subtract(inputs[0], inputs[1])] + +def subtract_schedule(outputs, target): + assert len(outputs) == 1 + return tvm.create_schedule(outputs[0].op) + +register("subtract", "FTVMCompute", subtract_compute) +register("subtract", "FTVMSchedule", subtract_schedule) + +def multiply_compute(attrs, inputs, output_type, target): + assert len(inputs) == 2 + return [topi.multiply(inputs[0], inputs[1])] + +def multiply_schedule(outputs, target): + assert len(outputs) == 1 + return tvm.create_schedule(outputs[0].op) + +register("multiply", "FTVMCompute", multiply_compute) +register("multiply", "FTVMSchedule", multiply_schedule) + +def equal_compute(attrs, inputs, output_type, target): + assert len(inputs) == 2 + return [topi.equal(inputs[0], inputs[1])] + +def equal_schedule(outputs, target): + assert len(outputs) == 1 + return tvm.create_schedule(outputs[0].op) + +register("equal", "FTVMCompute", equal_compute) +register("equal", "FTVMSchedule", equal_schedule) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py new file mode 100644 index 000000000000..4f5dcd4dd08b --- /dev/null +++ b/python/tvm/relay/op/nn/_nn.py @@ -0,0 +1,16 @@ +#pylint: disable=invalid-name, unused-argument +"""Backend compiler related feature registration""" +import tvm +import topi +from .. import register + +def dense_compiler(attrs, inputs, output_type): + assert len(inputs) == 2 + return [topi.nn.dense(inputs[0], inputs[1])] + +def dense_schedule(outputs, target): + assert len(outputs) == 1 + return tvm.create_schedule(outputs[0].op) + +register("nn.dense", "FTVMCompute", dense_compiler) +register("nn.dense", "FTVMSchedule", dense_schedule) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index f1130b52e7ce..0c09f39a3c83 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -3,7 +3,8 @@ from ..base import register_relay_node from ..expr import Expr - +from ...api import register_func +from ...build_module import lower, build @register_relay_node class Op(Expr): @@ -75,3 +76,11 @@ def _register(v): _init_api("relay.op", __name__) + +@register_func("relay.op.compiler._lower") +def _lower(name, schedule, inputs, outputs): + return lower(schedule, list(inputs) + list(outputs), name=name) + +@register_func("relay.op.compiler._build") +def _build(lowered_funcs): + return build(lowered_funcs, target="llvm") diff --git a/python/tvm/relay/testing/mlp.py b/python/tvm/relay/testing/mlp.py index 67fa0d90c643..7d7d984f7526 100644 --- a/python/tvm/relay/testing/mlp.py +++ b/python/tvm/relay/testing/mlp.py @@ -17,6 +17,7 @@ """ a simple multilayer perceptron """ +from __future__ import absolute_import from tvm import relay from .init import create_workload diff --git a/src/relay/interpreter.cc b/src/relay/interpreter.cc new file mode 100644 index 000000000000..534a2a980e4a --- /dev/null +++ b/src/relay/interpreter.cc @@ -0,0 +1,432 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/interpreter.cc + * \brief An interpreter for the Relay IR. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "./ir/type_functor.h" + +namespace tvm { +namespace relay { + +using namespace runtime; + +inline const PackedFunc& GetPackedFunc(const std::string& name) { + const PackedFunc* pf = tvm::runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; + return *pf; +} + +/* Value Implementation */ +Closure ClosureNode::make(tvm::Map env, Function func) { + NodePtr n = make_node(); + n->env = std::move(env); + n->func = std::move(func); + return Closure(n); +} + +TVM_REGISTER_API("relay._make.Closure") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ClosureNode::make(args[0], args[1]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const ClosureNode* node, tvm::IRPrinter* p) { + p->stream << "ClosureNode(" << node->func << ")"; + }); + +TupleValue TupleValueNode::make(tvm::Array value) { + NodePtr n = make_node(); + n->fields = value; + return TupleValue(n); +} + +TVM_REGISTER_API("relay._make.TupleValue") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TupleValueNode::make(args[0]); + }); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TupleValueNode* node, + tvm::IRPrinter* p) { + p->stream << "TupleValueNode(" << node->fields << ")"; + }); + +TensorValue TensorValueNode::make(runtime::NDArray data) { + NodePtr n = make_node(); + n->data = std::move(data); + return TensorValue(n); +} + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) + .set_dispatch([](const TensorValueNode* node, + tvm::IRPrinter* p) { + auto to_str = GetPackedFunc("relay._tensor_value_repr"); + std::string data_str = to_str(GetRef(node)); + p->stream << "TensorValueNode(" << data_str << ")"; + }); + +TensorValue TensorValueNode::FromType(const Type& t) { + if (auto tt_node = t.as()) { + std::vector dims; + + for (auto dim : tt_node->shape) { + auto int_node = dim.as(); + CHECK(int_node) << "expected concrete dimensions"; + dims.push_back(int_node->value); + } + + DLDataType dtype; + DLContext context; + + switch (tt_node->dtype.code()) { + case halideir_type_int: + dtype.code = kDLInt; + break; + case halideir_type_uint: + dtype.code = kDLUInt; + break; + case halideir_type_float: + dtype.code = kDLFloat; + break; + default: + throw dmlc::Error("can not convert HalideIR type into DLTensor dtype"); + } + + dtype.bits = tt_node->dtype.bits(); + dtype.lanes = tt_node->dtype.lanes(); + + // TODO(@jroesch): Is this the right place to place the tensor? + context.device_type = DLDeviceType::kDLCPU; + context.device_id = 0; + runtime::NDArray data = NDArray::Empty(dims, dtype, context); + return TensorValueNode::make(data); + } else { + LOG(FATAL) << "expected a tensor type"; + return TensorValue(); + } +} + +TVM_REGISTER_API("relay._make.TensorValue") + .set_body([](TVMArgs args, TVMRetValue* ret) { + runtime::NDArray data = args[0]; + *ret = TensorValueNode::make(data); + }); + +/* Evaluator Implementation. */ +struct EvalError : dmlc::Error { + explicit EvalError(const std::string& msg) : Error(msg) {} +}; + +/*! + * \brief A stack frame in the Relay interpreter. + * + * Contains a mapping from relay::Var to relay::Value. + */ +struct Frame { + /*! \brief The set of local variables and arguments for the frame. */ + tvm::Map locals; + + explicit Frame(tvm::Map locals) : locals(locals) {} +}; + +/*! + * \brief The call stack in the Relay interpreter. + * + * Contains a stack of frames; each corresponding to + * a function call. + */ +struct Stack { + /*! \brief The stack frames. */ + std::vector frames; + Stack() : frames() { frames.push_back(Frame({})); } + + Frame& current_frame() { return frames.back(); } + + Value Lookup(const Var& local) { + for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) { + auto elem = frame->locals.find(local); + if (elem != frame->locals.end()) { + return (*elem).second; + } + } + + LOG(FATAL) << "could not find variable binding for " << local + << "address= " << local.operator->(); + return Value(); + } + /*! + * A wrapper around Frame to add RAII semantics to pushing and popping + * stack frames. + */ + struct LocalFrame { + Stack& st; + explicit LocalFrame(Stack& st, const Frame& fr) : st(st) { + st.frames.push_back(fr); + } + ~LocalFrame() { st.frames.pop_back(); } + }; +}; + +/*! \brief The equal comparator for expressions. */ +struct ExprEqual { + bool operator()(const Expr& a, const Expr& b) const { + return AlphaEqual(a, b); + } +}; + +struct Interpreter : ExprFunctor { + Environment env; + Stack stack; + using JitKey = Function; + + using OpMap = std::unordered_map; + + OpMap operator_map_; + + template + T with_frame(const Frame& fr, const std::function& f) { + Stack::LocalFrame lf(stack, fr); + return f(); + } + + Interpreter(Environment env) : env(env), operator_map_() {} + Interpreter(Environment env, OpMap operator_map) : env(env), operator_map_(operator_map) {} + + void extend(const Var& id, Value v) { + this->stack.current_frame().locals.Set(id, v); + } + + inline Value Lookup(const Var& local) { + return this->stack.Lookup(local); + } + + Value Eval(const Expr& expr) { + return (*this)(expr); + } + + Value VisitExpr(const Expr& expr) override { + RELAY_LOG(INFO) << "VisitExpr: " << expr << std::endl; + auto ret = ExprFunctor::VisitExpr(expr); + return ret; + } + + Value VisitExpr_(const VarNode* var_node) override { + return Lookup(GetRef(var_node)); + } + + Value VisitExpr_(const GlobalVarNode* op) override { + return Eval(this->env->Lookup(GetRef(op))); + } + + Value VisitExpr_(const OpNode* id) override { + // TODO(@jroesch): Eta-expand and return in this case. + throw EvalError( + "internal error, need to wrap intrinsic into call synthetic call node " + "in " + "this case, eta expand"); + } + + Value VisitExpr_(const ConstantNode* op) override { + return TensorValueNode::make(op->data); + } + + Value VisitExpr_(const TupleNode* op) override { + std::vector values; + + for (const auto& field : op->fields) { + Value field_value = Eval(field); + values.push_back(field_value); + } + + return TupleValueNode::make(values); + } + + Value VisitExpr_(const FunctionNode* func_node) override { + auto func = GetRef(func_node); + tvm::Map captured_env; + Array free_vars = FreeVars(func); + + for (const auto& var : free_vars) { + captured_env.Set(var, Eval(var)); + } + + return ClosureNode::make(captured_env, func); + } + + inline Value InvokeCompiledOp(PackedFunc func, const Array& args, + Type ret_type) { + // Marshal the arguments. + auto arg_len = args.size() + 1; + std::vector values(arg_len); + std::vector codes(arg_len); + TVMArgsSetter setter(values.data(), codes.data()); + TVMRetValue ret; + + // We need real type information to properly allocate the structure. + for (size_t i = 0; i < args.size(); i++) { + if (const TensorValueNode* tv = args[i].as()) { + setter(i, tv->data); + } + } + + // TVM's calling convention is that the final argument is the output + // buffer. To preserve the illusion of being a functional language + // we need to allocate space for the output buffer based on the + // return type. + CHECK(ret_type.as()); + + auto out_tensor = TensorValueNode::FromType(ret_type); + + setter(arg_len - 1, out_tensor->data); + func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &ret); + return out_tensor; + } + + Value Invoke(const Closure& closure, const tvm::Array& args) { + // Get a reference to the function inside the closure. + auto func = closure->func; + auto compiled = operator_map_.find(func); + tvm::Array funcs; + for (auto op : operator_map_) { + funcs.push_back(op.first); + } + + // This case we know we have precompiled the operator. + if (compiled != operator_map_.end()) { + auto func_ty = func->func_type_annotation(); + return InvokeCompiledOp(compiled->second, args, func_ty->ret_type); + } + + // Allocate a frame with the parameters and free variables. + tvm::Map locals; + + CHECK_EQ(func->params.size(), args.size()); + + for (size_t i = 0; i < func->params.size(); i++) { + CHECK_EQ(locals.count(func->params[i]), 0); + locals.Set(func->params[i], args[i]); + } + + // Add the var to value mappings from the Closure's environment. + for (auto it = closure->env.begin(); it != closure->env.end(); ++it) { + CHECK_EQ(locals.count((*it).first), 0); + locals.Set((*it).first, (*it).second); + } + + return with_frame(Frame(locals), [&]() { return Eval(func->body); }); + } + + Value VisitExpr_(const CallNode* call) override { + tvm::Array args; + for (auto arg : call->args) { + args.push_back(Eval(arg)); + } + + // We should not find operators after running fusion, + // and operator lowering. + // + // We have some functions cotaining chunks of operators + // which will be loaded into operator map. + if (auto op_node = call->op.as()) { + LOG(FATAL) << "found " << op_node->name + << "; operators should be removed by future passes; try " + "fusing and lowering"; + } + + // Now we just evaluate and expect to find a closure. + Value fn_val = Eval(call->op); + if (const ClosureNode* closure_node = fn_val.as()) { + auto closure = GetRef(closure_node); + return this->Invoke(closure, args); + } else { + throw EvalError( + "internal error: type error, expected function value in the call " + "position"); + } + } + + Value VisitExpr_(const LetNode* op) override { + auto value = Eval(op->value); + this->extend(op->var, value); + return Eval(op->body); + } + + Value VisitExpr_(const TupleGetItemNode* op) override { + Value val = Eval(op->tuple); + auto product_node = val.as(); + CHECK(product_node) + << "interal error: when evaluating TupleGetItem expected a tuple value"; + CHECK_LT(static_cast(op->index), product_node->fields.size()) + << "internal error: index out of bounds"; + return product_node->fields[op->index]; + } + + Value VisitExpr_(const IfNode* op) override { + Value v = Eval(op->cond); + if (const TensorValueNode* bv = v.as()) { + // TODO(@jroesch, @MK): Refactor code into helper from DCE. + if (reinterpret_cast(bv->data->data)[0]) { + return Eval(op->true_branch); + } else { + return Eval(op->false_branch); + } + } else { + throw EvalError("type error, type system should have caught this"); + } + } +}; + +Interpreter::OpMap CompileOperators(const Environment& env, const Expr& e) { + Interpreter::OpMap op_map; + auto lowered_ops = LowerOps(env, e); + RELAY_LOG(INFO) << "LoweredFuncs: " << lowered_ops << std::endl; + if (lowered_ops.size()) { + const PackedFunc* fbuild_ptr = Registry::Get("relay.op.compiler._build"); + CHECK(fbuild_ptr) << "Could not find registered function: relay.op.compiler._build"; + auto fbuild = *fbuild_ptr; + + // Collect the set of lowered functions to build a module. + Array lowered_funcs; + for (auto lop : lowered_ops) { + lowered_funcs.push_back(lop->lowered_func); + } + + Module module = fbuild(lowered_funcs); + + // Loop over the lowered operations to map them into the operator map. + for (auto lop : lowered_ops) { + Function func = lop->func; + LoweredFunc lf = lop->lowered_func; + + RELAY_LOG(INFO) << "LoweredFunc: " << lf->name << std::endl; + auto op_impl = module.GetFunction(lf->name); + op_map.insert({func, op_impl}); + } + } + + return op_map; +} + +Value Evaluate(Environment env, Expr e) { + auto op_map = CompileOperators(env, e); + Interpreter interp(env, op_map); + return interp.Eval(e); +} + +TVM_REGISTER_API("relay._interpreter.evaluate") + .set_body([](TVMArgs args, TVMRetValue* ret) { + Environment env = args[0]; + Expr expr = args[1]; + *ret = Evaluate(env, expr); + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index a68910e56b71..1f73f297f99a 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -66,3 +66,5 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) } // namespace relay } // namespace tvm + + diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc index dddad82c8afc..262758ba0478 100644 --- a/src/relay/ir/environment.cc +++ b/src/relay/ir/environment.cc @@ -49,9 +49,16 @@ void EnvironmentNode::Add(const GlobalVar& var, << "Environment#update changes type, not possible in this mode."; } this->functions.Set(var, checked_func); - // set gloval var map - CHECK(!global_var_map_.count(var->name_hint)) - << "Duplicate global function name " << var->name_hint; + + auto it = global_var_map_.find(var->name_hint); + if (it != global_var_map_.end()) { + CHECK_EQ((*it).second, var); + } else { + // set global var map + CHECK(!global_var_map_.count(var->name_hint)) + << "Duplicate global function name " << var->name_hint; + } + global_var_map_.Set(var->name_hint, var); } @@ -94,7 +101,7 @@ TVM_REGISTER_API("relay._make.Environment") TVM_REGISTER_API("relay._env.Environment_Add") .set_body([](TVMArgs args, TVMRetValue *ret) { Environment env = args[0]; - env->Add(args[1], args[2], false); + env->Add(args[1], args[2], args[3]); }); TVM_REGISTER_API("relay._env.Environment_GetGlobalVar") diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index c75c414c8ce9..993892a94861 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -26,7 +26,10 @@ TVM_REGISTER_API("relay._make.Constant") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ConstantNode* node, tvm::IRPrinter* p) { - p->stream << "Constant(TODO)"; + const PackedFunc* fprint = Registry::Get("relay._constant_repr"); + CHECK(fprint) << "unable to find printing function for constants"; + std::string data = (*fprint)(GetRef(node)); + p->stream << "Constant(" << data << ")"; }); TensorType ConstantNode::tensor_type() const { @@ -104,12 +107,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) Function FunctionNode::make(tvm::Array params, Expr body, Type ret_type, - tvm::Array type_params) { + tvm::Array type_params, + tvm::Attrs attrs) { NodePtr n = make_node(); n->params = std::move(params); n->body = std::move(body); n->ret_type = std::move(ret_type); n->type_params = std::move(type_params); + n->attrs = std::move(attrs); return Function(n); } @@ -121,6 +126,39 @@ FuncType FunctionNode::func_type_annotation() const { return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); } +NodeRef FunctionGetAttr(const Function& func, const std::string& key) { + if (!func->attrs.defined()) { return NodeRef(); } + + const DictAttrsNode* dict_attrs = func->attrs.as(); + CHECK(dict_attrs); + auto it = dict_attrs->dict.find(key); + if (it != dict_attrs->dict.end()) { + return (*it).second; + } else { + return NodeRef(); + } +} + +Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data) { + const DictAttrsNode* dattrs = func->attrs.as(); + Attrs func_attrs; + if (dattrs) { + Map dict = dattrs->dict; + dict.Set(key, data); + func_attrs = DictAttrsNode::make(dict); + } else { + Map dict = {{key, data}}; + func_attrs = DictAttrsNode::make(dict); + } + + return FunctionNode::make( + func->params, + func->body, + func->ret_type, + func->type_params, + func_attrs); +} + TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_API("relay._make.Function") @@ -132,7 +170,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionNode* node, tvm::IRPrinter* p) { p->stream << "FunctionNode(" << node->params << ", " << node->ret_type - << ", " << node->body << ", " << node->type_params << ")"; + << ", " << node->body << ", " << node->type_params << ", " + << node->attrs << ")"; }); Call CallNode::make(Expr op, Array args, Attrs attrs, diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index b7a752d43a5c..d0151f870f05 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -92,7 +92,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { body.same_as(op->body)) { return GetRef(op); } else { - return FunctionNode::make(params, body, ret_type, ty_params); + return FunctionNode::make(params, body, ret_type, ty_params, op->attrs); } } @@ -194,6 +194,7 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) { void ExprVisitor::VisitExpr_(const CallNode* op) { this->VisitExpr(op->op); + for (auto ty_arg : op->type_args) { this->VisitType(ty_arg); } diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index ce2049f269df..4fd91256db9c 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -285,11 +285,11 @@ class RelayHashHandler: int var_counter = 0; }; -size_t StructuralHash(const Type& type) { +size_t StructuralHash::operator()(const Type& type) const { return RelayHashHandler().TypeHash(type); } -size_t StructuralHash(const Expr& expr) { +size_t StructuralHash::operator()(const Expr& expr) const { return RelayHashHandler().ExprHash(expr); } diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc new file mode 100644 index 000000000000..3aea12931649 --- /dev/null +++ b/src/relay/pass/fuse_ops.cc @@ -0,0 +1,86 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file src/tvm/relay/pass/fuse_ops.cc + * + * \brief Fuse Relay eligble sequences of Relay operators into a single one. + * + */ +#include +#include +#include +#include +#include +#include +#include "../ir/type_functor.h" + +namespace tvm { +namespace relay { + +using namespace runtime; + +struct AbstractFusableOps : ExprMutator { + Environment env; + Array fusable_funcs; + int counter = 0; + size_t expr_hash; + + AbstractFusableOps(Environment env, size_t expr_hash) : env(env), expr_hash(expr_hash) {} + + Expr VisitExpr_(const CallNode* call) { + if (auto op_node = call->op.as()) { + // Placeholder fusion algorithm which abstracts + // single definitions into functions only. + Array params; + Array inner_args; + Array args; + + int param_number = 0; + for (auto arg : call->args) { + auto name = std::string("p") + std::to_string(param_number++); + auto type = arg->checked_type(); + auto var = VarNode::make(name, type); + params.push_back(var); + inner_args.push_back(var); + args.push_back(VisitExpr(arg)); + } + + auto body = CallNode::make(call->op, inner_args, call->attrs); + auto func = FunctionNode::make(params, body, call->checked_type(), {}); + func = FunctionSetAttr(func, "Primitive", tvm::Integer(1)); + std::string func_name = "fused_"; + func_name += op_node->name; + func_name += "_"; + func_name += std::to_string(counter++); + func_name += "_"; + func_name += std::to_string(expr_hash); + auto gv = GlobalVarNode::make(func_name); + env->Add(gv, func); + fusable_funcs.push_back(gv); + return CallNode::make(gv, args, Attrs()); + } else { + return ExprMutator::VisitExpr_(call); + } + } +}; + +Expr FuseOps(const Environment& env, const Expr& e) { + // First we convert all chains of fusable ops into + // abstracted functions which we mark as primtive + // then we convert these primtive functions into + // new operators. + auto abstract = AbstractFusableOps(env, StructuralHash()(e)); + auto abstracted_e = abstract.VisitExpr(e); + RELAY_LOG(INFO) << "FuseOps: before=" << e + << "Fuse: after=" << abstracted_e; + return abstracted_e; +} + +TVM_REGISTER_API("relay._ir_pass.FuseOps") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = FuseOps(args[1], args[0]); +}); + + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/lower_ops.cc b/src/relay/pass/lower_ops.cc new file mode 100644 index 000000000000..6bab9a924269 --- /dev/null +++ b/src/relay/pass/lower_ops.cc @@ -0,0 +1,222 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file src/tvm/relay/pass/lower_ops.cc + * + * \brief Lower a Relay program to set of TVM operators. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include "../ir/type_functor.h" + +namespace tvm { +namespace relay { + +using namespace runtime; + +LoweredOp LoweredOpNode::make(Function func, LoweredFunc lowered_func) { + auto node = make_node(); + node->func = func; + node->lowered_func = lowered_func; + return LoweredOp(node); +} + +struct AbstractLocalFunctions : ExprMutator { + Environment env; + size_t expr_hash; + int counter = 0; + std::unordered_set visited_funcs; + explicit AbstractLocalFunctions(Environment env) + : env(env), expr_hash(0), counter(0), visited_funcs() {} + + Expr Abstract(const Expr& e) { + expr_hash = StructuralHash()(e); + return VisitExpr(e); + } + + Expr VisitExpr_(const GlobalVarNode* gvar_node) final { + auto gvar = GetRef(gvar_node); + auto it = visited_funcs.find(gvar); + if (it == visited_funcs.end()) { + auto func = env->Lookup(gvar); + visited_funcs.insert(gvar); + auto new_func = FunctionNode::make( + func->params, + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); + env->Update(gvar, new_func); + } + return gvar; + } + + Expr VisitExpr_(const FunctionNode* func_node) final { + Function func = GetRef(func_node); + auto free_vars = FreeVars(func); + Array params; + for (auto free_var : free_vars) { + auto var = VarNode::make("free_var", free_var->checked_type()); + params.push_back(var); + } + std::string abs_func = "abstracted_func_"; + abs_func += std::to_string(counter++); + abs_func += std::to_string(expr_hash); + auto gv = GlobalVarNode::make(abs_func); + auto lifted_func = FunctionNode::make(params, func, Type(), {}, {}); + env->Add(gv, lifted_func); + Array args; + for (auto free_var : free_vars) { + args.push_back(free_var); + } + return CallNode::make(gv, args, {}); + } +}; + +struct LiveFunctions : ExprVisitor { + Environment env; + explicit LiveFunctions(Environment env) : env(env), global_funcs() {} + + std::unordered_set visited_funcs; + std::unordered_set global_funcs; + + void Live(const Expr& e) { + CHECK(!e.as()) + << "functions should of been transformed away by previous pass"; + VisitExpr(e); + } + + void VisitExpr_(const FunctionNode* func_node) { + LOG(FATAL) << "functions should of been transformed away by previous pass"; + } + + void VisitExpr_(const GlobalVarNode* var_node) final { + GlobalVar var = GetRef(var_node); + auto it = visited_funcs.find(var); + if (it == visited_funcs.end()) { + auto func = env->Lookup(var); + visited_funcs.insert(var); + // The last pass has trasnformed functions of the form: + // + // let x = fn (p_1, ..., p_n) { ... }; + // ... + // + // into, a top-level declaration: + // + // def abs_f(fv_1, ..., fv_n) { + // return (fn (p_1...,p_N) { ... };) + // } + // + // and: + // + // let x = abs_f(fv_1, ... fv_n); + // + // The only other case we can handle is + // + // fn foo(...) { body } + // + // We just search through the body in this case. + if (auto inner_func = func->body.as()) { + return VisitExpr(inner_func->body); + } else { + return VisitExpr(func->body); + } + } + } + + void VisitExpr_(const CallNode* call) final { + RELAY_LOG(INFO) << "LiveOps: CallNode=" << GetRef(call); + if (auto gv_node = call->op.as()) { + GlobalVar gvar = GetRef(gv_node); + Function func = env->Lookup(gvar); + + auto attr = FunctionGetAttr(func, "Primitive"); + + if (attr.defined() && Downcast(attr)->value == 1) { + global_funcs.insert(gvar); + } else { + VisitExpr(gvar); + } + + // Finally we need to ensure to visit all the args no matter what. + for (auto arg : call->args) { + VisitExpr(arg); + } + } else { + return ExprVisitor::VisitExpr_(call); + } + } +}; + +using FCompute = TypedPackedFunc( + const Attrs&, const Array&, Type, std::string)>; +using FSchedule = TypedPackedFunc&, std::string)>; + +/*! \brief Return the set of operators in their TVM format. */ +Array LowerOps(const Environment& env, const Expr& e, + const std::string& target) { + RELAY_LOG(INFO) << "LowerOps: e=" << e; + auto flower_ptr = Registry::Get("relay.op.compiler._lower"); + CHECK(flower_ptr); + PackedFunc flower = *flower_ptr; + + auto abstracted_e = AbstractLocalFunctions(env).Abstract(e); + auto live_funcs = LiveFunctions(env); + live_funcs.VisitExpr(abstracted_e); + + auto schedule_reg = Op::GetAttr("FTVMSchedule"); + auto compute_reg = Op::GetAttr("FTVMCompute"); + + Array lowered_funcs; + + for (auto func_name : live_funcs.global_funcs) { + auto func = env->Lookup(func_name); + auto call = Downcast(func->body); + auto op_node = call->op.as(); + CHECK(op_node) << "violated invariant that primtiive calls contain a single op call"; + auto op = GetRef(op_node); + RELAY_LOG(INFO) << "LowerOps: Lowering " << op->name; + + CHECK(IsPrimitiveOp(op)) << "failed to lower " + << op->name << "can only lower primitve operations"; + + Array inputs; + std::string input_name = "in"; + int i = 0; + for (auto type_arg : call->type_args) { + auto tt = Downcast(type_arg); + inputs.push_back(PlaceholderOpNode::make(input_name + std::to_string(i), + tt->shape, tt->dtype) + .output(0)); + i++; + } + + auto output_tt = op->op_type->ret_type; + Array outputs = + compute_reg[op](call->attrs, inputs, output_tt, target); + auto schedule = schedule_reg[op](outputs, target); + size_t hash = StructuralHash()(func); + LoweredFunc lf = + flower(op->name + std::to_string(hash), schedule, inputs, outputs); + func = FunctionSetAttr(func, "LoweredFunc", lf); + env->Add(func_name, func, true); + lowered_funcs.push_back(LoweredOpNode::make(func, lf)); + } + + return lowered_funcs; +} + +TVM_REGISTER_API("relay._ir_pass.LowerOps") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = LowerOps(args[0], args[1], args[2]); +}); + + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index c0f1db97b538..6aed93e511ed 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -272,8 +272,8 @@ class TypeInferencer : private ExprFunctor { auto* fn_ty_node = ftype.as(); CHECK(fn_ty_node != nullptr) - << "only expressions with function types can be called, at " - << call->span; + << "only expressions with function types can be called, found " + << ftype << " at " << call->span; Array type_args; FuncType fn_ty = Instantiate(fn_ty_node, &type_args); @@ -479,12 +479,16 @@ Expr TypeInferencer::Infer(Expr expr) { // Step 1: Solve the constraints. solver_.Solve(); // Step 2: Attach resolved types to checked_type field. - return Resolver(type_map_, &solver_).VisitExpr(expr); + auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); + CHECK(WellFormed(resolved_expr)); + return resolved_expr; } Expr InferType(const Expr& expr, const Environment& env) { - return TypeInferencer(env).Infer(expr); + auto e = TypeInferencer(env).Infer(expr); + CHECK(WellFormed(e)); + return e; } Function InferType(const Function& func, @@ -496,6 +500,7 @@ Function InferType(const Function& func, Expr func_ret = TypeInferencer(env).Infer(func_copy); auto map_node = env->functions.CopyOnWrite(); map_node->data.erase(var.node_); + CHECK(WellFormed(func_ret)); return Downcast(func_ret); } diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index c1f00c7b65e0..51ef0377868f 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -3,7 +3,7 @@ * * \file util.cc * - * \brief simple util for relay. + * \brief Utility functions for Relay. */ #include #include diff --git a/tests/python/relay/test_graph_runtime.py b/tests/python/relay/test_graph_runtime.py new file mode 100644 index 000000000000..1e55f890e514 --- /dev/null +++ b/tests/python/relay/test_graph_runtime.py @@ -0,0 +1,80 @@ +import numpy as np + +from tvm import relay +from tvm.relay.ir_pass import infer_type +from tvm.relay.interpreter import evaluate +from tvm.relay.graph_runtime_codegen import graph_evaluate +from tvm.relay.scope_builder import ScopeBuilder +from tvm.relay.op import add +from tvm.relay.env import Environment + +# @tq, @jr should we put this in testing ns? +def check_rts(env, expr, args, expected_result): + """ + Check that evaluating `expr` applied to the arguments produces + `result` on both the evaluator and TVM runtime. + + Parameters + ---------- + expr: + The expression to evaluate + + args: list of Expr + The arguments to supply the expr. + + expected_result: + The expected result of running the expression. + """ + eval_result = evaluate(env, expr, *args) + rts_result = graph_evaluate(env, expr, *args) + np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy()) + +def test_add_op_scalar(): + """ + Program: + fn (x, y) { + return x + y; + } + """ + env = Environment() + x = relay.var('x', shape=()) + y = relay.var('y', shape=()) + func = relay.Function([x, y], add(x, y)) + x_data = np.array(10.0, dtype='float32') + y_data = np.array(1.0, dtype='float32') + check_rts(env, func, [x_data, y_data], x_data + y_data) + +def test_add_op_tensor(): + """ + Program: + fn (x, y) { + return x + y; + } + """ + env = Environment() + x = relay.var('x', shape=(10, 5)) + y = relay.var('y', shape=(10, 5)) + func = relay.Function([x, y], add(x, y)) + x_data = np.random.rand(10, 5).astype('float32') + y_data = np.random.rand(10, 5).astype('float32') + check_rts(env, func, [x_data, y_data], x_data + y_data) + +def test_add_op_broadcast(): + """ + Program: + fn (x, y) { + return x + y; + } + """ + env = Environment() + x = relay.var('x', shape=(10, 5)) + y = relay.var('y', shape=(1, 5)) + func = relay.Function([x, y], add(x, y)) + x_data = np.random.rand(10, 5).astype('float32') + y_data = np.random.rand(1, 5).astype('float32') + check_rts(env, func, [x_data, y_data], x_data + y_data) + +if __name__ == "__main__": + test_add_op_scalar() + test_add_op_tensor() + test_add_op_broadcast() diff --git a/tests/python/relay/test_interpreter.py b/tests/python/relay/test_interpreter.py new file mode 100644 index 000000000000..9a431b4c9524 --- /dev/null +++ b/tests/python/relay/test_interpreter.py @@ -0,0 +1,142 @@ +import numpy as np +import tvm +from tvm import relay +from tvm.relay.interpreter import Value, TupleValue, evaluate +from tvm.relay import op +from tvm.relay.scope_builder import ScopeBuilder +from tvm.relay import testing + + +def check_eval(expr, args, expected_result, env=None, rtol=1e-07): + if env is None: + env = relay.env.Environment({}) + + result = evaluate(env, expr, *args) + np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) + + +def test_from_scalar(): + np.testing.assert_allclose(Value.from_scalar(1, 'int32').asnumpy(), 1) + np.testing.assert_allclose(Value.from_scalar(10.0, 'float32').asnumpy(), 10.0) + np.testing.assert_allclose(Value.from_scalar(True).asnumpy(), True) + + +def test_tuple_value(): + tv = TupleValue(Value.from_scalar( + 1), Value.from_scalar(2), Value.from_scalar(3)) + np.testing.assert_allclose(tv[0].asnumpy(), 1) + np.testing.assert_allclose(tv[1].asnumpy(), 2) + np.testing.assert_allclose(tv[2].asnumpy(), 3) + + +def test_id(): + x = relay.var('x', 'float32') + ident = relay.Function([x], x) + env = relay.env.Environment({}) + res = evaluate(env, ident, 1.0) + check_eval(ident, [1.0], 1.0) + + +def test_add_const(): + two = op.add(relay.const(1), relay.const(1)) + func = relay.Function([], two) + check_eval(func, [], 2) + + +def test_mul_param(): + x = relay.var('x', shape=(10, 10)) + y = relay.var('y', shape=(1, 10)) + func = relay.Function([x, y], op.multiply(x, y)) + x_data = np.random.rand(10, 10).astype('float32') + y_data = np.random.rand(1, 10).astype('float32') + check_eval(func, [x_data, y_data], x_data * y_data) + + +# failing due to numeric issues + +# def test_dense(): +# x = relay.var('x', shape=(10, 10)) +# w = relay.var('w', shape=(10, 10)) +# y = op.nn.dense(x, w) +# func = relay.Function([x, w], y) +# x_data = np.random.rand(10, 10).astype('float32') +# w_data = np.random.rand(10, 10).astype('float32') +# check_eval(func, [x_data, w_data], x_data @ w_data, rtol=0.1) + +# def test_linear(): +# x = relay.var('x', shape=(10, 10)) +# w = relay.var('w', shape=(10, 10)) +# b = relay.var('b', shape=(10,)) +# y = op.add(op.nn.dense(x, w), b) +# func = relay.Function([x, w, b], y) +# x_data = np.random.rand(10, 10).astype('float32') +# w_data = np.random.rand(10, 10).astype('float32') +# b_data = np.random.rand(10).astype('float32') +# check_eval(func, [x_data, w_data, b_data], x_data @ w_data + b_data) + +def test_equal(): + i = relay.var('i', shape=[], dtype='int32') + j = relay.var('i', shape=[], dtype='int32') + z = op.equal(i, j) + func = relay.Function([i, j], z, ret_type=relay.TensorType([], 'bool')) + i_data = relay.const(0) + j_data = relay.const(0) + check_eval(func, [i_data, j_data], True) + +def test_subtract(): + i = relay.var('i', shape=[], dtype='int32') + sub = op.subtract(i, relay.const(1, dtype='int32')) + func = relay.Function([i], sub, ret_type=relay.TensorType([], 'int32')) + i_data = np.array(1, dtype='int32') + check_eval(func, [i_data], 0) + +def test_simple_loop(): + env = relay.env.Environment({}) + sum_up = relay.GlobalVar('sum_up') + i = relay.var('i', shape=[], dtype='int32') + sb = ScopeBuilder() + with sb.if_scope(op.equal(i, relay.const(0, dtype='int32'))): + sb.ret(i) + with sb.else_scope(): + one_less = op.subtract(i, relay.const(1, dtype='int32')) + rec_call = relay.Call(sum_up, [one_less]) + sb.ret(op.add(rec_call, i)) + func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32')) + env[sum_up] = func + i_data = np.array(10, dtype='int32') + check_eval(sum_up, [i_data], sum(range(1, 11)), env=env) + +def test_loop(): + env = relay.env.Environment({}) + sum_up = relay.GlobalVar('sum_up') + i = relay.var('i', shape=[], dtype='int32') + accum = relay.var('accum', shape=[], dtype='int32') + sb = ScopeBuilder() + with sb.if_scope(op.equal(i, relay.const(0))): + sb.ret(accum) + with sb.else_scope(): + one_less = op.subtract(i, relay.const(1)) + new_accum = op.add(accum, i) + sb.ret(relay.Call(sum_up, [one_less, new_accum])) + func = relay.Function([i, accum], sb.get()) + env[sum_up] = func + i_data = np.array(10, dtype='int32') + accum_data = np.array(0, dtype='int32') + check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), env=env) + +def test_mlp(): + pass + # net = testing.mlp.get_workload(1) + # import pdb; pdb.set_trace() + +if __name__ == "__main__": + test_id() + test_add_const() + # test_dense() + # test_linear() + test_equal() + test_subtract() + test_simple_loop() + test_loop() + test_mlp() + diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index b1823004022c..31d350dc7ff7 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -5,6 +5,16 @@ import numpy as np from tvm.relay.ir_pass import infer_type from tvm import relay +from tvm.relay import op +from tvm.relay.scope_builder import ScopeBuilder + + +def assert_has_type(expr, typ, env=relay.env.Environment({})): + checked_expr = infer_type(expr, env) + checked_type = checked_expr.checked_type + if checked_type != typ: + raise RuntimeError("Type mismatch %s vs %s" % ( + checked_type, typ)) def test_monomorphic_let(): @@ -16,6 +26,31 @@ def test_monomorphic_let(): assert xchecked.checked_type == relay.scalar_type("float64") +def test_single_op(): + "Program: fn (x : float32) { let t1 = f(x); t1 }" + x = relay.var('x', shape=[]) + func = relay.Function([x], op.log(x)) + ttype = relay.TensorType([], dtype='float32') + assert_has_type(func, relay.FuncType([ttype], ttype)) + + +def test_add_broadcast_op(): + """ + Program: + fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { + return x + y; + } + """ + pass + # x = relay.var('x', shape=(10, 4)) + # y = relay.var('y', shape=(5, 10, 1)) + # z = x + y + # func = relay.Function([x, y], z) + # ttype = relay.TensorType((5, 5, 5), 'float32') + # expected_ty = relay.FuncType([ttype, ttype], ttype) + # assert_has_type(func.to_func(), expected_ty) + + def test_dual_op(): """Program: fn (x : Tensor[f32, (10, 10)]) { @@ -41,7 +76,6 @@ def f(x : Tensor[(10, 10), f32]) { return log(x); } """ - sb = relay.ScopeBuilder() tp = relay.TensorType((10, 10)) x = relay.var("x", tp) f = relay.Function([x], relay.log(x)) @@ -76,6 +110,24 @@ def f(n: i32, data: f32) -> f32 { assert "%3 = @f(%1, %2)" in env.astext() assert env[f].checked_type == relay.FuncType([ti32, tf32], tf32) +# This currently fails and should pass under the type system. +# +# This test is to illustrate problem with our weak form of +# unification. +# + + +def test_incomplete_call(): + sb = ScopeBuilder() + x = relay.var('x', dtype='int32') + f = relay.var('f') + func = relay.Function([x, f], relay.Call(f, [x])) + + try: + relay.ir_pass.infer_type(func) + assert False + except tvm.TVMError as e: + assert True def test_tuple(): tp = relay.TensorType((10,)) @@ -84,13 +136,13 @@ def test_tuple(): assert (relay.ir_pass.infer_type(res).checked_type == relay.TupleType([tp, tp])) - def test_free_expr(): x = relay.var("x", "float32") y = relay.add(x, x) yy = relay.ir_pass.infer_type(y) assert yy.checked_type == relay.scalar_type("float32") + def test_type_args(): x = relay.var("x", shape=(10, 10)) y = relay.var("y", shape=(1, 10)) @@ -107,6 +159,7 @@ def test_type_args(): assert sh2[0].value == 1 assert sh2[1].value == 10 + def test_self_reference(): """ Program: @@ -117,31 +170,41 @@ def f(x) { a = relay.TypeVar("a") x = relay.var("x", a) sb = relay.ScopeBuilder() + f = relay.Function([x], x) fx = relay.Call(f, [x]) assert relay.ir_pass.infer_type(x).checked_type == a assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a) assert relay.ir_pass.infer_type(fx).checked_type == a + def test_global_var_cow_issue(): env = relay.env.Environment({}) gv = relay.GlobalVar("foo") x = relay.var('x', shape=[]) - func = relay.Function([x], relay.Call(gv, [x]), relay.TensorType([], 'float32')) + func = relay.Function([x], relay.Call(gv, [x]), + relay.TensorType([], 'float32')) env[gv] = func - # They should both point to the same global variable if global variables are - # stable across type checking. - assert gv == func.body.op + + +def test_equal(): + i = relay.var('i', shape=[], dtype='int32') + eq = op.equal(i, relay.const(0, dtype='int32')) + # This should fail .... + func = relay.Function([i], eq, ret_type=relay.TensorType([], 'int32')) + if __name__ == "__main__": test_free_expr() test_dual_op() + test_single_op() test_recursion() test_monomorphic_let() test_decl() test_recursion() test_tuple() + test_incomplete_call() test_free_expr() test_type_args() test_self_reference() - test_global_var_cow_issue() \ No newline at end of file + test_global_var_cow_issue() diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 818376717176..d11dcd5da71a 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -1,5 +1,5 @@ #!/bin/bash -export PYTHONPATH=python:apps/extension/python +export PYTHONPATH=python:topi/python:apps/extension/python export LD_LIBRARY_PATH=build:${LD_LIBRARY_PATH} rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc