Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUG FIX] model prediction and model in python3 #109

Merged
merged 5 commits into from
Sep 21, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 138 additions & 7 deletions example/imagenet/alexnet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
"\n",
"We will show how to train AlexNet in Python with single/multi GPU. All you need is to write a piece of Python code to describe network, then MXNet will help you finish all work without any of your effort. \n",
"\n",
"Notice: This notebook is a basic demo to show MXNet flavor. To train a full state-of-art network, please refer our ```Inception``` example.\n",
"\n",
"Generally, we need \n",
"\n",
"- Declare symbol network\n",
Expand All @@ -27,6 +29,7 @@
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import mxnet as mx"
]
},
Expand Down Expand Up @@ -99,7 +102,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we have a AlexNet. The ```softmax``` symbol contains all network structures. We can visualize our network structure. (require ```graphviz``` package)"
"Now we have a AlexNet in symbolic level. The ```softmax``` symbol contains all network structures. By indicate ```data``` for each symbol, the last symbol composite all info we need. We can visualize our network structure. (require ```graphviz``` package)"
]
},
{
Expand All @@ -115,7 +118,7 @@
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.36.0 (20140111.2315)\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
" -->\n",
"<!-- Title: AlexNet Pages: 1 -->\n",
"<svg width=\"102pt\" height=\"2322pt\"\n",
Expand All @@ -125,7 +128,7 @@
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-2318 98,-2318 98,4 -4,4\"/>\n",
"<!-- null_0 -->\n",
"<g id=\"node1\" class=\"node\"><title>null_0</title>\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"94,-58 -7.10543e-15,-58 -7.10543e-15,-3.55271e-15 94,-3.55271e-15 94,-58\"/>\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"94,-58 -7.10543e-15,-58 -7.10543e-15,-0 94,-0 94,-58\"/>\n",
"<text text-anchor=\"middle\" x=\"47\" y=\"-25.3\" font-family=\"Times,serif\" font-size=\"14.00\">data</text>\n",
"</g>\n",
"<!-- Convolution_3 -->\n",
Expand Down Expand Up @@ -390,7 +393,7 @@
"</svg>\n"
],
"text/plain": [
"<graphviz.dot.Digraph at 0x7f01c5b9e2e8>"
"<graphviz.dot.Digraph at 0x7f08a1121198>"
]
},
"execution_count": 4,
Expand All @@ -399,7 +402,7 @@
}
],
"source": [
"mx.visualization.network2dot(\"AlexNet\", softmax)"
"mx.visualization.plot_network(\"AlexNet\", softmax)"
]
},
{
Expand All @@ -408,7 +411,135 @@
"collapsed": true
},
"source": [
"After define our network, we are able to create our model."
"The next step is declare data iterator. We provide high perfomance RecordIO image iterator for ImageNet task. Please pack the images into record file before use. For how to pack image and more details about image data iterator and build-in io iterator, please read [io doc](https://github.com/dmlc/mxnet/blob/master/doc/python/io.md)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# We set batch size for to 256\n",
"batch_size = 256\n",
"# We need to set correct path to image record file\n",
"# For ```mean_image```. if it doesn't exist, the iterator will generate one. Usually on normal HDD, it costs less than 10 minutes\n",
"# the input shape is in format (channel, height, width)\n",
"# rand_crop option make source image randomly cropped to input_shape (3, 224, 224)\n",
"# rand_mirror option make source image randomly mirrored\n",
"# We use 2 threads to processing our data\n",
"train_dataiter = mx.io.ImageRecordIter(\n",
" path_imgrec=\"./Data/ImageNet/train.rec\",\n",
" mean_img=\"./Data/ImageNet/mean_224.bin\",\n",
" rand_crop=True,\n",
" rand_mirror=True,\n",
" input_shape=(3, 224, 224),\n",
" batch_size=batch_size,\n",
" nthread=2)\n",
"# similarly, we can declare our validation iterator\n",
"val_dataiter = mx.io.ImageRecordIter(\n",
" path_imgrec=\"./Data/ImageNet/val.rec\",\n",
" mean_img=\"./Data/ImageNet/mean_224.bin\",\n",
" rand_crop=False,\n",
" rand_mirror=False,\n",
" input_shape=(3, 224, 224),\n",
" batch_size=batch_size,\n",
" nthread=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next step, we will initialize our model from symbol. To run on a single GPU, we need to declare:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# For demo purpose, we just run 1 epoch\n",
"num_round = 1\n",
"# set context to GPU, if you want to use cpu, set it to mx.cpu()\n",
"ctx = mx.gpu() \n",
"# note: for input shape in model, we must contain batch size\n",
"data_shape = (batch_size, 3, 224, 224)\n",
"\n",
"model = mx.model.FeedForward(symbol=softmax, ctx=ctx, input_shape=data_shape, num_round=num_round,\n",
" learning_rate=0.01, momentum=0.9, wd=0.0001)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To run on multiply GPU, we need to declare"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# For demo purpose, we just run 1 epoch\n",
"num_round = 1\n",
"# Assume we have 4 GPU, we can make a context list contains 4 device\n",
"num_devs = 4\n",
"ctx = [mx.gpu(i) for i in range(num_devs)]\n",
"# note: for input shape in model, we must contain batch size\n",
"data_shape = (batch_size, 3, 224, 224)\n",
"\n",
"model = mx.model.FeedForward(symbol=softmax, ctx=ctx, input_shape=data_shape, num_round=num_round,\n",
" learning_rate=0.01, momentum=0.9, wd=0.0001)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"ename": "MXNetError",
"evalue": "[12:00:28] src/ndarray/ndarray.cc:157: Check failed: from.shape() == to->shape() operands shape mismatch",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mMXNetError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-10-8ca28bf9d513>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;31m# In this case, eval_data is also a data iterator\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;31m# We will use accuracy to measure our model's performace\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrain_dataiter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0meval_data\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mval_dataiter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0meval_metric\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'acc'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;32m/home/bing/wtf/mxnet/python/mxnet/model.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, X, y, eval_data, eval_metric, verbose)\u001b[0m\n\u001b[0;32m 304\u001b[0m \u001b[0mtrain_data\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0meval_data\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0meval_data\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 305\u001b[0m \u001b[0meval_metric\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0meval_metric\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 306\u001b[1;33m verbose=verbose)\n\u001b[0m",
"\u001b[1;32m/home/bing/wtf/mxnet/python/mxnet/model.py\u001b[0m in \u001b[0;36m_train\u001b[1;34m(symbol, ctx, input_shape, arg_params, aux_params, begin_round, end_round, optimizer, train_data, eval_data, eval_metric, iter_end_callback, verbose)\u001b[0m\n\u001b[0;32m 85\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mlist\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marg_names\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0marg_arrays\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 86\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mkey\u001b[0m \u001b[1;32min\u001b[0m \u001b[0marg_params\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 87\u001b[1;33m \u001b[0marg_params\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcopyto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 88\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mkey\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mlist\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0maux_names\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0maux_arrays\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 89\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mkey\u001b[0m \u001b[1;32min\u001b[0m \u001b[0maux_params\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m/home/bing/wtf/mxnet/python/mxnet/ndarray.py\u001b[0m in \u001b[0;36mcopyto\u001b[1;34m(self, other)\u001b[0m\n\u001b[0;32m 306\u001b[0m RuntimeWarning)\n\u001b[0;32m 307\u001b[0m \u001b[1;32mreturn\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 308\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mNDArray\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_copyto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mout\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mother\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 309\u001b[0m \u001b[1;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mother\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mContext\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 310\u001b[0m \u001b[0mhret\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mNDArray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_new_alloc_handle\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mother\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m/home/bing/wtf/mxnet/python/mxnet/ndarray.py\u001b[0m in \u001b[0;36mgeneric_ndarray_function\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 618\u001b[0m \u001b[0mc_array\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mNDArrayHandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mhandle\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0muse_vars_range\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 619\u001b[0m \u001b[0mc_array\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmx_float\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mscalar_range\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 620\u001b[1;33m c_array(NDArrayHandle, [v.handle for v in mutate_vars])))\n\u001b[0m\u001b[0;32m 621\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mn_mutate_vars\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 622\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mmutate_vars\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m/home/bing/wtf/mxnet/python/mxnet/base.py\u001b[0m in \u001b[0;36mcheck_call\u001b[1;34m(ret)\u001b[0m\n\u001b[0;32m 95\u001b[0m \"\"\"\n\u001b[0;32m 96\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mret\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 97\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mMXNetError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpy_str\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_LIB\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mMXGetLastError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 98\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mc_str\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstring\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mMXNetError\u001b[0m: [12:00:28] src/ndarray/ndarray.cc:157: Check failed: from.shape() == to->shape() operands shape mismatch"
]
}
],
"source": [
"# Now we can fit the model with data iterators\n",
"# When we use data iterator, we don't need to set y because label comes from data iterator directly\n",
"# In this case, eval_data is also a data iterator\n",
"# We will use accuracy to measure our model's performace\n",
"model.fit(X=train_dataiter, eval_data=val_dataiter, eval_metric='acc', verbose=True)\n",
"# You need to wait for a while to get the result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's all!"
]
},
{
Expand Down Expand Up @@ -437,7 +568,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.0"
"version": "3.4.2"
}
},
"nbformat": 4,
Expand Down
Loading