Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit c6bcc0e

Browse files
reminiscehaojin2
authored andcommitted
[numpy] Infra for supporting numpy ops in imperative mode and Gluon APIs (#14758)
* Infra of new ndarray and symbol types for numpy operators * Rename * Fix import problem * Refactor * Remove redundant code * Add docstring * More on numpy ndarray and symbol * Override unimplemented methdos for ndarray and _NumpySymbol * Fix built-in methods of ndarray and _NumpySymbol * Fix test and sanity check * Fix pylint * Address cr comments * Add unit tests for ndarray and _NumpySymbol * Add _true_divide * Fix gpu build * Add future import division * More correct way of checking if an output is from a np compat op * Fix gpu build * Fix output ndarray/symbol types with at least one new ndarray/symbol * Modify true_divide doc * Fix flaky copying zero-size arrays via gpus * Fix zero size in gluon hybridize and zeros/ones symbol not creating new symbol type * Fix doc
1 parent 55a9acd commit c6bcc0e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+3689
-59
lines changed

include/mxnet/c_api.h

+29
Original file line numberDiff line numberDiff line change
@@ -2902,6 +2902,35 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
29022902
EngineVarHandle mutable_vars_handle, int num_mutable_vars,
29032903
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
29042904
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
2905+
/*!
2906+
* \brief Determines if an op is a Numpy op by its name prefix.
2907+
* Every Numpy op starts with a prefix string "_numpy_".
2908+
* \param creator Operator handle
2909+
* \param is_np_op Indicator of whether creator is a numpy op handle
2910+
*/
2911+
MXNET_DLL int MXIsNumpyCompatOp(AtomicSymbolCreator creator,
2912+
int* is_np_op);
2913+
/*!
2914+
* \brief Create an NDArray from source sharing the same data chunk.
2915+
* \param src source NDArray
2916+
* \param out new NDArray sharing the same data chunck with src
2917+
*/
2918+
MXNET_DLL int MXShallowCopyNDArray(NDArrayHandle src, NDArrayHandle* out);
2919+
/*!
2920+
* \brief Create an Symbol from source sharing the same graph structure.
2921+
* \param src source Symbol
2922+
* \param out new Symbol sharing the same graph structure with src
2923+
*/
2924+
MXNET_DLL int MXShallowCopySymbol(SymbolHandle src, SymbolHandle * out);
2925+
/*!
2926+
* \brief Checks if an output of CachedOp is from a numpy op.
2927+
* \param handle CachedOp shared ptr
2928+
* \param output_idx index of the output of the CachedOp
2929+
* \param is_from_np_op indicator of whether the output is from a numpy op
2930+
*/
2931+
MXNET_DLL int MXIsCachedOpOutputFromNumpyCompatOp(CachedOpHandle handle,
2932+
int output_idx,
2933+
int* is_from_np_op);
29052934

29062935
/*!
29072936
* \brief Push an asynchronous operation to the engine.

include/mxnet/op_attr_types.h

+9
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,15 @@ using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
319319
using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
320320
size_t index)>;
321321

322+
/*!
323+
* \brief Indicates whether this operator is NumPy compatible.
324+
* It is for distinguishing the operator from classic MXNet operators
325+
* which do not support zero-dim and zero-size tensors.
326+
* In Python, it is used to determine whether to output numpy ndarrays
327+
* or symbols that are NumPy compatible.
328+
*/
329+
using TIsNumpyCompatible = bool;
330+
322331
} // namespace mxnet
323332

324333
#endif // MXNET_OP_ATTR_TYPES_H_

python/mxnet/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
from .base import MXNetError
2727
from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
2828
from . import base
29-
from . import numpy
3029
from . import contrib
3130
from . import ndarray
3231
from . import ndarray as nd
32+
from . import numpy
3333
from . import name
3434
# use mx.sym as short for symbol
3535
from . import symbol as sym

python/mxnet/_ctypes/ndarray.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..base import _LIB
2727
from ..base import c_str_array, c_handle_array
2828
from ..base import NDArrayHandle, CachedOpHandle
29-
from ..base import check_call
29+
from ..base import check_call, _is_np_compat_op
3030

3131

3232
class NDArrayBase(object):
@@ -55,13 +55,21 @@ def __reduce__(self):
5555

5656

5757
_ndarray_cls = None
58+
_np_ndarray_cls = None
59+
5860

