Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

merge #104

Merged
merged 3 commits into from
Sep 20, 2015
Merged

merge #104

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
445 changes: 445 additions & 0 deletions example/imagenet/alexnet.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from . import optimizer
from . import model
from . import initializer
from . import visualization
import atexit

__version__ = "0.1.0"
5 changes: 3 additions & 2 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# pylint: disable=invalid-name
"""Online evaluation metric module."""
import numpy as np
from .ndarray import NDArray

class EvalMetric(object):
"""Base class of all evaluation metrics."""
def __init__(self, name):
self.name = name
self.reset()

def update(pred, label):
def update(self, pred, label):
"""Update the internal evaluation.

Parameters
Expand Down Expand Up @@ -40,6 +40,7 @@ def get(self):


class Accuracy(EvalMetric):
"""Calculate accuracy"""
def __init__(self):
super(Accuracy, self).__init__('accuracy')

Expand Down
27 changes: 14 additions & 13 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# pylint: skip-file
# pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals, no-member
# pylint: disable=too-many-branches, too-many-statements, unused-argument, unused-variable
"""MXNet model module"""
import numpy as np
import time
from . import io
from . import nd
from . import optimizer as opt
from . import metric
from .symbol import Symbol
from .context import Context
from .initializer import Xavier

Expand All @@ -20,7 +21,7 @@


def _train(symbol, ctx, input_shape,
arg_params, aux_states,
arg_params, aux_params,
begin_round, end_round, optimizer,
train_data, eval_data=None, eval_metric=None,
iter_end_callback=None, verbose=True):
Expand All @@ -40,7 +41,7 @@ def _train(symbol, ctx, input_shape,
arg_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.

aux_states : dict of str to NDArray
aux_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.

begin_round : int
Expand Down Expand Up @@ -81,16 +82,16 @@ def _train(symbol, ctx, input_shape,
grad_arrays = train_exec.grad_arrays
aux_arrays = train_exec.aux_arrays
# copy initialized parameters to executor parameters
for key, weight in zip(arg_names, arg_arrays):
for key, weight in list(zip(arg_names, arg_arrays)):
if key in arg_params:
arg_params[key].copyto(weight)
for key, weight in zip(aux_names, aux_arrays):
for key, weight in list(zip(aux_names, aux_arrays)):
if key in aux_params:
aux_params[key].copyto(weight)
# setup helper data structures
label_array = None
data_array = None
for name, arr in zip(symbol.list_arguments(), arg_arrays):
for name, arr in list(zip(symbol.list_arguments(), arg_arrays)):
if name.endswith('label'):
assert label_array is None
label_array = arr
Expand Down Expand Up @@ -151,10 +152,10 @@ def _train(symbol, ctx, input_shape,
for key, weight, gard in arg_blocks:
if key in arg_params:
weight.copyto(arg_params[key])
for key, arr in zip(aux_names, aux_states):
arr.copyto(aux_states[key])
for key, arr in list(zip(aux_names, aux_arrays)):
arr.copyto(aux_params[key])
if iter_end_callback:
iter_end_callback(i, arg_params, aux_states)
iter_end_callback(i, arg_params, aux_arrays)
# end of the function
return

Expand Down Expand Up @@ -224,11 +225,11 @@ def _init_params(self):
arg_shapes, _, aux_shapes = self.symbol.infer_shape(data=self.input_shape)
if self.arg_params is None:
arg_names = self.symbol.list_arguments()
self.arg_params = {k : nd.zeros(s) for k, s in zip(arg_names, arg_shapes)
self.arg_params = {k : nd.zeros(s) for k, s in list(zip(arg_names, arg_shapes))
if not is_data_arg(k)}
if self.aux_states is None:
aux_names = self.symbol.list_auxiliary_states()
self.aux_states = {k : nd.zeros(s) for k, s in zip(aux_names, aux_shapes)}
self.aux_states = {k : nd.zeros(s) for k, s in list(zip(aux_names, aux_shapes))}
for k, v in self.arg_params.items():
self.initializer(k, v)
for k, v in self.aux_states.items():
Expand All @@ -241,7 +242,7 @@ def _init_predictor(self):
# for now only use the first device
pred_exec = self.symbol.simple_bind(
self.ctx[0], grad_req='null', data=self.input_shape)
for name, value in zip(self.symbol.list_arguments(), pred_exec.arg_arrays):
for name, value in list(zip(self.symbol.list_arguments(), pred_exec.arg_arrays)):
if name not in self.arg_datas:
assert name in self.arg_params
self.arg_params[name].copyto(value)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: skip-file
# pylint: disable=fixme, invalid-name
"""Common Optimization algorithms with regularizations."""
from .ndarray import NDArray, zeros

Expand Down
1 change: 0 additions & 1 deletion python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ def simple_bind(self, ctx, grad_req='write', **kwargs):
arg_ndarrays = [zeros(shape, ctx) for shape in arg_shapes]

if grad_req != 'null':
req = {}
grad_ndarrays = {}
for name, shape in zip(self.list_arguments(), arg_shapes):
if not (name.endswith('data') or name.endswith('label')):
Expand Down
137 changes: 137 additions & 0 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-locals, fixme
# pylint: disable=unused-argument, too-many-branches, too-many-statements
"""Visualization module"""
from .symbol import Symbol
import json
import re
import copy


def _str2tuple(string):
"""convert shape string to list, internal use only

