diff --git a/example/imagenet/alexnet.ipynb b/example/imagenet/alexnet.ipynb
new file mode 100644
index 000000000000..1e9e399d1b5f
--- /dev/null
+++ b/example/imagenet/alexnet.ipynb
@@ -0,0 +1,445 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Basic AlexNet Example\n",
+ "--------\n",
+ "\n",
+ "This notebook shows how to use MXNet construct AlexNet. AlexNet is made by Alex Krizhevsky in 2012.\n",
+ "\n",
+ "We will show how to train AlexNet in Python with single/multi GPU. All you need is to write a piece of Python code to describe network, then MXNet will help you finish all work without any of your effort. \n",
+ "\n",
+ "Generally, we need \n",
+ "\n",
+ "- Declare symbol network\n",
+ "- Declare data iterator\n",
+ "- Bind symbol network to device to model\n",
+ "- Fit the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "import mxnet as mx"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now we have successully load MXNet. we will start declare a symbolic network. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "input_data = mx.symbol.Variable(name=\"data\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We use a special symbol ```Variable``` to represent input data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# stage 1\n",
+ "conv1 = mx.symbol.Convolution(data=input_data, kernel=(11, 11), stride=(4, 4), num_filter=96)\n",
+ "relu1 = mx.symbol.Activation(data=conv1, act_type=\"relu\")\n",
+ "pool1 = mx.symbol.Pooling(data=relu1, pool_type=\"max\", kernel=(3, 3), stride=(2,2))\n",
+ "lrn1 = mx.symbol.LRN(data=pool1, alpha=0.0001, beta=0.75, knorm=1, nsize=5)\n",
+ "# stage 2\n",
+ "conv2 = mx.symbol.Convolution(data=lrn1, kernel=(5, 5), pad=(2, 2), num_filter=256)\n",
+ "relu2 = mx.symbol.Activation(data=conv2, act_type=\"relu\")\n",
+ "pool2 = mx.symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2))\n",
+ "lrn2 = mx.symbol.LRN(data=pool2, alpha=0.0001, beta=0.75, knorm=1, nsize=5)\n",
+ "# stage 3\n",
+ "conv3 = mx.symbol.Convolution(data=lrn2, kernel=(3, 3), pad=(1, 1), num_filter=384)\n",
+ "relu3 = mx.symbol.Activation(data=conv3, act_type=\"relu\")\n",
+ "conv4 = mx.symbol.Convolution(data=relu3, kernel=(3, 3), pad=(1, 1), num_filter=384)\n",
+ "relu4 = mx.symbol.Activation(data=conv4, act_type=\"relu\")\n",
+ "conv5 = mx.symbol.Convolution(data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256)\n",
+ "relu5 = mx.symbol.Activation(data=conv5, act_type=\"relu\")\n",
+ "pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2))\n",
+ "# stage 4\n",
+ "flatten = mx.symbol.Flatten(data=pool3)\n",
+ "fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096)\n",
+ "relu6 = mx.symbol.Activation(data=fc1, act_type=\"relu\")\n",
+ "dropout1 = mx.symbol.Dropout(data=relu6, p=0.5)\n",
+ "# stage 5\n",
+ "fc2 = mx.symbol.FullyConnected(data=dropout1, num_hidden=4096)\n",
+ "relu7 = mx.symbol.Activation(data=fc2, act_type=\"relu\")\n",
+ "dropout2 = mx.symbol.Dropout(data=relu7, p=0.5)\n",
+ "# stage 6\n",
+ "fc3 = mx.symbol.FullyConnected(data=dropout2, num_hidden=1000)\n",
+ "softmax = mx.symbol.Softmax(data=fc3)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now we have a AlexNet. The ```softmax``` symbol contains all network structures. We can visualize our network structure. (require ```graphviz``` package)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mx.visualization.network2dot(\"AlexNet\", softmax)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "collapsed": true
+ },
+ "source": [
+ "After define our network, we are able to create our model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.4.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index e3b17baa1b31..b5429a7bd816 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -21,6 +21,7 @@
from . import optimizer
from . import model
from . import initializer
+from . import visualization
import atexit
__version__ = "0.1.0"
diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index f58e90f4ac52..ff100c12c191 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -1,6 +1,6 @@
+# pylint: disable=invalid-name
"""Online evaluation metric module."""
import numpy as np
-from .ndarray import NDArray
class EvalMetric(object):
"""Base class of all evaluation metrics."""
@@ -8,7 +8,7 @@ def __init__(self, name):
self.name = name
self.reset()
- def update(pred, label):
+ def update(self, pred, label):
"""Update the internal evaluation.
Parameters
@@ -40,6 +40,7 @@ def get(self):
class Accuracy(EvalMetric):
+ """Calculate accuracy"""
def __init__(self):
super(Accuracy, self).__init__('accuracy')
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index f1cda62a1e53..726be0d7eb45 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -1,11 +1,12 @@
-# pylint: skip-file
+# pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals, no-member
+# pylint: disable=too-many-branches, too-many-statements, unused-argument, unused-variable
+"""MXNet model module"""
import numpy as np
import time
from . import io
from . import nd
from . import optimizer as opt
from . import metric
-from .symbol import Symbol
from .context import Context
from .initializer import Xavier
@@ -20,7 +21,7 @@
def _train(symbol, ctx, input_shape,
- arg_params, aux_states,
+ arg_params, aux_params,
begin_round, end_round, optimizer,
train_data, eval_data=None, eval_metric=None,
iter_end_callback=None, verbose=True):
@@ -40,7 +41,7 @@ def _train(symbol, ctx, input_shape,
arg_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
- aux_states : dict of str to NDArray
+ aux_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.
begin_round : int
@@ -81,16 +82,16 @@ def _train(symbol, ctx, input_shape,
grad_arrays = train_exec.grad_arrays
aux_arrays = train_exec.aux_arrays
# copy initialized parameters to executor parameters
- for key, weight in zip(arg_names, arg_arrays):
+ for key, weight in list(zip(arg_names, arg_arrays)):
if key in arg_params:
arg_params[key].copyto(weight)
- for key, weight in zip(aux_names, aux_arrays):
+ for key, weight in list(zip(aux_names, aux_arrays)):
if key in aux_params:
aux_params[key].copyto(weight)
# setup helper data structures
label_array = None
data_array = None
- for name, arr in zip(symbol.list_arguments(), arg_arrays):
+ for name, arr in list(zip(symbol.list_arguments(), arg_arrays)):
if name.endswith('label'):
assert label_array is None
label_array = arr
@@ -151,10 +152,10 @@ def _train(symbol, ctx, input_shape,
for key, weight, gard in arg_blocks:
if key in arg_params:
weight.copyto(arg_params[key])
- for key, arr in zip(aux_names, aux_states):
- arr.copyto(aux_states[key])
+ for key, arr in list(zip(aux_names, aux_arrays)):
+ arr.copyto(aux_params[key])
if iter_end_callback:
- iter_end_callback(i, arg_params, aux_states)
+ iter_end_callback(i, arg_params, aux_arrays)
# end of the function
return
@@ -224,11 +225,11 @@ def _init_params(self):
arg_shapes, _, aux_shapes = self.symbol.infer_shape(data=self.input_shape)
if self.arg_params is None:
arg_names = self.symbol.list_arguments()
- self.arg_params = {k : nd.zeros(s) for k, s in zip(arg_names, arg_shapes)
+ self.arg_params = {k : nd.zeros(s) for k, s in list(zip(arg_names, arg_shapes))
if not is_data_arg(k)}
if self.aux_states is None:
aux_names = self.symbol.list_auxiliary_states()
- self.aux_states = {k : nd.zeros(s) for k, s in zip(aux_names, aux_shapes)}
+ self.aux_states = {k : nd.zeros(s) for k, s in list(zip(aux_names, aux_shapes))}
for k, v in self.arg_params.items():
self.initializer(k, v)
for k, v in self.aux_states.items():
@@ -241,7 +242,7 @@ def _init_predictor(self):
# for now only use the first device
pred_exec = self.symbol.simple_bind(
self.ctx[0], grad_req='null', data=self.input_shape)
- for name, value in zip(self.symbol.list_arguments(), pred_exec.arg_arrays):
+ for name, value in list(zip(self.symbol.list_arguments(), pred_exec.arg_arrays)):
if name not in self.arg_datas:
assert name in self.arg_params
self.arg_params[name].copyto(value)
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 682071148c7e..8118e23f2bf6 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -1,4 +1,4 @@
-# pylint: skip-file
+# pylint: disable=fixme, invalid-name
"""Common Optimization algorithms with regularizations."""
from .ndarray import NDArray, zeros
diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py
index 78d6d9187c45..006bc66a5223 100644
--- a/python/mxnet/symbol.py
+++ b/python/mxnet/symbol.py
@@ -423,7 +423,6 @@ def simple_bind(self, ctx, grad_req='write', **kwargs):
arg_ndarrays = [zeros(shape, ctx) for shape in arg_shapes]
if grad_req != 'null':
- req = {}
grad_ndarrays = {}
for name, shape in zip(self.list_arguments(), arg_shapes):
if not (name.endswith('data') or name.endswith('label')):
diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py
new file mode 100644
index 000000000000..bccc4a2b3155
--- /dev/null
+++ b/python/mxnet/visualization.py
@@ -0,0 +1,137 @@
+# coding: utf-8
+# pylint: disable=invalid-name, protected-access, too-many-locals, fixme
+# pylint: disable=unused-argument, too-many-branches, too-many-statements
+"""Visualization module"""
+from .symbol import Symbol
+import json
+import re
+import copy
+
+
+def _str2tuple(string):
+ """convert shape string to list, internal use only
+
+ Parameters
+ ----------
+ string: str
+ shape string
+
+ Returns
+ -------
+ list of str to represent shape
+ """
+ return re.findall(r"\d+", string)
+
+
+def network2dot(title, symbol, shape=None):
+ """convert symbol to dot object for visualization
+
+ Parameters
+ ----------
+ title: str
+ title of the dot graph
+ symbol: Symbol
+ symbol to be visualized
+ shape: TODO
+ TODO
+
+ Returns
+ ------
+ dot: Diagraph
+ dot object of symbol
+ """
+ # todo add shape support
+ try:
+ from graphviz import Digraph
+ except:
+ raise ImportError("Draw network requires graphviz library")
+ if not isinstance(symbol, Symbol):
+ raise TypeError("symbol must be Symbol")
+ conf = json.loads(symbol.tojson())
+ nodes = conf["nodes"]
+ heads = set(conf["heads"][0]) # TODO(xxx): check careful
+ node_attr = {"shape": "box", "fixedsize": "true",
+ "width": "1.3", "height": "0.8034", "style": "filled"}
+ dot = Digraph(name=title)
+ # make nodes
+ for i in range(len(nodes)):
+ node = nodes[i]
+ op = node["op"]
+ name = "%s_%d" % (op, i)
+ # input data
+ if i in heads and op == "null":
+ label = node["name"]
+ attr = copy.deepcopy(node_attr)
+ dot.node(name=name, label=label, **attr)
+ if op == "null":
+ continue
+ elif op == "Convolution":
+ label = "Convolution\n%sx%s/%s, %s" % (_str2tuple(node["param"]["kernel"])[0],
+ _str2tuple(node["param"]["kernel"])[1],
+ _str2tuple(node["param"]["stride"])[0],
+ node["param"]["num_filter"])
+ attr = copy.deepcopy(node_attr)
+ attr["color"] = "royalblue1"
+ dot.node(name=name, label=label, **attr)
+ elif op == "FullyConnected":
+ label = "FullyConnected\n%s" % node["param"]["num_hidden"]
+ attr = copy.deepcopy(node_attr)
+ attr["color"] = "royalblue1"
+ dot.node(name=name, label=label, **attr)
+ elif op == "BatchNorm":
+ label = "BatchNorm"
+ attr = copy.deepcopy(node_attr)
+ attr["color"] = "orchid1"
+ dot.node(name=name, label=label, **attr)
+ elif op == "Concat":
+ label = "Concat"
+ attr = copy.deepcopy(node_attr)
+ attr["color"] = "seagreen1"
+ dot.node(name=name, label=label, **attr)
+ elif op == "Flatten":
+ label = "Flatten"
+ attr = copy.deepcopy(node_attr)
+ attr["color"] = "seagreen1"
+ dot.node(name=name, label=label, **attr)
+ elif op == "Reshape":
+ label = "Reshape"
+ attr = copy.deepcopy(node_attr)
+ attr["color"] = "seagreen1"
+ dot.node(name=name, label=label, **attr)
+ elif op == "Pooling":
+ label = "Pooling\n%s, %sx%s/%s" % (node["param"]["pool_type"],
+ _str2tuple(node["param"]["kernel"])[0],
+ _str2tuple(node["param"]["kernel"])[1],
+ _str2tuple(node["param"]["stride"])[0])
+ attr = copy.deepcopy(node_attr)
+ attr["color"] = "firebrick2"
+ dot.node(name=name, label=label, **attr)
+ elif op == "Activation" or op == "LeakyReLU":
+ label = "%s\n%s" % (op, node["param"]["act_type"])
+ attr = copy.deepcopy(node_attr)
+ attr["color"] = "salmon"
+ dot.node(name=name, label=label, **attr)
+ else:
+ label = op
+ attr = copy.deepcopy(node_attr)
+ attr["color"] = "olivedrab1"
+ dot.node(name=name, label=label, **attr)
+
+ # add edges
+ for i in range(len(nodes)):
+ node = nodes[i]
+ op = node["op"]
+ name = "%s_%d" % (op, i)
+ if op == "null":
+ continue
+ else:
+ inputs = node["inputs"]
+ for item in inputs:
+ input_node = nodes[item[0]]
+ input_name = "%s_%d" % (input_node["op"], item[0])
+ if input_node["op"] != "null" or item[0] in heads:
+ # add shape into label
+ attr = {"dir": "back"}
+ dot.edge(tail_name=name, head_name=input_name, **attr)
+
+ return dot