5961
def _set_ndarray_class(cls):
6062
"""Set the symbolic class to be cls"""
6163
global _ndarray_cls
6264
_ndarray_cls = cls
6365

6466

67+
def _set_np_ndarray_class(cls):
68+
"""Set the symbolic class to be cls"""
69+
global _np_ndarray_cls
70+
_np_ndarray_cls = cls
71+
72+
6573
def _imperative_invoke(handle, ndargs, keys, vals, out):
6674
"""ctypes implementation of imperative invoke wrapper"""
6775
if out is not None:
@@ -93,18 +101,19 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
93101

94102
if original_output is not None:
95103
return original_output
104+
create_ndarray_fn = _np_ndarray_cls if _is_np_compat_op(handle) else _ndarray_cls
96105
if num_output.value == 1:
97-
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
98-
stype=out_stypes[0])
106+
return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
107+
stype=out_stypes[0])
99108
else:
100-
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
101-
stype=out_stypes[i])
102-
for i in range(num_output.value)]
109+
return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle),
110+
stype=out_stypes[i]) for i in range(num_output.value)]
103111

104112

105113
class CachedOp(object):
106114
"""Cached operator handle."""
107115
__slots__ = ["handle"]
116+
108117
def __init__(self, sym, flags=()):
109118
self.handle = CachedOpHandle()
110119

@@ -118,6 +127,13 @@ def __init__(self, sym, flags=()):
118127
def __del__(self):
119128
check_call(_LIB.MXFreeCachedOp(self.handle))
120129

130+
def _is_from_np_compat_op(self, idx):
131+
"""Check if the CachedOp's idx-th output is directly from a numpy op."""
132+
is_from_np_op = ctypes.c_int(0)
133+
check_call(_LIB.MXIsCachedOpOutputFromNumpyCompatOp(self.handle, ctypes.c_int(idx),
134+
ctypes.byref(is_from_np_op)))
135+
return is_from_np_op.value != 0
136+
121137
def __call__(self, *args, **kwargs):
122138
"""ctypes implementation of imperative invoke wrapper"""
123139
out = kwargs.pop('out', None)
@@ -152,9 +168,11 @@ def __call__(self, *args, **kwargs):
152168
if original_output is not None:
153169
return original_output
154170
if num_output.value == 1:
155-
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
156-
stype=out_stypes[0])
171+
create_ndarray_fn = _np_ndarray_cls if self._is_from_np_compat_op(0) else _ndarray_cls
172+
return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
173+
stype=out_stypes[0])
157174
else:
158-
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
159-
stype=out_stypes[i])
175+
return [_np_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i])
176+
if self._is_from_np_compat_op(i) else
177+
_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i])
160178
for i in range(num_output.value)]

python/mxnet/_ctypes/symbol.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222

2323
import ctypes
2424
from ..base import _LIB
25-
from ..base import c_str_array, c_handle_array, c_str, mx_uint
25+
from ..base import c_str_array, c_handle_array, c_str, mx_uint, _is_np_compat_op
2626
from ..base import SymbolHandle
2727
from ..base import check_call
2828

2929
_symbol_cls = None
30+
_np_symbol_cls = None
3031

3132
class SymbolBase(object):
3233
"""Symbol is symbolic graph."""
@@ -115,6 +116,12 @@ def _set_symbol_class(cls):
115116
_symbol_cls = cls
116117

117118

