diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 16419e737568..168d3c67b499 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -256,14 +256,6 @@ def _from_mxnet_impl(symbol, graph): nnvm.sym.Symbol Converted symbol """ - try: - from mxnet import sym as mx_sym # pylint: disable=import-self - except ImportError as e: - raise ImportError('{}. MXNet is required to parse symbols.'.format(e)) - - if not isinstance(symbol, mx_sym.Symbol): - raise ValueError("Provided {}, while MXNet symbol is expected", type(symbol)) - if _is_mxnet_group_symbol(symbol): return [_from_mxnet_impl(s, graph) for s in symbol] @@ -294,7 +286,7 @@ def from_mxnet(symbol, arg_params=None, aux_params=None): Parameters ---------- - symbol : mxnet.Symbol + symbol : mxnet.Symbol or mxnet.gluon.HybridBlock MXNet symbol arg_params : dict of str to mx.NDArray @@ -305,18 +297,36 @@ def from_mxnet(symbol, arg_params=None, aux_params=None): Returns ------- - net: nnvm.Symbol + sym : nnvm.Symbol Compatible nnvm symbol params : dict of str to tvm.NDArray The parameter dict to be used by nnvm """ - sym = _from_mxnet_impl(symbol, {}) - params = {} - arg_params = arg_params if arg_params else {} - aux_params = aux_params if aux_params else {} - for k, v in arg_params.items(): - params[k] = tvm.nd.array(v.asnumpy()) - for k, v in aux_params.items(): - params[k] = tvm.nd.array(v.asnumpy()) + try: + import mxnet as mx # pylint: disable=import-self + except ImportError as e: + raise ImportError('{}. MXNet is required to parse symbols.'.format(e)) + + if isinstance(symbol, mx.sym.Symbol): + sym = _from_mxnet_impl(symbol, {}) + params = {} + arg_params = arg_params if arg_params else {} + aux_params = aux_params if aux_params else {} + for k, v in arg_params.items(): + params[k] = tvm.nd.array(v.asnumpy()) + for k, v in aux_params.items(): + params[k] = tvm.nd.array(v.asnumpy()) + elif isinstance(symbol, mx.gluon.HybridBlock): + data = mx.sym.Variable('data') + sym = symbol(data) + sym = _from_mxnet_impl(sym, {}) + params = {} + for k, v in symbol.collect_params().items(): + params[k] = tvm.nd.array(v.data().asnumpy()) + elif isinstance(symbol, mx.gluon.Block): + raise NotImplementedError("The dynamic Block is not supported yet.") + else: + msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol)) + raise ValueError(msg) return sym, params diff --git a/nnvm/python/nnvm/testing/resnet.py b/nnvm/python/nnvm/testing/resnet.py index 0e9c81232138..76b5c1d893b8 100644 --- a/nnvm/python/nnvm/testing/resnet.py +++ b/nnvm/python/nnvm/testing/resnet.py @@ -23,6 +23,7 @@ Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks" ''' +# pylint: disable=unused-argument import numpy as np from .. import symbol as sym from . utils import create_workload diff --git a/nnvm/tutorials/from_mxnet.py b/nnvm/tutorials/from_mxnet.py new file mode 100644 index 000000000000..0b6b70ef2d09 --- /dev/null +++ b/nnvm/tutorials/from_mxnet.py @@ -0,0 +1,114 @@ +""" +Compiling MXNet Models with NNVM +================================ +**Author**: `Joshua Z. Zhang `_ + +This article is an introductory tutorial to deploy mxnet models with NNVM. + +For us to begin with, mxnet module is required to be installed. + +A quick solution is +``` +pip install mxnet --user +``` +or please refer to offical installation guide. +https://mxnet.incubator.apache.org/versions/master/install/index.html +""" +# some standard imports +import mxnet as mx +import nnvm +import tvm +import numpy as np + +###################################################################### +# Download Resnet18 model from Gluon Model Zoo +# --------------------------------------------- +# In this section, we download a pretrained imagenet model and classify an image. +from mxnet.gluon.model_zoo.vision import get_model +from mxnet.gluon.utils import download +import Image +from matplotlib import pyplot as plt +block = get_model('resnet18_v1', pretrained=True) +img_name = 'cat.jpg' +synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', + '4d0b62f3d01426887599d4f7ede23ee5/raw/', + '596b27d23537e5a1b5751d2b0481ef172f58b539/', + 'imagenet1000_clsid_to_human.txt']) +synset_name = 'synset.txt' +download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name) +download(synset_url, synset_name) +with open(synset_name) as f: + synset = eval(f.read()) +image = Image.open(img_name).resize((224, 224)) +plt.imshow(image) +plt.show() + +def transform_image(image): + image = np.array(image) - np.array([123., 117., 104.]) + image /= np.array([58.395, 57.12, 57.375]) + image = image.transpose((2, 0, 1)) + image = image[np.newaxis, :] + return image + +x = transform_image(image) +print('x', x.shape) + +###################################################################### +# Compile the Graph +# ----------------- +# Now we would like to port the Gluon model to a portable computational graph. +# It's as easy as several lines. +# We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon +sym, params = nnvm.frontend.from_mxnet(block) +# we want a probability so add a softmax operator +sym = nnvm.sym.softmax(sym) + +###################################################################### +# now compile the graph +import nnvm.compiler +target = 'cuda' +shape_dict = {'data': x.shape} +graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, params=params) + +###################################################################### +# Execute the portable graph on TVM +# --------------------------------- +# Now, we would like to reproduce the same forward computation using TVM. +from tvm.contrib import graph_runtime +ctx = tvm.gpu(0) +dtype = 'float32' +m = graph_runtime.create(graph, lib, ctx) +# set inputs +m.set_input('data', tvm.nd.array(x.astype(dtype))) +m.set_input(**params) +# execute +m.run() +# get outputs +tvm_output = m.get_output(0, tvm.nd.empty((1000,), dtype)) +top1 = np.argmax(tvm_output) +print('TVM prediction top-1:', top1, synset[top1]) + +###################################################################### +# Use MXNet symbol with pretrained weights +# ---------------------------------------- +# MXNet often use `arg_prams` and `aux_params` to store network parameters +# separately, here we show how to use these weights with existing API +def block2symbol(block): + data = mx.sym.Variable('data') + sym = block(data) + args = {} + auxs = {} + for k, v in block.collect_params().items(): + args[k] = mx.nd.array(v.data().asnumpy()) + return sym, args, auxs +mx_sym, args, auxs = block2symbol(block) +# usually we would save/load it as checkpoint +mx.model.save_checkpoint('resnet18_v1', 0, mx_sym, args, auxs) +# there are 'resnet18_v1-0000.params' and 'resnet18_v1-symbol.json' on disk + +###################################################################### +# for a normal mxnet model, we start from here +mx_sym, args, auxs = mx.model.load_checkpoint('resnet18_v1', 0) +# now we use the same API to get NNVM compatible symbol +nnvm_sym, nnvm_params = nnvm.frontend.from_mxnet(mx_sym, args, auxs) +# repeat the same steps to run this model using TVM