@@ -561,7 +561,7 @@ def _as_list(obj):
561
561
return [obj ]
562
562
563
563
564
- _OP_NAME_PREFIX_LIST = ['_contrib_' , '_linalg_' , '_sparse_' , '_image_' , '_random_' , '_numpy_' ]
564
+ _OP_NAME_PREFIX_LIST = ['_contrib_' , '_linalg_' , '_sparse_' , '_image_' , '_random_' ]
565
565
566
566
567
567
def _get_op_name_prefix (op_name ):
@@ -607,15 +607,6 @@ def _init_op_module(root_namespace, module_name, make_op_func):
607
607
# use mx.nd.contrib or mx.sym.contrib from now on
608
608
contrib_module_name_old = "%s.contrib.%s" % (root_namespace , module_name )
609
609
contrib_module_old = sys .modules [contrib_module_name_old ]
610
- # special handling of registering numpy ops
611
- # only expose mxnet.numpy.op_name to users for imperative mode.
612
- # Symbolic mode should be used in Gluon.
613
- if module_name == 'ndarray' :
614
- numpy_module_name = "%s.numpy" % root_namespace
615
- numpy_module = sys .modules [numpy_module_name ]
616
- else :
617
- numpy_module_name = None
618
- numpy_module = None
619
610
submodule_dict = {}
620
611
for op_name_prefix in _OP_NAME_PREFIX_LIST :
621
612
submodule_dict [op_name_prefix ] = \
@@ -654,16 +645,6 @@ def _init_op_module(root_namespace, module_name, make_op_func):
654
645
function .__module__ = contrib_module_name_old
655
646
setattr (contrib_module_old , function .__name__ , function )
656
647
contrib_module_old .__all__ .append (function .__name__ )
657
- elif op_name_prefix == '_numpy_' and numpy_module_name is not None :
658
- # only register numpy ops under mxnet.numpy in imperative mode
659
- hdl = OpHandle ()
660
- check_call (_LIB .NNGetOpHandle (c_str (name ), ctypes .byref (hdl )))
661
- # TODO(reminisce): Didn't consider third level module here, e.g. mxnet.numpy.random.
662
- func_name = name [len (op_name_prefix ):]
663
- function = make_op_func (hdl , name , func_name )
664
- function .__module__ = numpy_module_name
665
- setattr (numpy_module , function .__name__ , function )
666
- numpy_module .__all__ .append (function .__name__ )
667
648
668
649
669
650
def _generate_op_module_signature (root_namespace , module_name , op_code_gen_func ):
@@ -754,7 +735,88 @@ def write_all_str(module_file, module_all_list):
754
735
ctypes .pythonapi .PyCapsule_New .restype = ctypes .py_object
755
736
ctypes .pythonapi .PyCapsule_GetPointer .restype = ctypes .c_void_p
756
737
738
+
757
739
from .runtime import Features
758
740
if Features ().is_enabled ("TVM_OP" ):
759
741
_LIB_TVM_OP = libinfo .find_lib_path ("libtvmop" )
760
742
check_call (_LIB .MXLoadTVMOp (c_str (_LIB_TVM_OP [0 ])))
743
+
744
+
745
+ def _sanity_check_params (func_name , unsupported_params , param_dict ):
746
+ for param_name in unsupported_params :
747
+ if param_name in param_dict :
748
+ raise NotImplementedError ("function {} does not support parameter {}"
749
+ .format (func_name , param_name ))
750
+
751
+
752
+ _NP_OP_SUBMODULE_LIST = ['_random_' , '_linalg_' ]
753
+ _NP_OP_PREFIX = '_numpy_'
754
+
755
+
756
+ def _get_np_op_submodule_name (op_name ):
757
+ assert op_name .startswith (_NP_OP_PREFIX )
758
+ for name in _NP_OP_SUBMODULE_LIST :
759
+ if op_name [len (_NP_OP_PREFIX ):].startswith (name ):
760
+ return name
761
+ return ""
762
+
763
+
764
+ def _init_np_op_module (root_namespace , module_name , make_op_func ):
765
+ """
766
+ Register numpy operators in namespaces `mxnet.numpy`, `mxnet.ndarray.numpy`
767
+ and `mxnet.symbol.numpy`. They are used in imperative mode, Gluon APIs w/o hybridization,
768
+ and Gluon APIs w/ hybridization, respectively. Essentially, operators with the same name
769
+ registered in three namespaces, respectively share the same functionality in C++ backend.
770
+ Different namespaces are needed for dispatching operator calls in Gluon's `HybridBlock` by `F`.
771
+
772
+ Parameters
773
+ ----------
774
+ root_namespace : str
775
+ Top level module name, `mxnet` in the current cases.
776
+ module_name : str
777
+ Second level module name, `ndarray` or `symbol` in the current case.
778
+ make_op_func : function
779
+ Function for creating op functions.
780
+ """
781
+ plist = ctypes .POINTER (ctypes .c_char_p )()
782
+ size = ctypes .c_uint ()
783
+
784
+ check_call (_LIB .MXListAllOpNames (ctypes .byref (size ), ctypes .byref (plist )))
785
+ op_names = []
786
+ for i in range (size .value ):
787
+ name = py_str (plist [i ])
788
+ if name .startswith (_NP_OP_PREFIX ):
789
+ op_names .append (name )
790
+
791
+ if module_name == 'numpy' :
792
+ # register ops for mxnet.numpy
793
+ module_pattern = "%s.%s._op"
794
+ submodule_pattern = "%s.%s.%s"
795
+ else :
796
+ # register ops for mxnet.ndarray.numpy or mxnet.symbol.numpy
797
+ module_pattern = "%s.%s.numpy._op"
798
+ submodule_pattern = "%s.%s.numpy.%s"
799
+ module_np_op = sys .modules [module_pattern % (root_namespace , module_name )]
800
+ submodule_dict = {}
801
+ # TODO(junwu): uncomment the following lines when adding numpy ops in submodules, e.g. np.random
802
+ # for submodule_name in _NP_OP_SUBMODULE_LIST:
803
+ # submodule_dict[submodule_name] = \
804
+ # sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])]
805
+ for name in op_names :
806
+ hdl = OpHandle ()
807
+ check_call (_LIB .NNGetOpHandle (c_str (name ), ctypes .byref (hdl )))
808
+ submodule_name = _get_np_op_submodule_name (name )
809
+ module_name_local = module_name
810
+ if len (submodule_name ) > 0 :
811
+ func_name = name [(len (_NP_OP_PREFIX ) + len (submodule_name )):]
812
+ cur_module = submodule_dict [submodule_name ]
813
+ module_name_local = submodule_pattern % (root_namespace ,
814
+ module_name , submodule_name [1 :- 1 ])
815
+ else :
816
+ func_name = name [len (_NP_OP_PREFIX ):]
817
+ cur_module = module_np_op
818
+
819
+ function = make_op_func (hdl , name , func_name )
820
+ function .__module__ = module_name_local
821
+ setattr (cur_module , function .__name__ , function )
822
+ cur_module .__all__ .append (function .__name__ )
0 commit comments