From f7574e82c6bee463e7329051472d9489c39fe4f1 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sun, 20 Sep 2015 14:49:05 -0600 Subject: [PATCH 1/2] add composite symbol --- example/imagenet/alexnet.ipynb | 145 +++- example/python-howto/composite_symbol.ipynb | 766 ++++++++++++++++++++ python/mxnet/ndarray.py | 1 - python/mxnet/visualization.py | 2 +- 4 files changed, 905 insertions(+), 9 deletions(-) create mode 100644 example/python-howto/composite_symbol.ipynb diff --git a/example/imagenet/alexnet.ipynb b/example/imagenet/alexnet.ipynb index 1e9e399d1b5f..b7bb6bf266c2 100644 --- a/example/imagenet/alexnet.ipynb +++ b/example/imagenet/alexnet.ipynb @@ -11,6 +11,8 @@ "\n", "We will show how to train AlexNet in Python with single/multi GPU. All you need is to write a piece of Python code to describe network, then MXNet will help you finish all work without any of your effort. \n", "\n", + "Notice: This notebook is a basic demo to show MXNet flavor. To train a full state-of-art network, please refer our ```Inception``` example.\n", + "\n", "Generally, we need \n", "\n", "- Declare symbol network\n", @@ -27,6 +29,7 @@ }, "outputs": [], "source": [ + "%matplotlib inline\n", "import mxnet as mx" ] }, @@ -99,7 +102,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now we have a AlexNet. The ```softmax``` symbol contains all network structures. We can visualize our network structure. (require ```graphviz``` package)" + "Now we have a AlexNet in symbolic level. The ```softmax``` symbol contains all network structures. By indicate ```data``` for each symbol, the last symbol composite all info we need. We can visualize our network structure. (require ```graphviz``` package)" ] }, { @@ -115,7 +118,7 @@ "\n", "\n", - "\n", "\n", "\n", "\n", "null_0\n", - "\n", + "\n", "data\n", "\n", "\n", @@ -390,7 +393,7 @@ "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 4, @@ -399,7 +402,7 @@ } ], "source": [ - "mx.visualization.network2dot(\"AlexNet\", softmax)" + "mx.visualization.plot_network(\"AlexNet\", softmax)" ] }, { @@ -408,7 +411,135 @@ "collapsed": true }, "source": [ - "After define our network, we are able to create our model." + "The next step is declare data iterator. We provide high perfomance RecordIO image iterator for ImageNet task. Please pack the images into record file before use. For how to pack image and more details about image data iterator and build-in io iterator, please read [io doc](https://github.com/dmlc/mxnet/blob/master/doc/python/io.md)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# We set batch size for to 256\n", + "batch_size = 256\n", + "# We need to set correct path to image record file\n", + "# For ```mean_image```. if it doesn't exist, the iterator will generate one. Usually on normal HDD, it costs less than 10 minutes\n", + "# the input shape is in format (channel, height, width)\n", + "# rand_crop option make source image randomly cropped to input_shape (3, 224, 224)\n", + "# rand_mirror option make source image randomly mirrored\n", + "# We use 2 threads to processing our data\n", + "train_dataiter = mx.io.ImageRecordIter(\n", + " path_imgrec=\"./Data/ImageNet/train.rec\",\n", + " mean_img=\"./Data/ImageNet/mean_224.bin\",\n", + " rand_crop=True,\n", + " rand_mirror=True,\n", + " input_shape=(3, 224, 224),\n", + " batch_size=batch_size,\n", + " nthread=2)\n", + "# similarly, we can declare our validation iterator\n", + "val_dataiter = mx.io.ImageRecordIter(\n", + " path_imgrec=\"./Data/ImageNet/val.rec\",\n", + " mean_img=\"./Data/ImageNet/mean_224.bin\",\n", + " rand_crop=False,\n", + " rand_mirror=False,\n", + " input_shape=(3, 224, 224),\n", + " batch_size=batch_size,\n", + " nthread=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next step, we will initialize our model from symbol. To run on a single GPU, we need to declare:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# For demo purpose, we just run 1 epoch\n", + "num_round = 1\n", + "# set context to GPU, if you want to use cpu, set it to mx.cpu()\n", + "ctx = mx.gpu() \n", + "# note: for input shape in model, we must contain batch size\n", + "data_shape = (batch_size, 3, 224, 224)\n", + "\n", + "model = mx.model.FeedForward(symbol=softmax, ctx=ctx, input_shape=data_shape, num_round=num_round,\n", + " learning_rate=0.01, momentum=0.9, wd=0.0001)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To run on multiply GPU, we need to declare" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# For demo purpose, we just run 1 epoch\n", + "num_round = 1\n", + "# Assume we have 4 GPU, we can make a context list contains 4 device\n", + "num_devs = 4\n", + "ctx = [mx.gpu(i) for i in range(num_devs)]\n", + "# note: for input shape in model, we must contain batch size\n", + "data_shape = (batch_size, 3, 224, 224)\n", + "\n", + "model = mx.model.FeedForward(symbol=softmax, ctx=ctx, input_shape=data_shape, num_round=num_round,\n", + " learning_rate=0.01, momentum=0.9, wd=0.0001)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "ename": "MXNetError", + "evalue": "[12:00:28] src/ndarray/ndarray.cc:157: Check failed: from.shape() == to->shape() operands shape mismatch", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mMXNetError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;31m# In this case, eval_data is also a data iterator\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;31m# We will use accuracy to measure our model's performace\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrain_dataiter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0meval_data\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mval_dataiter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0meval_metric\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'acc'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[1;32m/home/bing/wtf/mxnet/python/mxnet/model.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, X, y, eval_data, eval_metric, verbose)\u001b[0m\n\u001b[0;32m 304\u001b[0m \u001b[0mtrain_data\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0meval_data\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0meval_data\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 305\u001b[0m \u001b[0meval_metric\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0meval_metric\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 306\u001b[1;33m verbose=verbose)\n\u001b[0m", + "\u001b[1;32m/home/bing/wtf/mxnet/python/mxnet/model.py\u001b[0m in \u001b[0;36m_train\u001b[1;34m(symbol, ctx, input_shape, arg_params, aux_params, begin_round, end_round, optimizer, train_data, eval_data, eval_metric, iter_end_callback, verbose)\u001b[0m\n\u001b[0;32m 85\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mlist\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marg_names\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0marg_arrays\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 86\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mkey\u001b[0m \u001b[1;32min\u001b[0m \u001b[0marg_params\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 87\u001b[1;33m \u001b[0marg_params\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcopyto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 88\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mlist\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0maux_names\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maux_arrays\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 89\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mkey\u001b[0m \u001b[1;32min\u001b[0m \u001b[0maux_params\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m/home/bing/wtf/mxnet/python/mxnet/ndarray.py\u001b[0m in \u001b[0;36mcopyto\u001b[1;34m(self, other)\u001b[0m\n\u001b[0;32m 306\u001b[0m RuntimeWarning)\n\u001b[0;32m 307\u001b[0m \u001b[1;32mreturn\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 308\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mNDArray\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_copyto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mout\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mother\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 309\u001b[0m \u001b[1;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mother\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mContext\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 310\u001b[0m \u001b[0mhret\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mNDArray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_new_alloc_handle\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mother\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m/home/bing/wtf/mxnet/python/mxnet/ndarray.py\u001b[0m in \u001b[0;36mgeneric_ndarray_function\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 618\u001b[0m \u001b[0mc_array\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mNDArrayHandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mhandle\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0muse_vars_range\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 619\u001b[0m \u001b[0mc_array\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmx_float\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mscalar_range\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 620\u001b[1;33m c_array(NDArrayHandle, [v.handle for v in mutate_vars])))\n\u001b[0m\u001b[0;32m 621\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mn_mutate_vars\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 622\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mmutate_vars\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m/home/bing/wtf/mxnet/python/mxnet/base.py\u001b[0m in \u001b[0;36mcheck_call\u001b[1;34m(ret)\u001b[0m\n\u001b[0;32m 95\u001b[0m \"\"\"\n\u001b[0;32m 96\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mret\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 97\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mMXNetError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpy_str\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_LIB\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mMXGetLastError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 98\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mc_str\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstring\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mMXNetError\u001b[0m: [12:00:28] src/ndarray/ndarray.cc:157: Check failed: from.shape() == to->shape() operands shape mismatch" + ] + } + ], + "source": [ + "# Now we can fit the model with data iterators\n", + "# When we use data iterator, we don't need to set y because label comes from data iterator directly\n", + "# In this case, eval_data is also a data iterator\n", + "# We will use accuracy to measure our model's performace\n", + "model.fit(X=train_dataiter, eval_data=val_dataiter, eval_metric='acc', verbose=True)\n", + "# You need to wait for a while to get the result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That's all!" ] }, { @@ -437,7 +568,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.4.0" + "version": "3.4.2" } }, "nbformat": 4, diff --git a/example/python-howto/composite_symbol.ipynb b/example/python-howto/composite_symbol.ipynb new file mode 100644 index 000000000000..dc97fa22e5dc --- /dev/null +++ b/example/python-howto/composite_symbol.ipynb @@ -0,0 +1,766 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Composite symbols into component\n", + "In this example we will show how to make an Inception network by forming single symbol into component.\n", + "\n", + "Inception is currently best model. Compared to other models, it has much less parameters, and with best performance. However, it is much more complex than sequence feedforward network.\n", + "\n", + "The Inception network in this example is refer to ```Ioffe, Sergey, and Christian Szegedy. \"Batch normalization: Accelerating deep network training by reducing internal covariate shift.\" arXiv preprint arXiv:1502.03167 (2015).```\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import mxnet as mx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For complex network such as inception network, building from single symbol is painful, we can make simple ```ComponentFactory``` to simplfiy the procedure.\n", + "\n", + "Except difference in number of filter, we find 2 major differences in each Inception module, so we can build two factories plus one basic ```Convolution + BatchNorm + ReLU``` factory to simplfiy the problem.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Basic Conv + BN + ReLU factory\n", + "def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), act_type=\"relu\"):\n", + " conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)\n", + " bn = mx.symbol.BatchNorm(data=conv)\n", + " act = mx.symbol.Activation(data = bn, act_type=act_type)\n", + " return act\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can visualize our basic component" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "conv\n", + "\n", + "\n", + "null_0\n", + "\n", + "Previos Output\n", + "\n", + "\n", + "Convolution_3\n", + "\n", + "Convolution\n", + "7x7/2, 64\n", + "\n", + "\n", + "Convolution_3->null_0\n", + "\n", + "\n", + "\n", + "\n", + "BatchNorm_6\n", + "\n", + "BatchNorm\n", + "\n", + "\n", + "BatchNorm_6->Convolution_3\n", + "\n", + "\n", + "\n", + "\n", + "Activation_7\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_7->BatchNorm_6\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prev = mx.symbol.Variable(name=\"Previos Output\")\n", + "conv_comp = ConvFactory(data=prev, num_filter=64, kernel=(7,7), stride=(2, 2))\n", + "mx.visualization.plot_network(title=\"conv\", symbol=conv_comp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The next step is making a component factory with all ```stride=(1, 1)```" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# param mapping to paper:\n", + "# num_1x1 >>> #1x1\n", + "# num_3x3red >>> #3x3 reduce\n", + "# num_3x3 >>> #3x3\n", + "# num_d3x3red >>> double #3x3 reduce\n", + "# num_d3x3 >>> double #3x3\n", + "# pool >>> Pool\n", + "# proj >>> proj\n", + "def InceptionFactoryA(data, num_1x1, num_3x3red, num_3x3, num_d3x3red, num_d3x3, pool, proj):\n", + " # 1x1\n", + " c1x1 = ConvFactory(data=data, num_filter=num_1x1, kernel=(1, 1))\n", + " # 3x3 reduce + 3x3\n", + " c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1))\n", + " c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1))\n", + " # double 3x3 reduce + double 3x3\n", + " cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1))\n", + " cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1))\n", + " # pool + proj\n", + " pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool)\n", + " cproj = ConvFactory(data=pooling, num_filter=proj, kernel=(1, 1))\n", + " # concat\n", + " concat = mx.symbol.Concat(*[c1x1, c3x3, cd3x3, cproj])\n", + " return concat" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "in3a\n", + "\n", + "\n", + "null_0\n", + "\n", + "Previos Output\n", + "\n", + "\n", + "Convolution_3\n", + "\n", + "Convolution\n", + "1x1/1, 64\n", + "\n", + "\n", + "Convolution_3->null_0\n", + "\n", + "\n", + "\n", + "\n", + "BatchNorm_6\n", + "\n", + "BatchNorm\n", + "\n", + "\n", + "BatchNorm_6->Convolution_3\n", + "\n", + "\n", + "\n", + "\n", + "Activation_7\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_7->BatchNorm_6\n", + "\n", + "\n", + "\n", + "\n", + "Convolution_10\n", + "\n", + "Convolution\n", + "1x1/1, 64\n", + "\n", + "\n", + "Convolution_10->null_0\n", + "\n", + "\n", + "\n", + "\n", + "BatchNorm_13\n", + "\n", + "BatchNorm\n", + "\n", + "\n", + "BatchNorm_13->Convolution_10\n", + "\n", + "\n", + "\n", + "\n", + "Activation_14\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_14->BatchNorm_13\n", + "\n", + "\n", + "\n", + "\n", + "Convolution_17\n", + "\n", + "Convolution\n", + "3x3/1, 64\n", + "\n", + "\n", + "Convolution_17->Activation_14\n", + "\n", + "\n", + "\n", + "\n", + "BatchNorm_20\n", + "\n", + "BatchNorm\n", + "\n", + "\n", + "BatchNorm_20->Convolution_17\n", + "\n", + "\n", + "\n", + "\n", + "Activation_21\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_21->BatchNorm_20\n", + "\n", + "\n", + "\n", + "\n", + "Convolution_24\n", + "\n", + "Convolution\n", + "1x1/1, 64\n", + "\n", + "\n", + "Convolution_24->null_0\n", + "\n", + "\n", + "\n", + "\n", + "BatchNorm_27\n", + "\n", + "BatchNorm\n", + "\n", + "\n", + "BatchNorm_27->Convolution_24\n", + "\n", + "\n", + "\n", + "\n", + "Activation_28\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_28->BatchNorm_27\n", + "\n", + "\n", + "\n", + "\n", + "Convolution_31\n", + "\n", + "Convolution\n", + "3x3/1, 96\n", + "\n", + "\n", + "Convolution_31->Activation_28\n", + "\n", + "\n", + "\n", + "\n", + "BatchNorm_34\n", + "\n", + "BatchNorm\n", + "\n", + "\n", + "BatchNorm_34->Convolution_31\n", + "\n", + "\n", + "\n", + "\n", + "Activation_35\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_35->BatchNorm_34\n", + "\n", + "\n", + "\n", + "\n", + "Pooling_36\n", + "\n", + "Pooling\n", + "avg, 3x3/1\n", + "\n", + "\n", + "Pooling_36->null_0\n", + "\n", + "\n", + "\n", + "\n", + "Convolution_39\n", + "\n", + "Convolution\n", + "1x1/1, 32\n", + "\n", + "\n", + "Convolution_39->Pooling_36\n", + "\n", + "\n", + "\n", + "\n", + "BatchNorm_42\n", + "\n", + "BatchNorm\n", + "\n", + "\n", + "BatchNorm_42->Convolution_39\n", + "\n", + "\n", + "\n", + "\n", + "Activation_43\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_43->BatchNorm_42\n", + "\n", + "\n", + "\n", + "\n", + "Concat_44\n", + "\n", + "Concat\n", + "\n", + "\n", + "Concat_44->Activation_7\n", + "\n", + "\n", + "\n", + "\n", + "Concat_44->Activation_21\n", + "\n", + "\n", + "\n", + "\n", + "Concat_44->Activation_35\n", + "\n", + "\n", + "\n", + "\n", + "Concat_44->Activation_43\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prev = mx.symbol.Variable(name=\"Previos Output\")\n", + "in3a = InceptionFactoryA(prev, 64, 64, 64, 64, 96, \"avg\", 32)\n", + "mx.visualization.plot_network(title=\"in3a\", symbol=in3a)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will make the other factory with ```strde=(2, 2)```" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# param mapping to paper:\n", + "# num_1x1 >>> #1x1 (not exist!)\n", + "# num_3x3red >>> #3x3 reduce\n", + "# num_3x3 >>> #3x3\n", + "# num_d3x3red >>> double #3x3 reduce\n", + "# num_d3x3 >>> double #3x3\n", + "# pool >>> Pool (not needed, all are max pooling)\n", + "# proj >>> proj (not exist!)\n", + "def InceptionFactoryB(data, num_3x3red, num_3x3, num_d3x3red, num_d3x3):\n", + " # 3x3 reduce + 3x3\n", + " c3x3r = ConvFactory(data=data, num_filter=num_3x3red, kernel=(1, 1))\n", + " c3x3 = ConvFactory(data=c3x3r, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2))\n", + " # double 3x3 reduce + double 3x3\n", + " cd3x3r = ConvFactory(data=data, num_filter=num_d3x3red, kernel=(1, 1))\n", + " cd3x3 = ConvFactory(data=cd3x3r, num_filter=num_d3x3, kernel=(3, 3), pad=(1, 1), stride=(2, 2))\n", + " # pool + proj\n", + " pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=\"max\")\n", + " # concat\n", + " concat = mx.symbol.Concat(*[c3x3, cd3x3, pooling])\n", + " return concat" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "in3c\n", + "\n", + "\n", + "null_0\n", + "\n", + "Previos Output\n", + "\n", + "\n", + "Convolution_3\n", + "\n", + "Convolution\n", + "1x1/1, 128\n", + "\n", + "\n", + "Convolution_3->null_0\n", + "\n", + "\n", + "\n", + "\n", + "BatchNorm_6\n", + "\n", + "BatchNorm\n", + "\n", + "\n", + "BatchNorm_6->Convolution_3\n", + "\n", + "\n", + "\n", + "\n", + "Activation_7\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_7->BatchNorm_6\n", + "\n", + "\n", + "\n", + "\n", + "Convolution_10\n", + "\n", + "Convolution\n", + "3x3/2, 160\n", + "\n", + "\n", + "Convolution_10->Activation_7\n", + "\n", + "\n", + "\n", + "\n", + "BatchNorm_13\n", + "\n", + "BatchNorm\n", + "\n", + "\n", + "BatchNorm_13->Convolution_10\n", + "\n", + "\n", + "\n", + "\n", + "Activation_14\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_14->BatchNorm_13\n", + "\n", + "\n", + "\n", + "\n", + "Convolution_17\n", + "\n", + "Convolution\n", + "1x1/1, 64\n", + "\n", + "\n", + "Convolution_17->null_0\n", + "\n", + "\n", + "\n", + "\n", + "BatchNorm_20\n", + "\n", + "BatchNorm\n", + "\n", + "\n", + "BatchNorm_20->Convolution_17\n", + "\n", + "\n", + "\n", + "\n", + "Activation_21\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_21->BatchNorm_20\n", + "\n", + "\n", + "\n", + "\n", + "Convolution_24\n", + "\n", + "Convolution\n", + "3x3/2, 96\n", + "\n", + "\n", + "Convolution_24->Activation_21\n", + "\n", + "\n", + "\n", + "\n", + "BatchNorm_27\n", + "\n", + "BatchNorm\n", + "\n", + "\n", + "BatchNorm_27->Convolution_24\n", + "\n", + "\n", + "\n", + "\n", + "Activation_28\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_28->BatchNorm_27\n", + "\n", + "\n", + "\n", + "\n", + "Pooling_29\n", + "\n", + "Pooling\n", + "max, 3x3/2\n", + "\n", + "\n", + "Pooling_29->null_0\n", + "\n", + "\n", + "\n", + "\n", + "Concat_30\n", + "\n", + "Concat\n", + "\n", + "\n", + "Concat_30->Activation_14\n", + "\n", + "\n", + "\n", + "\n", + "Concat_30->Activation_28\n", + "\n", + "\n", + "\n", + "\n", + "Concat_30->Pooling_29\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prev = mx.symbol.Variable(name=\"Previos Output\")\n", + "in3c = InceptionFactoryB(prev, 128, 160, 64, 96)\n", + "mx.visualization.plot_network(title=\"in3c\", symbol=in3c)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can use these factories to build the whole network" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# data\n", + "data = mx.symbol.Variable(name=\"data\")\n", + "# stage 1\n", + "conv1 = ConvFactory(data=data, num_filter=64, kernel=(7, 7), stride=(2, 2), pad=(3, 3))\n", + "pool1 = mx.symbol.Pooling(data=conv1, kernel=(3, 3), stride=(2, 2))\n", + "# stage 2\n", + "conv2red = ConvFactory(data=pool1, num_filter=64, kernel=(1, 1), stride=(1, 1))\n", + "conv2 = ConvFactory(data=conv2red, num_filter=192, kernel=(3, 3), stride=(1, 1), pad=(1, 1))\n", + "pool2 = mx.symbol.Pooling(data=conv2, kernel=(3, 3), stride=(2, 2))\n", + "# stage 3\n", + "in3a = InceptionFactoryA(pool2, 64, 64, 64, 64, 96, \"avg\", 32)\n", + "in3b = InceptionFactoryA(in3a, 64, 64, 96, 64, 96, \"avg\", 64)\n", + "in3c = InceptionFactoryB(in3b, 128, 160, 64, 96)\n", + "# stage 4\n", + "in4a = InceptionFactoryA(in3c, 224, 64, 96, 96, 128, \"avg\", 128)\n", + "in4b = InceptionFactoryA(in4a, 192, 96, 128, 96, 128, \"avg\", 128)\n", + "in4c = InceptionFactoryA(in4b, 160, 128, 160, 128, 160, \"avg\", 128)\n", + "in4d = InceptionFactoryA(in4c, 96, 128, 192, 160, 192, \"avg\", 128)\n", + "in4e = InceptionFactoryB(in4d, 128, 192, 192, 256)\n", + "# stage 5\n", + "in5a = InceptionFactoryA(in4e, 352, 192, 320, 160, 224, \"avg\", 128)\n", + "in5b = InceptionFactoryA(in5a, 352, 192, 320, 192, 224, \"max\", 128)\n", + "# global avg pooling\n", + "avg = mx.symbol.Pooling(data=in5b, kernel=(7, 7), stride=(1, 1))\n", + "# linear classifier\n", + "flatten = mx.symbol.Flatten(data=avg)\n", + "fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=1000)\n", + "softmax = mx.symbol.Softmax(data=fc1)\n", + "\n", + "# if you like, you can visualize full network structure\n", + "# mx.visualization.plot_network(title=\"inception\", symbol=softmax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.4.2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 921c00fccb35..4e99151afccc 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -36,7 +36,6 @@ def _new_alloc_handle(shape, ctx, delay_alloc): a new empty ndarray handle """ hdl = NDArrayHandle() - print ctx.device_typeid check_call(_LIB.MXNDArrayCreate( c_array(mx_uint, shape), len(shape), diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index bccc4a2b3155..86fc53c37311 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -23,7 +23,7 @@ def _str2tuple(string): return re.findall(r"\d+", string) -def network2dot(title, symbol, shape=None): +def plot_network(title, symbol, shape=None): """convert symbol to dot object for visualization Parameters From 510ba1bacae115970b0bca3912ecb330e4a83855 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sun, 20 Sep 2015 22:19:32 -0600 Subject: [PATCH 2/2] [BUG FIX] model prediction and model in python3 --- python/mxnet/initializer.py | 9 +- python/mxnet/model.py | 9 +- src/common/cuda_utils.h | 3 +- src/engine/naive_engine.cc | 11 ++- src/engine/threaded_engine.cc | 2 +- src/engine/threaded_engine_perdevice.cc | 10 ++- src/resource.cc | 22 ++++- tests/python/train/test_conv.py | 108 ++++++------------------ 8 files changed, 80 insertions(+), 94 deletions(-) diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index 021586719dfa..bd64413ca295 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -30,11 +30,18 @@ def __call__(self, name, arr): self._init_beta(name, arr) elif name.endswith('weight'): self._init_weight(name, arr) + elif name.endswith("moving_mean"): + self._init_zero(name, arr) + elif name.endswith("moving_var"): + self._init_zero(name, arr) else: self._init_default(name, arr) + def _init_zero(self, name, arr): + arr[:] = 0.0 + def _init_bias(self, name, arr): - arr[:] = 0 + arr[:] = 0.0 def _init_gamma(self, name, arr): arr[:] = 1.0 diff --git a/python/mxnet/model.py b/python/mxnet/model.py index d757ce5ded27..2c4614b1dade 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -149,7 +149,7 @@ def _train(symbol, ctx, input_shape, data_array, label_array = arg_arrays[data_index], arg_arrays[label_index] out_array = train_exec.outputs[0] out_cpu_array = nd.zeros(out_array.shape) - arg_blocks = zip(arg_arrays, grad_arrays) + arg_blocks = list(zip(arg_arrays, grad_arrays)) for i in range(begin_round, end_round): # training phase @@ -184,7 +184,7 @@ def _train(symbol, ctx, input_shape, for data, label in eval_data: data.copyto(data_array) # TODO(bing): add is_train=False - train_exec.forward() + train_exec.forward(is_train=False) out_array.copyto(out_cpu_array) eval_metric.update(out_array, label) @@ -400,6 +400,9 @@ def _init_predictor(self, input_shape): if not self._is_data_arg(name): assert name in self.arg_params self.arg_params[name].copyto(value) + for name, value in list(zip(self.symbol.list_auxiliary_states(), pred_exec.aux_arrays)): + assert name in self.aux_params + self.aux_params[name].copyto(value) data_index, _ = _check_arguments(self.symbol) self._pred_exec_input = pred_exec.arg_arrays[data_index] self._pred_exec = pred_exec @@ -423,7 +426,7 @@ def predict(self, X): X.reset() for data, _ in X: data.copyto(self._pred_exec_input) - self._pred_exec.forward() + self._pred_exec.forward(is_train=False) outputs.append(self._pred_exec.outputs[0].asnumpy()) return np.concatenate(outputs) diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h index 51e67bfb0d04..bbc7961c2642 100644 --- a/src/common/cuda_utils.h +++ b/src/common/cuda_utils.h @@ -109,7 +109,8 @@ inline const char* CurandGetErrorString(curandStatus_t status) { #define CUDA_CALL(func) \ { \ cudaError_t e = (func); \ - CHECK_EQ(e, cudaSuccess) << "CUDA: " << cudaGetErrorString(e); \ + CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ + << "CUDA: " << cudaGetErrorString(e); \ } /*! diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index e9d558013c1b..2ad0e6772a0d 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -16,6 +16,7 @@ class NaiveEngine final : public Engine { // virtual destructor virtual ~NaiveEngine() { #if MXNET_USE_CUDA + LOG(INFO) << "Engine shutdown"; for (size_t i = 0; i < streams_.size(); ++i) { if (streams_[i] != nullptr) { // Catch exception for CUDA driver shutdown @@ -62,7 +63,14 @@ class NaiveEngine final : public Engine { if (exec_ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA size_t dev_id = static_cast(exec_ctx.dev_id); - mshadow::SetDevice(exec_ctx.dev_id); + try { + mshadow::SetDevice(exec_ctx.dev_id); + } catch (const dmlc::Error &e) { + std::string what = e.what(); + if (what.find("driver shutting down") == std::string::npos) { + LOG(ERROR) << "Ignore Error " << what << " during worker finalization"; + } + } if (streams_.size() <= dev_id) { streams_.resize(dev_id + 1, nullptr); } @@ -71,7 +79,6 @@ class NaiveEngine final : public Engine { } ctx_.stream = streams_[dev_id]; exec_fun(ctx_, callback); - streams_[dev_id]->Wait(); #else LOG(FATAL) << "GPU is not enabled"; #endif diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 0889482535e9..b1fef5cf5091 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -85,7 +85,7 @@ template bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { // this is lock scope VersionedVarBlock *old_pending_write, *end_of_read_chain; - bool trigger_write = false, to_delete = false; + bool trigger_write = false; { std::lock_guard lock{m_}; assert(ready_to_read_ == false); diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 3bb72606c341..c58352def413 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -48,7 +48,14 @@ class ThreadedEnginePerDevice : public ThreadedEngine { if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { if (ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA - mshadow::SetDevice(ctx.dev_id); + try { + mshadow::SetDevice(ctx.dev_id); + } catch (const dmlc::Error &e) { + std::string what = e.what(); + if (what.find("driver shutting down") == std::string::npos) { + LOG(ERROR) << "Ignore Error " << what << " during worker finalization"; + } + } #endif } RunContext run_ctx; @@ -170,4 +177,3 @@ Engine *CreateThreadedEnginePerDevice() { } } // namespace engine } // namespace mxnet - diff --git a/src/resource.cc b/src/resource.cc index 2259c642d32e..27a06fd358b5 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -111,7 +111,16 @@ class ResourceManagerImpl : public ResourceManager { ~ResourceRandom() { mshadow::Random *r = prnd; Engine::Get()->DeleteVariable( - [r](RunContext rctx){ delete r; }, ctx, resource.var); + [r](RunContext rctx) { + try { + delete r; + } catch (const dmlc::Error &e) { + std::string what = e.what(); + if (what.find("driver shutting down") == std::string::npos) { + LOG(ERROR) << "Ignore Error " << what << " resource finalization"; + } + } + }, ctx, resource.var); } // set seed to a PRNG inline void Seed(uint32_t global_seed) { @@ -150,7 +159,16 @@ class ResourceManagerImpl : public ResourceManager { for (size_t i = 0; i < space.size(); ++i) { mshadow::TensorContainer* r = space[i]; Engine::Get()->DeleteVariable( - [r](RunContext rctx){ delete r; }, ctx, resource[i].var); + [r](RunContext rctx){ + try { + r->Release(); + } catch (const dmlc::Error &e) { + std::string what = e.what(); + if (what.find("driver shutting down") == std::string::npos) { + LOG(ERROR) << "Ignore Error " << what << " resource finalization"; + } + } + }, ctx, resource[i].var); } } // get next resource in round roubin matter diff --git a/tests/python/train/test_conv.py b/tests/python/train/test_conv.py index e411c9e9c8f6..df2b7b98afb2 100644 --- a/tests/python/train/test_conv.py +++ b/tests/python/train/test_conv.py @@ -1,15 +1,15 @@ # pylint: skip-file +import sys +sys.path.insert(0, '../../python') import mxnet as mx import numpy as np import os, pickle, gzip +import logging from common import get_data -def CalAcc(out, label): - pred = np.argmax(out, axis=1) - return np.sum(pred == label) * 1.0 / out.shape[0] - # symbol net batch_size = 100 + data = mx.symbol.Variable('data') conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2)) bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1") @@ -25,48 +25,12 @@ def CalAcc(out, label): fl = mx.symbol.Flatten(data = mp2, name="flatten") fc2 = mx.symbol.FullyConnected(data = fl, name='fc2', num_hidden=10) softmax = mx.symbol.Softmax(data = fc2, name = 'sm') -args_list = softmax.list_arguments() -# infer shape -#data_shape = (batch_size, 784) - -data_shape = (batch_size, 1, 28, 28) -arg_shapes, out_shapes, aux_shapes = softmax.infer_shape(data=data_shape) -arg_narrays = [mx.nd.empty(shape) for shape in arg_shapes] -grad_narrays = [mx.nd.empty(shape) for shape in arg_shapes] -aux_narrays = [mx.nd.empty(shape) for shape in aux_shapes] - -inputs = dict(zip(args_list, arg_narrays)) -np.random.seed(0) -# set random weight -for name, narray in inputs.items(): - if "weight" in name: - narray[:] = np.random.uniform(-0.07, 0.07, narray.shape) - if "bias" in name: - narray[:] = 0.0 - if "gamma" in name: - narray[:] = 1.0 - if "beta" in name: - narray[:] = 0.0 - -# bind executer -# TODO(bing): think of a better bind interface - -executor = softmax.bind(mx.cpu(), arg_narrays, grad_narrays, 'write', aux_narrays) -# update -print(executor.debug_str()) -out_narray = executor.outputs[0] -grad_narray = mx.nd.empty(out_narray.shape) - -epoch = 1 -momentum = 0.9 -lr = 0.1 -wd = 0.0004 - -def Update(grad, weight): - weight[:] -= lr * grad / batch_size - -block = list(zip(grad_narrays, arg_narrays)) +num_round = 1 +model = mx.model.FeedForward(softmax, mx.cpu(), + num_round=num_round, + learning_rate=0.1, wd=0.0001, + momentum=0.9) # check data get_data.GetMNIST_ubyte() @@ -82,43 +46,23 @@ def Update(grad, weight): batch_size=batch_size, shuffle=True, flat=False, silent=False) def test_mnist(): - acc_train = 0.0 - acc_val = 0.0 - for i in range(epoch): - # train - print("Epoch %d" % i) - train_acc = 0.0 - val_acc = 0.0 - train_nbatch = 0 - val_nbatch = 0 - for data, label in train_dataiter: - label = label.asnumpy().flatten() - inputs["data"][:] = data - inputs["sm_label"][:] = label - executor.forward(is_train = True) - train_acc += CalAcc(out_narray.asnumpy(), label) - train_nbatch += 1 - grad_narray[:] = out_narray - executor.backward([grad_narray]) - - for grad, weight in block: - Update(grad, weight) - - # evaluate - for data, label in val_dataiter: - label = label.asnumpy().flatten() - inputs["data"][:] = data - executor.forward(is_train = False) - val_acc += CalAcc(out_narray.asnumpy(), label) - val_nbatch += 1 - print("Train Acc: ", train_acc / train_nbatch) - print("Valid Acc: ", val_acc / val_nbatch) - acc_train = train_acc / train_nbatch - acc_val = val_acc / val_nbatch - train_dataiter.reset() - val_dataiter.reset() - assert(acc_train > 0.84) - assert(acc_val > 0.96) + # print logging by default + logging.basicConfig(level=logging.DEBUG) + console = logging.StreamHandler() + console.setLevel(logging.DEBUG) + logging.getLogger('').addHandler(console) + + model.fit(X=train_dataiter, + eval_data=val_dataiter) + logging.info('Finish fit...') + prob = model.predict(val_dataiter) + logging.info('Finish predict...') + val_dataiter.reset() + y = np.concatenate([label.asnumpy() for _, label in val_dataiter]).astype('int') + py = np.argmax(prob, axis=1) + acc1 = float(np.sum(py == y)) / len(y) + logging.info('final accuracy = %f', acc1) + assert(acc1 > 0.96) if __name__ == "__main__":