diff --git a/nnvm/README.md b/nnvm/README.md index 6872e12120d0..6e4b4568920e 100644 --- a/nnvm/README.md +++ b/nnvm/README.md @@ -3,16 +3,54 @@ [![Build Status](https://travis-ci.org/dmlc/nnvm.svg?branch=master)](https://travis-ci.org/dmlc/nnvm) [![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE) -NNVM is a reusable computational graph optimization and compilation stack for deep learning systems. -NNVM provides modules to: +NNVM is a reusable computational graph optimization and compilation stack for deep learning systems. It provides modules to: - Represent deep learning workloads from front-end frameworks via a graph IR. - Optimize computation graphs to improve performance. - Compile into executable modules and deploy to different hardware backends with minimum dependency. -NNVM is designed to add new frontend, operators and graph optimizations in a decentralized fashion without changing the core interface. NNVM is part of [TVM stack](https://github.com/dmlc/tvm), which provides an end to end IR compilation stack for deploying deep learning workloads into different hardware backends +NNVM is designed to add new frontend, operators and graph optimizations in a decentralized fashion without changing the core interface. NNVM is part of [TVM stack](https://github.com/dmlc/tvm). NNVM compiler toolchain can target hardware backends supported by TVM. +The compiled module can be deployed to server, mobile, embedded devices and browsers with minimum dependency, in languages including c++, python, javascript, java, objective-c. + +The following code snippet demonstrates the general workflow of nnvm compiler toolchain. + +```python +import tvm +from tvm.contrib import graph_runtime, rpc +import nnvm.frontend +import nnvm.compiler + +# get model from frameworks +# change xyz to supported framework name. +graph, params = nnvm.frontend.from_xyz(...) + +# optimize and compile the graph to get a deployable module +# target can be "opencl", "llvm", "metal" or any target supported by tvm +target = "cuda" +graph, lib, params = nnvm.compiler.build( + graph, target, shape={"data", data_shape}, params=params) + +# deploy and run on gpu(0) +module = graph_runtime.create(graph, lib, tvm.gpu(0)) +module.set_input(**params) +output = tvm.nd.empty(out_shape, ctx=tvm.gpu(0)) +for data_array in dataset: + module.set_input("data", data_array) + module.run() + module.get_output(0, output) + +# deploy to remote mobile/rasp/browser with minimum tvm rpc runtime +# useful for quick experiments on mobile devices +remote = rpc.connect(remote_host, remote_port) +lib.export_library("mylib.so") +remote.upload("mylib.so") +rlib = rpc.load_module("mylib.so") +# run on remote device +rmodule = graph_runtime.create(graph, rlib, remote.gpu(0)) +rmodule.set_input(**params) +rmodule.run() +``` ## Links - [TinyFlow](https://github.com/tqchen/tinyflow) on how you can use NNVM to build a TensorFlow like API. - [Apache MXNet](http://mxnet.io/) uses NNVM as a backend. - diff --git a/nnvm/docs/api/python/compiler.rst b/nnvm/docs/api/python/compiler.rst index 74f3bed07e52..4b995b28cd9e 100644 --- a/nnvm/docs/api/python/compiler.rst +++ b/nnvm/docs/api/python/compiler.rst @@ -7,6 +7,10 @@ nnvm.compiler .. autofunction:: nnvm.compiler.build_config +.. autofunction:: nnvm.compiler.save_param_dict + +.. autofunction:: nnvm.compiler.load_param_dict + .. autofunction:: nnvm.compiler.optimize .. automodule:: nnvm.compiler.graph_util diff --git a/nnvm/python/nnvm/compiler/__init__.py b/nnvm/python/nnvm/compiler/__init__.py index 3954c920a67a..1625150a6edc 100644 --- a/nnvm/python/nnvm/compiler/__init__.py +++ b/nnvm/python/nnvm/compiler/__init__.py @@ -1,6 +1,7 @@ """NNVM compiler toolchain. -User only need to use :any:`build` and :any:`build_config` to do the compilation. +User only need to use :any:`build` and :any:`build_config` to do the compilation, +and :any:`save_param_dict` to save the parameters into bytes. The other APIs are for more advanced interaction with the compiler toolchain. """ from __future__ import absolute_import @@ -10,6 +11,7 @@ from . import build_module from . build_module import build, optimize, build_config from . compile_engine import engine, graph_key +from . param_dict import save_param_dict, load_param_dict from .. import symbol as _symbol from .. import graph as _graph diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 8c4b1a46f6bd..5f1fbcc0df56 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -9,7 +9,7 @@ from .. import graph as _graph OPT_PASS_LEVEL = { - "SimplifyInference": 2, + "SimplifyInference": 0, "PrecomputePrune": 2, "OpFusion": 1 } @@ -26,6 +26,7 @@ class BuildConfig(object): current = None defaults = { "opt_level": 2, + "add_pass": None, } def __init__(self, **kwargs): self._old_scope = None @@ -53,6 +54,23 @@ def __exit__(self, ptype, value, trace): assert self._old_scope BuildConfig.current = self._old_scope + def pass_enabled(self, pass_name): + """Get whether pass is enabled. + + Parameters + ---------- + pass_name : str + The optimization pass name + + Returns + ------- + enabled : bool + Whether pass is enabled. + """ + if self.add_pass and pass_name in self.add_pass: + return True + return self.opt_level >= OPT_PASS_LEVEL[pass_name] + BuildConfig.current = BuildConfig() @@ -64,6 +82,9 @@ def build_config(**kwargs): opt_level: int, default=2 Optimization level. See OPT_PASS_LEVEL for level of each pass. + add_pass: set of str + Optimization pass to be added regardless of optimization level. + Returns ------- config: BuildConfig @@ -120,7 +141,7 @@ def optimize(graph, shape, dtype="float32"): """ # pylint: disable=unused-argument cfg = BuildConfig.current - if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyInference"]: + if cfg.pass_enabled("SimplifyInference"): graph = graph_attr.set_shape_inputs(graph, shape) graph = graph.apply(["InferShape", "SimplifyInference"]) return graph @@ -182,14 +203,17 @@ def build(graph, target, shape, dtype="float32", params=None): # Apply optimization graph = optimize(graph, shape, dtype) # Precompute prune - if params and cfg.opt_level >= OPT_PASS_LEVEL["PrecomputePrune"]: + if params and cfg.pass_enabled("PrecomputePrune"): graph, params = precompute_prune(graph, params) shape, dtype = _update_shape_dtype(shape, dtype, params) # Operator Fusion and generatiom graph = graph_attr.set_shape_inputs(graph, shape) graph = graph_attr.set_dtype_inputs(graph, dtype) graph._set_json_attr("target", target, "str") - graph._set_json_attr("opt_level", cfg.opt_level, "int") + if cfg.pass_enabled("OpFusion"): + graph._set_json_attr("opt_level", 1, "int") + else: + graph._set_json_attr("opt_level", 0, "int") graph = graph.apply("InferShape").apply("InferType") graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile") libmod = graph_attr._move_out_module(graph, "module") diff --git a/nnvm/python/nnvm/compiler/param_dict.py b/nnvm/python/nnvm/compiler/param_dict.py new file mode 100644 index 000000000000..97c3158c694f --- /dev/null +++ b/nnvm/python/nnvm/compiler/param_dict.py @@ -0,0 +1,66 @@ +"""Helper utility to save parameter dict""" +import tvm + +_save_param_dict = tvm.get_global_func("nnvm.compiler._save_param_dict") +_load_param_dict = tvm.get_global_func("nnvm.compiler._load_param_dict") + +def save_param_dict(params): + """Save parameter dictionary to binary bytes. + + The result binary bytes can be loaded by the + GraphModule with API "load_params". + + Parameters + ---------- + params : dict of str to NDArray + The parameter dictionary. + + Returns + ------- + param_bytes: bytearray + Serialized parameters. + + Examples + -------- + .. code-block:: python + + # compile and save the modules to file. + graph, lib, params = nnvm.compiler.build( + graph, target, shape={"data", data_shape}, params=params) + module = graph_runtime.create(graph, lib, tvm.gpu(0)) + # save the parameters as byte array + param_bytes = nnvm.compiler.save_param_dict(params) + # We can serialize the param_bytes and load it back later. + # Pass in byte array to module to directly set parameters + module["load_params"](param_bytes) + """ + args = [] + for k, v in params.items(): + args.append(k) + args.append(tvm.nd.array(v)) + return _save_param_dict(*args) + + +def load_param_dict(param_bytes): + """Load parameter dictionary to binary bytes. + + Parameters + ---------- + param_bytes: bytearray + Serialized parameters. + + Returns + ------- + params : dict of str to NDArray + The parameter dictionary. + """ + if isinstance(param_bytes, (bytes, str)): + param_bytes = bytearray(param_bytes) + load_mod = _load_param_dict(param_bytes) + size = load_mod(0) + param_dict = {} + for i in range(size): + key = load_mod(1, i) + dltensor_handle = load_mod(2, i) + param_dict[key] = tvm.nd.NDArray(dltensor_handle, False) + return param_dict diff --git a/nnvm/python/nnvm/graph.py b/nnvm/python/nnvm/graph.py index bfe5251e2bd8..2ea365e67ef4 100644 --- a/nnvm/python/nnvm/graph.py +++ b/nnvm/python/nnvm/graph.py @@ -12,7 +12,7 @@ from ._base import c_array, c_str, nn_uint, py_str, string_types from ._base import GraphHandle, SymbolHandle from ._base import check_call -from .symbol import Symbol, Group as _Group +from .symbol import Variable, Symbol, Group as _Group class GraphIndex(object): """Index for quickly accessing graph attributes. @@ -174,9 +174,19 @@ def symbol(self): check_call(_LIB.NNGraphGetSymbol(self.handle, ctypes.byref(shandle))) return Symbol(shandle) + def json(self): + """Get JSON representation of the graph + + Returns + ------- + json : str + JSON representation of the graph + """ + return self.apply("SaveJSON").json_attr("json") + def _tvm_graph_json(self): """Get TVM graph json""" - return self.apply("SaveJSON").json_attr("json") + return self.json() @property def index(self): @@ -225,6 +235,24 @@ def apply(self, passes): return Graph(ghandle) +def load_json(json_str): + """Create a new graph by loading from json + + Parameters + ---------- + json_str : str + The json string + + Returns + ------- + graph : Graph + The loaded graph + """ + ret = create(Variable("x")) + ret._set_json_attr("json", json_str) + return ret.apply("LoadJSON") + + def create(symbol): """Create a new graph from symbol. diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index b97bf50b98a2..b8c6a1adc02c 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -15,46 +15,10 @@ #include #include #include "./compile_engine.h" -#include "../../tvm/src/runtime/graph/graph_runtime.h" +#include "./graph_runtime.h" namespace nnvm { namespace compiler { - - -struct TVMOpParam : public dmlc::Parameter { - std::string func_name; - uint32_t num_inputs; - uint32_t num_outputs; - uint32_t flatten_data; - - DMLC_DECLARE_PARAMETER(TVMOpParam) { - DMLC_DECLARE_FIELD(func_name); - DMLC_DECLARE_FIELD(num_inputs).set_default(1); - DMLC_DECLARE_FIELD(num_outputs).set_default(1); - DMLC_DECLARE_FIELD(flatten_data).set_default(0); - } -}; - -DMLC_REGISTER_PARAMETER(TVMOpParam); - -// parser -inline void TVMOpParamParser(nnvm::NodeAttrs* attrs) { - TVMOpParam param; - param.Init(attrs->dict); - attrs->parsed = std::move(param); -} - -NNVM_REGISTER_OP(tvm_op) -.set_attr_parser(TVMOpParamParser) -.set_num_inputs([](const NodeAttrs& attrs) { - const TVMOpParam& param = nnvm::get(attrs.parsed); - return param.num_inputs; - }) -.set_num_outputs([](const NodeAttrs& attrs) { - const TVMOpParam& param = nnvm::get(attrs.parsed); - return param.num_outputs; - }); - using namespace tvm; // The single fuse rule. diff --git a/nnvm/src/compiler/graph_runtime.cc b/nnvm/src/compiler/graph_runtime.cc new file mode 100644 index 000000000000..902ccad25914 --- /dev/null +++ b/nnvm/src/compiler/graph_runtime.cc @@ -0,0 +1,182 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file graph_runtime.cc + * \brief Interface code with TVM graph runtime. +*/ +#include +#include +#include +#include +#include "./graph_runtime.h" + +namespace nnvm { +namespace compiler { + +using tvm::runtime::TVMArgs; +using tvm::runtime::TVMRetValue; +using tvm::runtime::PackedFunc; +using tvm::runtime::kTVMNDArrayMagic; +using tvm::runtime::kTVMNDArrayListMagic; + +DMLC_REGISTER_PARAMETER(TVMOpParam); + +// parser +inline void TVMOpParamParser(nnvm::NodeAttrs* attrs) { + TVMOpParam param; + param.Init(attrs->dict); + attrs->parsed = std::move(param); +} + +NNVM_REGISTER_OP(tvm_op) +.set_attr_parser(TVMOpParamParser) +.set_num_inputs([](const NodeAttrs& attrs) { + const TVMOpParam& param = nnvm::get(attrs.parsed); + return param.num_inputs; + }) +.set_num_outputs([](const NodeAttrs& attrs) { + const TVMOpParam& param = nnvm::get(attrs.parsed); + return param.num_outputs; + }); + +bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) { + uint64_t header = kTVMNDArrayMagic, reserved = 0; + strm->Write(&header, sizeof(header)); + strm->Write(&reserved, sizeof(reserved)); + + strm->Write(&tensor->ctx, sizeof(tensor->ctx)); + strm->Write(&tensor->ndim, sizeof(tensor->ndim)); + strm->Write(&tensor->dtype, sizeof(tensor->dtype)); + + int ndim = tensor->ndim; + strm->Write(tensor->shape, sizeof(int64_t) * ndim); + + int type_size = tensor->dtype.bits / 8; + int64_t size = 1; + for (int i = 0; i < ndim; ++i) { + size *= tensor->shape[i]; + } + int64_t data_byte_size = type_size * size; + strm->Write(&data_byte_size, sizeof(data_byte_size)); + strm->Write(tensor->data, data_byte_size); + return true; +} + +DLTensor* LoadDLTensor(dmlc::Stream* strm) { + uint64_t header, reserved; + CHECK(strm->Read(&header, sizeof(header))) + << "Invalid DLTensor file format"; + CHECK(strm->Read(&reserved, sizeof(reserved))) + << "Invalid DLTensor file format"; + CHECK(header == kTVMNDArrayMagic) + << "Invalid DLTensor file format"; + + DLTensor tensor; + CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx))) + << "Invalid DLTensor file format"; + CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim))) + << "Invalid DLTensor file format"; + CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype))) + << "Invalid DLTensor file format"; + std::vector shape(tensor.ndim); + CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim)) + << "Invalid DLTensor file format"; + DLTensor* ret; + CHECK_EQ(TVMArrayAlloc(shape.data(), + tensor.ndim, + tensor.dtype.code, + tensor.dtype.bits, + tensor.dtype.lanes, + static_cast(tensor.ctx.device_type), + tensor.ctx.device_id, + &ret), 0) << TVMGetLastError(); + int64_t size = 1; + int type_size = ret->dtype.bits / 8; + for (int i = 0; i < ret->ndim; ++i) { + size *= ret->shape[i]; + } + int64_t data_byte_size; + CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size))) + << "Invalid DLTensor file format"; + CHECK(data_byte_size == type_size * size) + << "Invalid DLTensor file format"; + CHECK(strm->Read(ret->data, type_size * size)) + << "Invalid DLTensor file format"; + return ret; +} + +TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict") +.set_body([](TVMArgs args, TVMRetValue *rv) { + CHECK_EQ(args.size() % 2, 0u); + size_t num_params = args.size() / 2; + std::vector names; + names.reserve(num_params); + std::vector arrays; + arrays.reserve(num_params); + for (size_t i = 0; i < num_params * 2; i += 2) { + names.emplace_back(args[i].operator std::string()); + arrays.emplace_back(args[i + 1].operator DLTensor*()); + } + std::string bytes; + dmlc::MemoryStringStream strm(&bytes); + dmlc::Stream* fo = &strm; + uint64_t header = kTVMNDArrayListMagic, reserved = 0; + fo->Write(&header, sizeof(header)); + fo->Write(&reserved, sizeof(reserved)); + fo->Write(names); + { + uint64_t sz = static_cast(arrays.size()); + fo->Write(&sz, sizeof(sz)); + for (size_t i = 0; i < sz; ++i) { + SaveDLTensor(fo, arrays[i]); + } + } + TVMByteArray arr; + arr.data = bytes.c_str(); + arr.size = bytes.length(); + *rv = arr; + }); + + +TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict") +.set_body([](TVMArgs args, TVMRetValue *rv) { + std::string bytes = args[0]; + std::vector data; + std::vector names; + dmlc::MemoryStringStream memstrm(&bytes); + dmlc::Stream* strm = &memstrm; + + uint64_t header, reserved; + CHECK(strm->Read(&header)) + << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) + << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) + << "Invalid parameters file format"; + + CHECK(strm->Read(&names)) + << "Invalid parameters file format"; + uint64_t sz; + strm->Read(&sz, sizeof(sz)); + size_t size = static_cast(sz); + CHECK(size == names.size()) + << "Invalid parameters file format"; + for (size_t i = 0; i < size; ++i) { + data.push_back(LoadDLTensor(strm)); + } + auto packed = [data, names](TVMArgs args, TVMRetValue* rv) { + int code = args[0]; + if (code == 0) { + *rv = static_cast(data.size()); + } else if (code == 1) { + int index = args[1]; + *rv = names[index]; + } else { + CHECK_EQ(code, 2); + int index = args[1]; + *rv = static_cast(data[index]); + } + }; + *rv = PackedFunc(packed); + }); +} // namespace compiler +} // namespace nnvm diff --git a/nnvm/src/compiler/graph_runtime.h b/nnvm/src/compiler/graph_runtime.h new file mode 100644 index 000000000000..085e5bbf062f --- /dev/null +++ b/nnvm/src/compiler/graph_runtime.h @@ -0,0 +1,32 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file graph_runtime.h + * \brief Interface code with TVM graph runtime. +*/ +#ifndef NNVM_COMPILER_GRAPH_RUNTIME_H_ +#define NNVM_COMPILER_GRAPH_RUNTIME_H_ + +#include +#include +#include "../../tvm/src/runtime/graph/graph_runtime.h" + +namespace nnvm { +namespace compiler { + +struct TVMOpParam : public dmlc::Parameter { + std::string func_name; + uint32_t num_inputs; + uint32_t num_outputs; + uint32_t flatten_data; + + DMLC_DECLARE_PARAMETER(TVMOpParam) { + DMLC_DECLARE_FIELD(func_name); + DMLC_DECLARE_FIELD(num_inputs).set_default(1); + DMLC_DECLARE_FIELD(num_outputs).set_default(1); + DMLC_DECLARE_FIELD(flatten_data).set_default(0); + } +}; + +} // namespace compiler +} // namespace nnvm +#endif // NNVM_COMPILER_GRAPH_RUNTIME_H_ diff --git a/nnvm/tests/python/compiler/test_build.py b/nnvm/tests/python/compiler/test_build.py index 936631185032..f4e57954f53f 100644 --- a/nnvm/tests/python/compiler/test_build.py +++ b/nnvm/tests/python/compiler/test_build.py @@ -70,7 +70,8 @@ def test_precompute_prune(): m = graph_runtime.create(graph, lib, tvm.cpu(0)) params["y"] = ny res = tvm.nd.empty(shape) - m.run(**params) + m["load_params"](nnvm.compiler.save_param_dict(params)) + m.run() out = m.get_output(0, out=res) np.testing.assert_allclose( res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy() + na.asnumpy()) diff --git a/nnvm/tests/python/compiler/test_param_dict.py b/nnvm/tests/python/compiler/test_param_dict.py new file mode 100644 index 000000000000..667ae4e48cec --- /dev/null +++ b/nnvm/tests/python/compiler/test_param_dict.py @@ -0,0 +1,19 @@ +import numpy as np +import nnvm.compiler + +def test_save_load(): + x = np.random.uniform(size=(10, 2)).astype("float32") + y = np.random.uniform(size=(1, 2, 3)).astype("float32") + x[:] = 1 + y[:] = 1 + params = {"x": x, "y": y} + param_bytes = nnvm.compiler.save_param_dict(params) + assert isinstance(param_bytes, bytearray) + param2 = nnvm.compiler.load_param_dict(param_bytes) + assert len(param2) == 2 + np.testing.assert_equal(param2["x"].asnumpy(), x) + np.testing.assert_equal(param2["y"].asnumpy(), y) + + +if __name__ == "__main__": + test_save_load() diff --git a/nnvm/tests/python/unittest/test_graph.py b/nnvm/tests/python/unittest/test_graph.py index f41d62538817..60c842f36a6d 100644 --- a/nnvm/tests/python/unittest/test_graph.py +++ b/nnvm/tests/python/unittest/test_graph.py @@ -10,6 +10,9 @@ def test_json_pass(): ret._set_json_attr('json', ret.json_attr('json')) g2 = ret.apply('LoadJSON') assert g2.apply('SaveJSON').json_attr('json') == ret.json_attr('json') + json = g.json() + g2 = graph.load_json(json) + assert json == g2.json() def test_json_pass_with_attr():