diff --git a/example/imagenet/alexnet.ipynb b/example/imagenet/alexnet.ipynb new file mode 100644 index 000000000000..1e9e399d1b5f --- /dev/null +++ b/example/imagenet/alexnet.ipynb @@ -0,0 +1,445 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Basic AlexNet Example\n", + "--------\n", + "\n", + "This notebook shows how to use MXNet construct AlexNet. AlexNet is made by Alex Krizhevsky in 2012.\n", + "\n", + "We will show how to train AlexNet in Python with single/multi GPU. All you need is to write a piece of Python code to describe network, then MXNet will help you finish all work without any of your effort. \n", + "\n", + "Generally, we need \n", + "\n", + "- Declare symbol network\n", + "- Declare data iterator\n", + "- Bind symbol network to device to model\n", + "- Fit the model" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import mxnet as mx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we have successully load MXNet. we will start declare a symbolic network. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "input_data = mx.symbol.Variable(name=\"data\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use a special symbol ```Variable``` to represent input data." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# stage 1\n", + "conv1 = mx.symbol.Convolution(data=input_data, kernel=(11, 11), stride=(4, 4), num_filter=96)\n", + "relu1 = mx.symbol.Activation(data=conv1, act_type=\"relu\")\n", + "pool1 = mx.symbol.Pooling(data=relu1, pool_type=\"max\", kernel=(3, 3), stride=(2,2))\n", + "lrn1 = mx.symbol.LRN(data=pool1, alpha=0.0001, beta=0.75, knorm=1, nsize=5)\n", + "# stage 2\n", + "conv2 = mx.symbol.Convolution(data=lrn1, kernel=(5, 5), pad=(2, 2), num_filter=256)\n", + "relu2 = mx.symbol.Activation(data=conv2, act_type=\"relu\")\n", + "pool2 = mx.symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2))\n", + "lrn2 = mx.symbol.LRN(data=pool2, alpha=0.0001, beta=0.75, knorm=1, nsize=5)\n", + "# stage 3\n", + "conv3 = mx.symbol.Convolution(data=lrn2, kernel=(3, 3), pad=(1, 1), num_filter=384)\n", + "relu3 = mx.symbol.Activation(data=conv3, act_type=\"relu\")\n", + "conv4 = mx.symbol.Convolution(data=relu3, kernel=(3, 3), pad=(1, 1), num_filter=384)\n", + "relu4 = mx.symbol.Activation(data=conv4, act_type=\"relu\")\n", + "conv5 = mx.symbol.Convolution(data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256)\n", + "relu5 = mx.symbol.Activation(data=conv5, act_type=\"relu\")\n", + "pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2))\n", + "# stage 4\n", + "flatten = mx.symbol.Flatten(data=pool3)\n", + "fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096)\n", + "relu6 = mx.symbol.Activation(data=fc1, act_type=\"relu\")\n", + "dropout1 = mx.symbol.Dropout(data=relu6, p=0.5)\n", + "# stage 5\n", + "fc2 = mx.symbol.FullyConnected(data=dropout1, num_hidden=4096)\n", + "relu7 = mx.symbol.Activation(data=fc2, act_type=\"relu\")\n", + "dropout2 = mx.symbol.Dropout(data=relu7, p=0.5)\n", + "# stage 6\n", + "fc3 = mx.symbol.FullyConnected(data=dropout2, num_hidden=1000)\n", + "softmax = mx.symbol.Softmax(data=fc3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we have a AlexNet. The ```softmax``` symbol contains all network structures. We can visualize our network structure. (require ```graphviz``` package)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "AlexNet\n", + "\n", + "\n", + "null_0\n", + "\n", + "data\n", + "\n", + "\n", + "Convolution_3\n", + "\n", + "Convolution\n", + "11x11/4, 96\n", + "\n", + "\n", + "Convolution_3->null_0\n", + "\n", + "\n", + "\n", + "\n", + "Activation_4\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_4->Convolution_3\n", + "\n", + "\n", + "\n", + "\n", + "Pooling_5\n", + "\n", + "Pooling\n", + "max, 3x3/2\n", + "\n", + "\n", + "Pooling_5->Activation_4\n", + "\n", + "\n", + "\n", + "\n", + "LRN_6\n", + "\n", + "LRN\n", + "\n", + "\n", + "LRN_6->Pooling_5\n", + "\n", + "\n", + "\n", + "\n", + "Convolution_9\n", + "\n", + "Convolution\n", + "5x5/1, 256\n", + "\n", + "\n", + "Convolution_9->LRN_6\n", + "\n", + "\n", + "\n", + "\n", + "Activation_10\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_10->Convolution_9\n", + "\n", + "\n", + "\n", + "\n", + "Pooling_11\n", + "\n", + "Pooling\n", + "max, 3x3/2\n", + "\n", + "\n", + "Pooling_11->Activation_10\n", + "\n", + "\n", + "\n", + "\n", + "LRN_12\n", + "\n", + "LRN\n", + "\n", + "\n", + "LRN_12->Pooling_11\n", + "\n", + "\n", + "\n", + "\n", + "Convolution_15\n", + "\n", + "Convolution\n", + "3x3/1, 384\n", + "\n", + "\n", + "Convolution_15->LRN_12\n", + "\n", + "\n", + "\n", + "\n", + "Activation_16\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_16->Convolution_15\n", + "\n", + "\n", + "\n", + "\n", + "Convolution_19\n", + "\n", + "Convolution\n", + "3x3/1, 384\n", + "\n", + "\n", + "Convolution_19->Activation_16\n", + "\n", + "\n", + "\n", + "\n", + "Activation_20\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_20->Convolution_19\n", + "\n", + "\n", + "\n", + "\n", + "Convolution_23\n", + "\n", + "Convolution\n", + "3x3/1, 256\n", + "\n", + "\n", + "Convolution_23->Activation_20\n", + "\n", + "\n", + "\n", + "\n", + "Activation_24\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_24->Convolution_23\n", + "\n", + "\n", + "\n", + "\n", + "Pooling_25\n", + "\n", + "Pooling\n", + "max, 3x3/2\n", + "\n", + "\n", + "Pooling_25->Activation_24\n", + "\n", + "\n", + "\n", + "\n", + "Flatten_26\n", + "\n", + "Flatten\n", + "\n", + "\n", + "Flatten_26->Pooling_25\n", + "\n", + "\n", + "\n", + "\n", + "FullyConnected_29\n", + "\n", + "FullyConnected\n", + "4096\n", + "\n", + "\n", + "FullyConnected_29->Flatten_26\n", + "\n", + "\n", + "\n", + "\n", + "Activation_30\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_30->FullyConnected_29\n", + "\n", + "\n", + "\n", + "\n", + "Dropout_31\n", + "\n", + "Dropout\n", + "\n", + "\n", + "Dropout_31->Activation_30\n", + "\n", + "\n", + "\n", + "\n", + "FullyConnected_34\n", + "\n", + "FullyConnected\n", + "4096\n", + "\n", + "\n", + "FullyConnected_34->Dropout_31\n", + "\n", + "\n", + "\n", + "\n", + "Activation_35\n", + "\n", + "Activation\n", + "relu\n", + "\n", + "\n", + "Activation_35->FullyConnected_34\n", + "\n", + "\n", + "\n", + "\n", + "Dropout_36\n", + "\n", + "Dropout\n", + "\n", + "\n", + "Dropout_36->Activation_35\n", + "\n", + "\n", + "\n", + "\n", + "FullyConnected_39\n", + "\n", + "FullyConnected\n", + "1000\n", + "\n", + "\n", + "FullyConnected_39->Dropout_36\n", + "\n", + "\n", + "\n", + "\n", + "Softmax_41\n", + "\n", + "Softmax\n", + "\n", + "\n", + "Softmax_41->FullyConnected_39\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mx.visualization.network2dot(\"AlexNet\", softmax)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "After define our network, we are able to create our model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.4.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index e3b17baa1b31..b5429a7bd816 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -21,6 +21,7 @@ from . import optimizer from . import model from . import initializer +from . import visualization import atexit __version__ = "0.1.0" diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index f58e90f4ac52..ff100c12c191 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -1,6 +1,6 @@ +# pylint: disable=invalid-name """Online evaluation metric module.""" import numpy as np -from .ndarray import NDArray class EvalMetric(object): """Base class of all evaluation metrics.""" @@ -8,7 +8,7 @@ def __init__(self, name): self.name = name self.reset() - def update(pred, label): + def update(self, pred, label): """Update the internal evaluation. Parameters @@ -40,6 +40,7 @@ def get(self): class Accuracy(EvalMetric): + """Calculate accuracy""" def __init__(self): super(Accuracy, self).__init__('accuracy') diff --git a/python/mxnet/model.py b/python/mxnet/model.py index f1cda62a1e53..726be0d7eb45 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -1,11 +1,12 @@ -# pylint: skip-file +# pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals, no-member +# pylint: disable=too-many-branches, too-many-statements, unused-argument, unused-variable +"""MXNet model module""" import numpy as np import time from . import io from . import nd from . import optimizer as opt from . import metric -from .symbol import Symbol from .context import Context from .initializer import Xavier @@ -20,7 +21,7 @@ def _train(symbol, ctx, input_shape, - arg_params, aux_states, + arg_params, aux_params, begin_round, end_round, optimizer, train_data, eval_data=None, eval_metric=None, iter_end_callback=None, verbose=True): @@ -40,7 +41,7 @@ def _train(symbol, ctx, input_shape, arg_params : dict of str to NDArray Model parameter, dict of name to NDArray of net's weights. - aux_states : dict of str to NDArray + aux_params : dict of str to NDArray Model parameter, dict of name to NDArray of net's auxiliary states. begin_round : int @@ -81,16 +82,16 @@ def _train(symbol, ctx, input_shape, grad_arrays = train_exec.grad_arrays aux_arrays = train_exec.aux_arrays # copy initialized parameters to executor parameters - for key, weight in zip(arg_names, arg_arrays): + for key, weight in list(zip(arg_names, arg_arrays)): if key in arg_params: arg_params[key].copyto(weight) - for key, weight in zip(aux_names, aux_arrays): + for key, weight in list(zip(aux_names, aux_arrays)): if key in aux_params: aux_params[key].copyto(weight) # setup helper data structures label_array = None data_array = None - for name, arr in zip(symbol.list_arguments(), arg_arrays): + for name, arr in list(zip(symbol.list_arguments(), arg_arrays)): if name.endswith('label'): assert label_array is None label_array = arr @@ -151,10 +152,10 @@ def _train(symbol, ctx, input_shape, for key, weight, gard in arg_blocks: if key in arg_params: weight.copyto(arg_params[key]) - for key, arr in zip(aux_names, aux_states): - arr.copyto(aux_states[key]) + for key, arr in list(zip(aux_names, aux_arrays)): + arr.copyto(aux_params[key]) if iter_end_callback: - iter_end_callback(i, arg_params, aux_states) + iter_end_callback(i, arg_params, aux_arrays) # end of the function return @@ -224,11 +225,11 @@ def _init_params(self): arg_shapes, _, aux_shapes = self.symbol.infer_shape(data=self.input_shape) if self.arg_params is None: arg_names = self.symbol.list_arguments() - self.arg_params = {k : nd.zeros(s) for k, s in zip(arg_names, arg_shapes) + self.arg_params = {k : nd.zeros(s) for k, s in list(zip(arg_names, arg_shapes)) if not is_data_arg(k)} if self.aux_states is None: aux_names = self.symbol.list_auxiliary_states() - self.aux_states = {k : nd.zeros(s) for k, s in zip(aux_names, aux_shapes)} + self.aux_states = {k : nd.zeros(s) for k, s in list(zip(aux_names, aux_shapes))} for k, v in self.arg_params.items(): self.initializer(k, v) for k, v in self.aux_states.items(): @@ -241,7 +242,7 @@ def _init_predictor(self): # for now only use the first device pred_exec = self.symbol.simple_bind( self.ctx[0], grad_req='null', data=self.input_shape) - for name, value in zip(self.symbol.list_arguments(), pred_exec.arg_arrays): + for name, value in list(zip(self.symbol.list_arguments(), pred_exec.arg_arrays)): if name not in self.arg_datas: assert name in self.arg_params self.arg_params[name].copyto(value) diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 682071148c7e..8118e23f2bf6 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -1,4 +1,4 @@ -# pylint: skip-file +# pylint: disable=fixme, invalid-name """Common Optimization algorithms with regularizations.""" from .ndarray import NDArray, zeros diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 78d6d9187c45..006bc66a5223 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -423,7 +423,6 @@ def simple_bind(self, ctx, grad_req='write', **kwargs): arg_ndarrays = [zeros(shape, ctx) for shape in arg_shapes] if grad_req != 'null': - req = {} grad_ndarrays = {} for name, shape in zip(self.list_arguments(), arg_shapes): if not (name.endswith('data') or name.endswith('label')): diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py new file mode 100644 index 000000000000..bccc4a2b3155 --- /dev/null +++ b/python/mxnet/visualization.py @@ -0,0 +1,137 @@ +# coding: utf-8 +# pylint: disable=invalid-name, protected-access, too-many-locals, fixme +# pylint: disable=unused-argument, too-many-branches, too-many-statements +"""Visualization module""" +from .symbol import Symbol +import json +import re +import copy + + +def _str2tuple(string): + """convert shape string to list, internal use only + + Parameters + ---------- + string: str + shape string + + Returns + ------- + list of str to represent shape + """ + return re.findall(r"\d+", string) + + +def network2dot(title, symbol, shape=None): + """convert symbol to dot object for visualization + + Parameters + ---------- + title: str + title of the dot graph + symbol: Symbol + symbol to be visualized + shape: TODO + TODO + + Returns + ------ + dot: Diagraph + dot object of symbol + """ + # todo add shape support + try: + from graphviz import Digraph + except: + raise ImportError("Draw network requires graphviz library") + if not isinstance(symbol, Symbol): + raise TypeError("symbol must be Symbol") + conf = json.loads(symbol.tojson()) + nodes = conf["nodes"] + heads = set(conf["heads"][0]) # TODO(xxx): check careful + node_attr = {"shape": "box", "fixedsize": "true", + "width": "1.3", "height": "0.8034", "style": "filled"} + dot = Digraph(name=title) + # make nodes + for i in range(len(nodes)): + node = nodes[i] + op = node["op"] + name = "%s_%d" % (op, i) + # input data + if i in heads and op == "null": + label = node["name"] + attr = copy.deepcopy(node_attr) + dot.node(name=name, label=label, **attr) + if op == "null": + continue + elif op == "Convolution": + label = "Convolution\n%sx%s/%s, %s" % (_str2tuple(node["param"]["kernel"])[0], + _str2tuple(node["param"]["kernel"])[1], + _str2tuple(node["param"]["stride"])[0], + node["param"]["num_filter"]) + attr = copy.deepcopy(node_attr) + attr["color"] = "royalblue1" + dot.node(name=name, label=label, **attr) + elif op == "FullyConnected": + label = "FullyConnected\n%s" % node["param"]["num_hidden"] + attr = copy.deepcopy(node_attr) + attr["color"] = "royalblue1" + dot.node(name=name, label=label, **attr) + elif op == "BatchNorm": + label = "BatchNorm" + attr = copy.deepcopy(node_attr) + attr["color"] = "orchid1" + dot.node(name=name, label=label, **attr) + elif op == "Concat": + label = "Concat" + attr = copy.deepcopy(node_attr) + attr["color"] = "seagreen1" + dot.node(name=name, label=label, **attr) + elif op == "Flatten": + label = "Flatten" + attr = copy.deepcopy(node_attr) + attr["color"] = "seagreen1" + dot.node(name=name, label=label, **attr) + elif op == "Reshape": + label = "Reshape" + attr = copy.deepcopy(node_attr) + attr["color"] = "seagreen1" + dot.node(name=name, label=label, **attr) + elif op == "Pooling": + label = "Pooling\n%s, %sx%s/%s" % (node["param"]["pool_type"], + _str2tuple(node["param"]["kernel"])[0], + _str2tuple(node["param"]["kernel"])[1], + _str2tuple(node["param"]["stride"])[0]) + attr = copy.deepcopy(node_attr) + attr["color"] = "firebrick2" + dot.node(name=name, label=label, **attr) + elif op == "Activation" or op == "LeakyReLU": + label = "%s\n%s" % (op, node["param"]["act_type"]) + attr = copy.deepcopy(node_attr) + attr["color"] = "salmon" + dot.node(name=name, label=label, **attr) + else: + label = op + attr = copy.deepcopy(node_attr) + attr["color"] = "olivedrab1" + dot.node(name=name, label=label, **attr) + + # add edges + for i in range(len(nodes)): + node = nodes[i] + op = node["op"] + name = "%s_%d" % (op, i) + if op == "null": + continue + else: + inputs = node["inputs"] + for item in inputs: + input_node = nodes[item[0]] + input_name = "%s_%d" % (input_node["op"], item[0]) + if input_node["op"] != "null" or item[0] in heads: + # add shape into label + attr = {"dir": "back"} + dot.edge(tail_name=name, head_name=input_name, **attr) + + return dot