Parameters
----------
string: str
shape string

Returns
-------
list of str to represent shape
"""
return re.findall(r"\d+", string)


def network2dot(title, symbol, shape=None):
"""convert symbol to dot object for visualization

Parameters
----------
title: str
title of the dot graph
symbol: Symbol
symbol to be visualized
shape: TODO
TODO

Returns
------
dot: Diagraph
dot object of symbol
"""
# todo add shape support
try:
from graphviz import Digraph
except:
raise ImportError("Draw network requires graphviz library")
if not isinstance(symbol, Symbol):
raise TypeError("symbol must be Symbol")
conf = json.loads(symbol.tojson())
nodes = conf["nodes"]
heads = set(conf["heads"][0]) # TODO(xxx): check careful
node_attr = {"shape": "box", "fixedsize": "true",
"width": "1.3", "height": "0.8034", "style": "filled"}
dot = Digraph(name=title)
# make nodes
for i in range(len(nodes)):
node = nodes[i]
op = node["op"]
name = "%s_%d" % (op, i)
# input data
if i in heads and op == "null":
label = node["name"]
attr = copy.deepcopy(node_attr)
dot.node(name=name, label=label, **attr)
if op == "null":
continue
elif op == "Convolution":
label = "Convolution\n%sx%s/%s, %s" % (_str2tuple(node["param"]["kernel"])[0],
_str2tuple(node["param"]["kernel"])[1],
_str2tuple(node["param"]["stride"])[0],
node["param"]["num_filter"])
attr = copy.deepcopy(node_attr)
attr["color"] = "royalblue1"
dot.node(name=name, label=label, **attr)
elif op == "FullyConnected":
label = "FullyConnected\n%s" % node["param"]["num_hidden"]
attr = copy.deepcopy(node_attr)
attr["color"] = "royalblue1"
dot.node(name=name, label=label, **attr)
elif op == "BatchNorm":
label = "BatchNorm"
attr = copy.deepcopy(node_attr)
attr["color"] = "orchid1"
dot.node(name=name, label=label, **attr)
elif op == "Concat":
label = "Concat"
attr = copy.deepcopy(node_attr)
attr["color"] = "seagreen1"
dot.node(name=name, label=label, **attr)
elif op == "Flatten":
label = "Flatten"
attr = copy.deepcopy(node_attr)
attr["color"] = "seagreen1"
dot.node(name=name, label=label, **attr)
elif op == "Reshape":
label = "Reshape"
attr = copy.deepcopy(node_attr)
attr["color"] = "seagreen1"
dot.node(name=name, label=label, **attr)
elif op == "Pooling":
label = "Pooling\n%s, %sx%s/%s" % (node["param"]["pool_type"],
_str2tuple(node["param"]["kernel"])[0],
_str2tuple(node["param"]["kernel"])[1],
_str2tuple(node["param"]["stride"])[0])
attr = copy.deepcopy(node_attr)
attr["color"] = "firebrick2"
dot.node(name=name, label=label, **attr)
elif op == "Activation" or op == "LeakyReLU":
label = "%s\n%s" % (op, node["param"]["act_type"])
attr = copy.deepcopy(node_attr)
attr["color"] = "salmon"
dot.node(name=name, label=label, **attr)
else:
label = op
attr = copy.deepcopy(node_attr)
attr["color"] = "olivedrab1"
dot.node(name=name, label=label, **attr)

# add edges
for i in range(len(nodes)):
node = nodes[i]
op = node["op"]
name = "%s_%d" % (op, i)
if op == "null":
continue
else:
inputs = node["inputs"]
for item in inputs:
input_node = nodes[item[0]]
input_name = "%s_%d" % (input_node["op"], item[0])
if input_node["op"] != "null" or item[0] in heads:
# add shape into label
attr = {"dir": "back"}
dot.edge(tail_name=name, head_name=input_name, **attr)

return dot