Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #105 from tqchen/master
Browse files Browse the repository at this point in the history
Update model to add save/load and period checkpoint
  • Loading branch information
tqchen committed Sep 20, 2015
2 parents a8c5ed1 + bd36bbc commit 3401ac5
Show file tree
Hide file tree
Showing 13 changed files with 434 additions and 260 deletions.
3 changes: 2 additions & 1 deletion python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from . import model
from . import initializer
from . import visualization
import atexit
# use viz as short for mx.ndarray
from . import visualization as viz

__version__ = "0.1.0"
41 changes: 41 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,44 @@ def ctypes2numpy_shared(cptr, shape):
dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)


def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.
num_args : mx_uint
Number of arguments.
arg_names : ctypes.POINTER(ctypes.c_char_p)
Argument names.
arg_types : ctypes.POINTER(ctypes.c_char_p)
Argument type information.
arg_descs : ctypes.POINTER(ctypes.c_char_p)
Argument description information.
remove_dup : boolean, optional
Whether remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for i in range(num_args.value):
key = py_str(arg_names[i])
if key in param_keys and remove_dup:
continue
param_keys.add(key)
type_info = py_str(arg_types[i])
ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
15 changes: 4 additions & 11 deletions python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .base import _LIB
from .base import c_array, c_str, mx_uint, py_str
from .base import DataIterHandle, NDArrayHandle
from .base import check_call
from .base import check_call, ctypes2docstring
from .ndarray import NDArray

class DataIter(object):
Expand Down Expand Up @@ -99,24 +99,17 @@ def _make_io_iterator(handle):
ctypes.byref(arg_types), \
ctypes.byref(arg_descs)))
iter_name = py_str(name.value)
param_str = []
for i in range(num_args.value):
ret = '%s : %s' % (arg_names[i], arg_types[i])
if len(arg_descs[i]) != 0:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)

doc_str = ('%s\n\n' +
'Parameters\n' +
'----------\n' +
'%s\n' +
'name : string, required.\n' +
' Name of the resulting data iterator.\n\n' +
'Returns\n' +
'-------\n' +
'iterator: Iterator\n'+
'iterator: DataIter\n'+
' The result iterator.')
doc_str = doc_str % (desc.value, '\n'.join(param_str))
doc_str = doc_str % (desc.value, param_str)

def creator(*args, **kwargs):
"""Create an iterator.
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(self):
def update(self, pred, label):
pred = pred.asnumpy()
label = label.asnumpy().astype('int32')
y = np.argmax(pred, axis=1)
self.sum_metric += np.sum(y == label)
py = np.argmax(pred, axis=1)
self.sum_metric += np.sum(py == label)
self.num_inst += label.size


Expand Down
Loading

0 comments on commit 3401ac5

Please sign in to comment.