119+
def _set_np_symbol_class(cls):
120+
"""Set the symbolic class to be cls"""
121+
global _np_symbol_cls
122+
_np_symbol_cls = cls
123+
124+
118125
def _symbol_creator(handle, args, kwargs, keys, vals, name):
119126
sym_handle = SymbolHandle()
120127
check_call(_LIB.MXSymbolCreateAtomicSymbol(
@@ -128,7 +135,10 @@ def _symbol_creator(handle, args, kwargs, keys, vals, name):
128135
raise TypeError(
129136
'Operators with variable length input can only accept input'
130137
'Symbols either as positional or keyword arguments, not both')
131-
s = _symbol_cls(sym_handle)
138+
if _is_np_compat_op(handle):
139+
s = _np_symbol_cls(sym_handle)
140+
else:
141+
s = _symbol_cls(sym_handle)
132142
if args:
133143
s._compose(*args, name=name)
134144
elif kwargs:

python/mxnet/base.py

+82-20
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def _as_list(obj):
561561
return [obj]
562562

563563

564-
_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_', '_numpy_']
564+
_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_']
565565

566566

567567
def _get_op_name_prefix(op_name):
@@ -607,15 +607,6 @@ def _init_op_module(root_namespace, module_name, make_op_func):
607607
# use mx.nd.contrib or mx.sym.contrib from now on
608608
contrib_module_name_old = "%s.contrib.%s" % (root_namespace, module_name)
609609
contrib_module_old = sys.modules[contrib_module_name_old]
610-
# special handling of registering numpy ops
611-
# only expose mxnet.numpy.op_name to users for imperative mode.
612-
# Symbolic mode should be used in Gluon.
613-
if module_name == 'ndarray':
614-
numpy_module_name = "%s.numpy" % root_namespace
615-
numpy_module = sys.modules[numpy_module_name]
616-
else:
617-
numpy_module_name = None
618-
numpy_module = None
619610
submodule_dict = {}
620611
for op_name_prefix in _OP_NAME_PREFIX_LIST:
621612
submodule_dict[op_name_prefix] =\
@@ -654,16 +645,6 @@ def _init_op_module(root_namespace, module_name, make_op_func):
654645
function.__module__ = contrib_module_name_old
655646
setattr(contrib_module_old, function.__name__, function)
656647
contrib_module_old.__all__.append(function.__name__)
657-
elif op_name_prefix == '_numpy_' and numpy_module_name is not None:
658-
# only register numpy ops under mxnet.numpy in imperative mode
659-
hdl = OpHandle()
660-
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
661-
# TODO(reminisce): Didn't consider third level module here, e.g. mxnet.numpy.random.
662-
func_name = name[len(op_name_prefix):]
663-
function = make_op_func(hdl, name, func_name)
664-
function.__module__ = numpy_module_name
665-
setattr(numpy_module, function.__name__, function)
666-
numpy_module.__all__.append(function.__name__)
667648

668649

669650
def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func):
@@ -754,7 +735,88 @@ def write_all_str(module_file, module_all_list):
754735
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
755736
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
756737

738+
757739
from .runtime import Features
758740
if Features().is_enabled("TVM_OP"):
759741
_LIB_TVM_OP = libinfo.find_lib_path("libtvmop")
760742
check_call(_LIB.MXLoadTVMOp(c_str(_LIB_TVM_OP[0])))
743+
744+
745+
def _sanity_check_params(func_name, unsupported_params, param_dict):
746+
for param_name in unsupported_params:
747+
if param_name in param_dict:
748+
raise NotImplementedError("function {} does not support parameter {}"
749+
.format(func_name, param_name))
750+
751+
752+
_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']
753+
_NP_OP_PREFIX = '_numpy_'
754+
755+
756+
def _get_np_op_submodule_name(op_name):
757+
assert op_name.startswith(_NP_OP_PREFIX)
758+
for name in _NP_OP_SUBMODULE_LIST:
759+
if op_name[len(_NP_OP_PREFIX):].startswith(name):
760+
return name
761+
return ""
762+
763+
764+
def _init_np_op_module(root_namespace, module_name, make_op_func):
765+
"""
766+
Register numpy operators in namespaces `mxnet.numpy`, `mxnet.ndarray.numpy`
767+
and `mxnet.symbol.numpy`. They are used in imperative mode, Gluon APIs w/o hybridization,
768+
and Gluon APIs w/ hybridization, respectively. Essentially, operators with the same name
769+
registered in three namespaces, respectively share the same functionality in C++ backend.
770+
Different namespaces are needed for dispatching operator calls in Gluon's `HybridBlock` by `F`.
771+
772+
Parameters
773+
----------
774+
root_namespace : str
775+
Top level module name, `mxnet` in the current cases.
776+
module_name : str
777+
Second level module name, `ndarray` or `symbol` in the current case.
778+
make_op_func : function
779+
Function for creating op functions.
780+
"""
781+
plist = ctypes.POINTER(ctypes.c_char_p)()
782+
size = ctypes.c_uint()
783+
784+
check_call(_LIB.MXListAllOpNames(ctypes.byref(size), ctypes.byref(plist)))
785+
op_names = []
786+
for i in range(size.value):
787+
name = py_str(plist[i])
788+
if name.startswith(_NP_OP_PREFIX):
789+
op_names.append(name)
790+
791+
if module_name == 'numpy':
792+
# register ops for mxnet.numpy
793+
module_pattern = "%s.%s._op"
794+
submodule_pattern = "%s.%s.%s"
795+
else:
796+
# register ops for mxnet.ndarray.numpy or mxnet.symbol.numpy
797+
module_pattern = "%s.%s.numpy._op"
798+
submodule_pattern = "%s.%s.numpy.%s"
799+
module_np_op = sys.modules[module_pattern % (root_namespace, module_name)]
800+
submodule_dict = {}
801+
# TODO(junwu): uncomment the following lines when adding numpy ops in submodules, e.g. np.random
802+
# for submodule_name in _NP_OP_SUBMODULE_LIST:
803+
# submodule_dict[submodule_name] = \
804+
# sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])]
805+
for name in op_names:
806+
hdl = OpHandle()
807+
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
808+
submodule_name = _get_np_op_submodule_name(name)
809+
module_name_local = module_name
810+
if len(submodule_name) > 0:
811+
func_name = name[(len(_NP_OP_PREFIX) + len(submodule_name)):]
812+
cur_module = submodule_dict[submodule_name]
813+
module_name_local = submodule_pattern % (root_namespace,
814+
module_name, submodule_name[1:-1])
815+
else:
816+
func_name = name[len(_NP_OP_PREFIX):]
817+
cur_module = module_np_op
818+
819+
function = make_op_func(hdl, name, func_name)
820+
function.__module__ = module_name_local
821+
setattr(cur_module, function.__name__, function)
822+
cur_module.__all__.append(function.__name__)

python/mxnet/gluon/block.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from .. import name as _name
3535
from .parameter import Parameter, ParameterDict, DeferredInitializationError
3636
from .utils import _indent, _brief_print_list, HookHandle
37+
from .. import numpy as _mx_np
3738

3839

3940
class _BlockScope(object):
@@ -739,9 +740,13 @@ def _get_graph(self, *args):
739740
if not self._cached_graph:
740741
args, self._in_format = _flatten(args, "input")
741742
if len(args) > 1:
742-
inputs = [symbol.var('data%d'%i) for i in range(len(args))]
743+
inputs = [symbol.var('data%d' % i).as_np_ndarray()
744+
if isinstance(args[i], _mx_np.ndarray)
745+
else symbol.var('data%d' % i) for i in range(len(args))]
743746
else:
744-
inputs = [symbol.var('data')]
747+
inputs = [symbol.var('data').as_np_ndarray()
748+
if isinstance(args[0], _mx_np.ndarray)
749+
else symbol.var('data')]
745750
grouped_inputs = _regroup(inputs, self._in_format)[0]
746751

747752
params = {i: j.var() for i, j in self._reg_params.items()}

python/mxnet/ndarray/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .utils import load, load_frombuffer, save, zeros, empty, array
3131
from .sparse import _ndarray_cls
3232
from .ndarray import _GRAD_REQ_MAP, _DTYPE_MX_TO_NP, _DTYPE_NP_TO_MX, _new_empty_handle
33+
from . import numpy as np
3334

3435
__all__ = op.__all__ + ndarray.__all__ + utils.__all__ + \
3536
['contrib', 'linalg', 'random', 'sparse', 'image']

python/mxnet/ndarray/_internal.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,24 @@
2323
try:
2424
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
2525
from .._ctypes.ndarray import NDArrayBase, CachedOp
26-
from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke
26+
from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class
2727
elif _sys.version_info >= (3, 0):
2828
from .._cy3.ndarray import NDArrayBase, CachedOp
29-
from .._cy3.ndarray import _set_ndarray_class, _imperative_invoke
29+
from .._cy3.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class
3030
else:
3131
from .._cy2.ndarray import NDArrayBase, CachedOp
32-
from .._cy2.ndarray import _set_ndarray_class, _imperative_invoke
32+
from .._cy2.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class
3333
except ImportError:
3434
if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
3535
raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
3636
from .._ctypes.ndarray import NDArrayBase, CachedOp
37-
from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke
37+
from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class
3838

3939
from ..base import _Null
4040
try:
4141
from .gen__internal import * # pylint: disable=unused-wildcard-import
4242
except ImportError:
4343
pass
4444

45-
__all__ = ['NDArrayBase', 'CachedOp', '_imperative_invoke', '_set_ndarray_class']
45+
__all__ = ['NDArrayBase', 'CachedOp', '_imperative_invoke', '_set_ndarray_class',
46+
'_set_np_ndarray_class']

0 commit comments

Comments
 (0)