diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 9f0939ec7f37..e30a26dbca2e 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -38,8 +38,6 @@ from . import sampler as _sampler from ... import nd, context -from ...util import is_np_shape, is_np_array, set_np -from ... import numpy as _mx_np # pylint: disable=reimported if sys.platform == 'darwin' or sys.platform == 'win32': def rebuild_ndarray(*args): @@ -131,33 +129,27 @@ def __init__(self, *args, **kwargs): def default_batchify_fn(data): """Collate data into batch.""" if isinstance(data[0], nd.NDArray): - return _mx_np.stack(data) if is_np_array() else nd.stack(*data) + return nd.stack(*data) elif isinstance(data[0], tuple): data = zip(*data) return [default_batchify_fn(i) for i in data] else: data = np.asarray(data) - array_fn = _mx_np.array if is_np_array() else nd.array - return array_fn(data, dtype=data.dtype) + return nd.array(data, dtype=data.dtype) def default_mp_batchify_fn(data): """Collate data into batch. Use shared memory for stacking.""" if isinstance(data[0], nd.NDArray): - empty_fn = _mx_np.empty if is_np_array() else nd.empty - out = empty_fn((len(data),) + data[0].shape, dtype=data[0].dtype, + out = nd.empty((len(data),) + data[0].shape, dtype=data[0].dtype, ctx=context.Context('cpu_shared', 0)) - if is_np_array(): - return _mx_np.stack(data, out=out) - else: - return nd.stack(*data, out=out) + return nd.stack(*data, out=out) elif isinstance(data[0], tuple): data = zip(*data) return [default_mp_batchify_fn(i) for i in data] else: data = np.asarray(data) - array_fn = _mx_np.array if is_np_array() else nd.array - return array_fn(data, dtype=data.dtype, + return nd.array(data, dtype=data.dtype, ctx=context.Context('cpu_shared', 0)) @@ -393,20 +385,14 @@ def __len__(self): return len(self._batch_sampler) -def _thread_worker_initializer(active_shape, active_array): - """Initializer for ThreadPool.""" - set_np(shape=active_shape, array=active_array) - - _worker_dataset = None -def _worker_initializer(dataset, active_shape, active_array): +def _worker_initializer(dataset): """Initialier for processing pool.""" # global dataset is per-process based and only available in worker processes # this is only necessary to handle MXIndexedRecordIO because otherwise dataset # can be passed as argument global _worker_dataset _worker_dataset = dataset - set_np(shape=active_shape, array=active_array) def _worker_fn(samples, batchify_fn, dataset=None): """Function for processing data in worker process.""" @@ -573,13 +559,10 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers) if self._num_workers > 0: if self._thread_pool: - self._worker_pool = ThreadPool(self._num_workers, - initializer=_thread_worker_initializer, - initargs=(is_np_shape(), is_np_array())) + self._worker_pool = ThreadPool(self._num_workers) else: self._worker_pool = multiprocessing.Pool( - self._num_workers, initializer=_worker_initializer, - initargs=[self._dataset, is_np_shape(), is_np_array()]) + self._num_workers, initializer=_worker_initializer, initargs=[self._dataset]) if batchify_fn is None: if num_workers > 0: self._batchify_fn = default_mp_batchify_fn diff --git a/python/mxnet/gluon/data/vision/datasets.py b/python/mxnet/gluon/data/vision/datasets.py index bdcaff52a042..12ef7e16ef49 100644 --- a/python/mxnet/gluon/data/vision/datasets.py +++ b/python/mxnet/gluon/data/vision/datasets.py @@ -31,8 +31,6 @@ from .. import dataset from ...utils import download, check_sha1, _get_repo_file_url from .... import nd, image, recordio, base -from .... import numpy as _mx_np # pylint: disable=reimported -from ....util import is_np_array class MNIST(dataset._DownloadedDataset): @@ -83,16 +81,13 @@ def _get_data(self): with gzip.open(label_file, 'rb') as fin: struct.unpack(">II", fin.read(8)) label = np.frombuffer(fin.read(), dtype=np.uint8).astype(np.int32) - if is_np_array(): - label = _mx_np.array(label, dtype=label.dtype) with gzip.open(data_file, 'rb') as fin: struct.unpack(">IIII", fin.read(16)) data = np.frombuffer(fin.read(), dtype=np.uint8) data = data.reshape(len(label), 28, 28, 1) - array_fn = _mx_np.array if is_np_array() else nd.array - self._data = array_fn(data, dtype=data.dtype) + self._data = nd.array(data, dtype=data.dtype) self._label = label @@ -188,9 +183,8 @@ def _get_data(self): data = np.concatenate(data) label = np.concatenate(label) - array_fn = _mx_np.array if is_np_array() else nd.array - self._data = array_fn(data, dtype=data.dtype) - self._label = array_fn(label, dtype=label.dtype) if is_np_array() else label + self._data = nd.array(data, dtype=data.dtype) + self._label = label class CIFAR100(CIFAR10): diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index ab8f8ab482df..955f2b2e4a66 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -23,7 +23,6 @@ from ...nn import Sequential, HybridSequential from .... import image from ....base import numeric_types -from ....util import is_np_array class Compose(Sequential): @@ -93,8 +92,6 @@ def __init__(self, dtype='float32'): self._dtype = dtype def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx return F.cast(x, self._dtype) @@ -137,8 +134,6 @@ def __init__(self): super(ToTensor, self).__init__() def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx return F.image.to_tensor(x) @@ -192,8 +187,6 @@ def __init__(self, mean=0.0, std=1.0): self._std = std def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx return F.image.normalize(x, self._mean, self._std) @@ -376,8 +369,6 @@ def __init__(self, size, keep_ratio=False, interpolation=1): self._interpolation = interpolation def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx return F.image.resize(x, self._size, self._keep, self._interpolation) class RandomFlipLeftRight(HybridBlock): @@ -394,8 +385,6 @@ def __init__(self): super(RandomFlipLeftRight, self).__init__() def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx return F.image.random_flip_left_right(x) @@ -413,8 +402,6 @@ def __init__(self): super(RandomFlipTopBottom, self).__init__() def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx return F.image.random_flip_top_bottom(x) @@ -440,8 +427,6 @@ def __init__(self, brightness): self._args = (max(0, 1-brightness), 1+brightness) def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx return F.image.random_brightness(x, *self._args) @@ -467,8 +452,6 @@ def __init__(self, contrast): self._args = (max(0, 1-contrast), 1+contrast) def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx return F.image.random_contrast(x, *self._args) @@ -494,8 +477,6 @@ def __init__(self, saturation): self._args = (max(0, 1-saturation), 1+saturation) def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx return F.image.random_saturation(x, *self._args) @@ -521,8 +502,6 @@ def __init__(self, hue): self._args = (max(0, 1-hue), 1+hue) def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx return F.image.random_hue(x, *self._args) @@ -557,8 +536,6 @@ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): self._args = (brightness, contrast, saturation, hue) def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx return F.image.random_color_jitter(x, *self._args) @@ -582,6 +559,4 @@ def __init__(self, alpha): self._alpha = alpha def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx return F.image.random_lighting(x, self._alpha) diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index d2e2344e253f..e6d4c5bab852 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -29,7 +29,6 @@ from .. import ndarray from ..base import numeric_types from .block import HybridBlock -from ..util import is_np_array def _apply_weighting(F, loss, weight=None, sample_weight=None): @@ -54,10 +53,7 @@ def _apply_weighting(F, loss, weight=None, sample_weight=None): Weighted loss """ if sample_weight is not None: - if is_np_array(): - loss = loss * sample_weight - else: - loss = F.broadcast_mul(loss, sample_weight) + loss = F.broadcast_mul(loss, sample_weight) if weight is not None: assert isinstance(weight, numeric_types), "weight must be a number" @@ -68,11 +64,7 @@ def _apply_weighting(F, loss, weight=None, sample_weight=None): def _reshape_like(F, x, y): """Reshapes x to the same shape as y.""" - if F is ndarray: - return x.reshape(y.shape) - elif is_np_array(): - F = F.npx - return F.reshape_like(x, y) + return x.reshape(y.shape) if F is ndarray else F.reshape_like(x, y) class Loss(HybridBlock): @@ -144,15 +136,9 @@ def __init__(self, weight=1., batch_axis=0, **kwargs): def hybrid_forward(self, F, pred, label, sample_weight=None): label = _reshape_like(F, label, pred) - loss = F.np.square(label - pred) if is_np_array() else F.square(label - pred) + loss = F.square(label - pred) loss = _apply_weighting(F, loss, self._weight / 2, sample_weight) - if is_np_array(): - if F is ndarray: - return F.np.mean(loss, axis=tuple(range(1, loss.ndim))) - else: - return F.npx.batch_flatten(loss).mean(axis=1) - else: - return F.mean(loss, axis=self._batch_axis, exclude=True) + return F.mean(loss, axis=self._batch_axis, exclude=True) class L1Loss(Loss): @@ -258,45 +244,27 @@ def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, **kwargs): def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None): label = _reshape_like(F, label, pred) - if is_np_array(): - relu_fn = F.npx.relu - act_fn = F.npx.activation - abs_fn = F.np.abs - mul_fn = F.np.multiply - log_fn = F.np.log - else: - relu_fn = F.relu - act_fn = F.Activation - abs_fn = F.abs - mul_fn = F.broadcast_mul - log_fn = F.log if not self._from_sigmoid: if pos_weight is None: # We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x))) - loss = relu_fn(pred) - pred * label + \ - act_fn(-abs_fn(pred), act_type='softrelu') + loss = F.relu(pred) - pred * label + \ + F.Activation(-F.abs(pred), act_type='softrelu') else: # We use the stable formula: x - x * z + (1 + z * pos_weight - z) * \ # (log(1 + exp(-abs(x))) + max(-x, 0)) - log_weight = 1 + mul_fn(pos_weight - 1, label) - loss = pred - pred * label + log_weight *\ - (act_fn(-abs_fn(pred), act_type='softrelu') + relu_fn(-pred)) + log_weight = 1 + F.broadcast_mul(pos_weight - 1, label) + loss = pred - pred * label + log_weight * \ + (F.Activation(-F.abs(pred), act_type='softrelu') + F.relu(-pred)) else: eps = 1e-12 if pos_weight is None: - loss = -(log_fn(pred + eps) * label - + log_fn(1. - pred + eps) * (1. - label)) + loss = -(F.log(pred + eps) * label + + F.log(1. - pred + eps) * (1. - label)) else: - loss = -(mul_fn(log_fn(pred + eps) * label, pos_weight) - + log_fn(1. - pred + eps) * (1. - label)) + loss = -(F.broadcast_mul(F.log(pred + eps) * label, pos_weight) + + F.log(1. - pred + eps) * (1. - label)) loss = _apply_weighting(F, loss, self._weight, sample_weight) - if is_np_array(): - if F is ndarray: - return F.np.mean(loss, axis=tuple(range(1, loss.ndim))) - else: - return F.npx.batch_flatten(loss).mean(axis=1) - else: - return F.mean(loss, axis=self._batch_axis, exclude=True) + return F.mean(loss, axis=self._batch_axis, exclude=True) SigmoidBCELoss = SigmoidBinaryCrossEntropyLoss @@ -373,27 +341,15 @@ def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None, self._from_logits = from_logits def hybrid_forward(self, F, pred, label, sample_weight=None): - if is_np_array(): - log_softmax = F.npx.log_softmax - pick = F.npx.pick - else: - log_softmax = F.log_softmax - pick = F.pick if not self._from_logits: - pred = log_softmax(pred, self._axis) + pred = F.log_softmax(pred, self._axis) if self._sparse_label: - loss = -pick(pred, label, axis=self._axis, keepdims=True) + loss = -F.pick(pred, label, axis=self._axis, keepdims=True) else: label = _reshape_like(F, label, pred) - loss = -(pred * label).sum(axis=self._axis, keepdims=True) + loss = -F.sum(pred * label, axis=self._axis, keepdims=True) loss = _apply_weighting(F, loss, self._weight, sample_weight) - if is_np_array(): - if F is ndarray: - return loss.mean(axis=tuple(range(1, loss.ndim))) - else: - return F.npx.batch_flatten(loss).mean(axis=1) - else: - return loss.mean(axis=self._batch_axis, exclude=True) + return F.mean(loss, axis=self._batch_axis, exclude=True) SoftmaxCELoss = SoftmaxCrossEntropyLoss diff --git a/python/mxnet/gluon/model_zoo/vision/resnet.py b/python/mxnet/gluon/model_zoo/vision/resnet.py index 50a65ec8d2da..48390decb11b 100644 --- a/python/mxnet/gluon/model_zoo/vision/resnet.py +++ b/python/mxnet/gluon/model_zoo/vision/resnet.py @@ -33,7 +33,6 @@ from ...block import HybridBlock from ... import nn from .... import base -from .... util import is_np_array # Helpers def _conv3x3(channels, stride, in_channels): @@ -82,8 +81,7 @@ def hybrid_forward(self, F, x): if self.downsample: residual = self.downsample(residual) - act = F.npx.activation if is_np_array() else F.Activation - x = act(residual+x, act_type='relu') + x = F.Activation(residual+x, act_type='relu') return x @@ -131,8 +129,7 @@ def hybrid_forward(self, F, x): if self.downsample: residual = self.downsample(residual) - act = F.npx.activation if is_np_array() else F.Activation - x = act(x + residual, act_type='relu') + x = F.Activation(x + residual, act_type='relu') return x @@ -168,14 +165,13 @@ def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs): def hybrid_forward(self, F, x): residual = x x = self.bn1(x) - act = F.npx.activation if is_np_array() else F.Activation - x = act(x, act_type='relu') + x = F.Activation(x, act_type='relu') if self.downsample: residual = self.downsample(x) x = self.conv1(x) x = self.bn2(x) - x = act(x, act_type='relu') + x = F.Activation(x, act_type='relu') x = self.conv2(x) return x + residual @@ -215,18 +211,17 @@ def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs): def hybrid_forward(self, F, x): residual = x x = self.bn1(x) - act = F.npx.activation if is_np_array() else F.Activation - x = act(x, act_type='relu') + x = F.Activation(x, act_type='relu') if self.downsample: residual = self.downsample(x) x = self.conv1(x) x = self.bn2(x) - x = act(x, act_type='relu') + x = F.Activation(x, act_type='relu') x = self.conv2(x) x = self.bn3(x) - x = act(x, act_type='relu') + x = F.Activation(x, act_type='relu') x = self.conv3(x) return x + residual diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py index a3baae004311..8c51b0a52592 100644 --- a/python/mxnet/gluon/nn/activations.py +++ b/python/mxnet/gluon/nn/activations.py @@ -22,7 +22,6 @@ from ... import initializer from ..block import HybridBlock -from ...util import is_np_array class Activation(HybridBlock): @@ -49,8 +48,7 @@ def _alias(self): return self._act_type def hybrid_forward(self, F, x): - act = F.npx.activation if is_np_array() else F.Activation - return act(x, act_type=self._act_type, name='fwd') + return F.Activation(x, act_type=self._act_type, name='fwd') def __repr__(self): s = '{name}({_act_type})' @@ -90,8 +88,7 @@ def __init__(self, alpha, **kwargs): self._alpha = alpha def hybrid_forward(self, F, x): - leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU - return leaky_relu(x, act_type='leaky', slope=self._alpha, name='fwd') + return F.LeakyReLU(x, act_type='leaky', slope=self._alpha, name='fwd') def __repr__(self): s = '{name}({alpha})' diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index fb0b62e8a74f..b1482ce6dd82 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -28,7 +28,6 @@ from ..block import Block, HybridBlock from ..utils import _indent from ... import nd, sym -from ...util import is_np_array class Sequential(Block): @@ -219,9 +218,8 @@ def __init__(self, units, activation=None, use_bias=True, flatten=True, self.act = None def hybrid_forward(self, F, x, weight, bias=None): - fc = F.npx.fully_connected if is_np_array() else F.FullyConnected - act = fc(x, weight, bias, no_bias=bias is None, num_hidden=self._units, - flatten=self._flatten, name='fwd') + act = F.FullyConnected(x, weight, bias, no_bias=bias is None, num_hidden=self._units, + flatten=self._flatten, name='fwd') if self.act is not None: act = self.act(act) return act @@ -266,11 +264,9 @@ def __init__(self, rate, axes=(), **kwargs): def hybrid_forward(self, F, x): if self._rate > 0: - dropout = F.npx.dropout if is_np_array() else F.Dropout - return dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False) + return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False) else: - copy = F.np.copy if is_np_array() else F.identity - return copy(x) + return F.identity(x) def __repr__(self): s = '{name}(p = {_rate}, axes={_axes})' @@ -361,9 +357,8 @@ def cast(self, dtype): super(BatchNorm, self).cast(dtype) def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var): - batch_norm = F.npx.batch_norm if is_np_array() else F.BatchNorm - return batch_norm(x, gamma, beta, running_mean, running_var, - name='fwd', **self._kwargs) + return F.BatchNorm(x, gamma, beta, running_mean, running_var, + name='fwd', **self._kwargs) def __repr__(self): s = '{name}({content}' @@ -415,8 +410,7 @@ def __init__(self, input_dim, output_dim, dtype='float32', allow_deferred_init=True, grad_stype=grad_stype) def hybrid_forward(self, F, x, weight): - embedding = F.npx.embedding if is_np_array() else F.Embedding - return embedding(x, weight, name='fwd', **self._kwargs) + return F.Embedding(x, weight, name='fwd', **self._kwargs) def __repr__(self): s = '{block_name}({input_dim} -> {output_dim}, {dtype})' @@ -437,8 +431,7 @@ def __init__(self, **kwargs): super(Flatten, self).__init__(**kwargs) def hybrid_forward(self, F, x): - flatten = F.npx.batch_flatten if is_np_array() else F.flatten - return flatten(x) + return F.Flatten(x) def __repr__(self): return self.__class__.__name__ @@ -611,8 +604,8 @@ def __init__(self, axis=-1, epsilon=1e-5, center=True, scale=True, allow_deferred_init=True) def hybrid_forward(self, F, data, gamma, beta): - layer_norm = F.npx.layer_norm if is_np_array() else F.LayerNorm - return layer_norm(data, gamma=gamma, beta=beta, axis=self._axis, eps=self._epsilon) + norm_data = F.LayerNorm(data, gamma=gamma, beta=beta, axis=self._axis, eps=self._epsilon) + return norm_data def __repr__(self): s = '{name}({content}' diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 4682684662cd..9a9d96a48335 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -30,17 +30,11 @@ from ... import symbol from ...base import numeric_types from .activations import Activation -from ...util import is_np_array def _infer_weight_shape(op_name, data_shape, kwargs): - data = symbol.var('data', shape=data_shape) - if is_np_array(): - op = getattr(symbol.npx, op_name) - data = data.as_np_ndarray() - else: - op = getattr(symbol, op_name) - sym = op(data, **kwargs) + op = getattr(symbol, op_name) + sym = op(symbol.var('data', shape=data_shape), **kwargs) return sym.infer_shape_partial()[0] @@ -115,10 +109,7 @@ def __init__(self, channels, kernel_size, strides, padding, dilation, if adj is not None: self._kwargs['adj'] = adj - if is_np_array(): - dshape = [-1]*(len(kernel_size) + 2) - else: - dshape = [0]*(len(kernel_size) + 2) + dshape = [0]*(len(kernel_size) + 2) dshape[layout.find('N')] = 1 dshape[layout.find('C')] = in_channels @@ -139,8 +130,6 @@ def __init__(self, channels, kernel_size, strides, padding, dilation, self.act = None def hybrid_forward(self, F, x, weight, bias=None): - if is_np_array(): - F = F.npx if bias is None: act = getattr(F, self._op_name)(x, weight, name='fwd', **self._kwargs) else: @@ -247,13 +236,9 @@ def __init__(self, channels, kernel_size, strides=1, padding=0, dilation=1, if isinstance(kernel_size, numeric_types): kernel_size = (kernel_size,) assert len(kernel_size) == 1, "kernel_size must be a number or a list of 1 ints" - op_name = kwargs.pop('op_name', 'Convolution') - if is_np_array(): - op_name = 'convolution' super(Conv1D, self).__init__( channels, kernel_size, strides, padding, dilation, groups, layout, - in_channels, activation, use_bias, weight_initializer, bias_initializer, - op_name, **kwargs) + in_channels, activation, use_bias, weight_initializer, bias_initializer, **kwargs) class Conv2D(_Conv): @@ -331,13 +316,9 @@ def __init__(self, channels, kernel_size, strides=(1, 1), padding=(0, 0), if isinstance(kernel_size, numeric_types): kernel_size = (kernel_size,)*2 assert len(kernel_size) == 2, "kernel_size must be a number or a list of 2 ints" - op_name = kwargs.pop('op_name', 'Convolution') - if is_np_array(): - op_name = 'convolution' super(Conv2D, self).__init__( channels, kernel_size, strides, padding, dilation, groups, layout, - in_channels, activation, use_bias, weight_initializer, bias_initializer, - op_name, **kwargs) + in_channels, activation, use_bias, weight_initializer, bias_initializer, **kwargs) class Conv3D(_Conv): @@ -416,13 +397,9 @@ def __init__(self, channels, kernel_size, strides=(1, 1, 1), padding=(0, 0, 0), if isinstance(kernel_size, numeric_types): kernel_size = (kernel_size,)*3 assert len(kernel_size) == 3, "kernel_size must be a number or a list of 3 ints" - op_name = kwargs.pop('op_name', 'Convolution') - if is_np_array(): - op_name = 'convolution' super(Conv3D, self).__init__( channels, kernel_size, strides, padding, dilation, groups, layout, - in_channels, activation, use_bias, weight_initializer, bias_initializer, - op_name, **kwargs) + in_channels, activation, use_bias, weight_initializer, bias_initializer, **kwargs) class Conv1DTranspose(_Conv): @@ -504,13 +481,10 @@ def __init__(self, channels, kernel_size, strides=1, padding=0, output_padding=0 output_padding = (output_padding,) assert len(kernel_size) == 1, "kernel_size must be a number or a list of 1 ints" assert len(output_padding) == 1, "output_padding must be a number or a list of 1 ints" - op_name = kwargs.pop('op_name', 'Deconvolution') - if is_np_array(): - op_name = 'deconvolution' super(Conv1DTranspose, self).__init__( channels, kernel_size, strides, padding, dilation, groups, layout, in_channels, activation, use_bias, weight_initializer, - bias_initializer, op_name=op_name, adj=output_padding, **kwargs) + bias_initializer, op_name='Deconvolution', adj=output_padding, **kwargs) self.outpad = output_padding @@ -598,13 +572,10 @@ def __init__(self, channels, kernel_size, strides=(1, 1), padding=(0, 0), output_padding = (output_padding,)*2 assert len(kernel_size) == 2, "kernel_size must be a number or a list of 2 ints" assert len(output_padding) == 2, "output_padding must be a number or a list of 2 ints" - op_name = kwargs.pop('op_name', 'Deconvolution') - if is_np_array(): - op_name = 'deconvolution' super(Conv2DTranspose, self).__init__( channels, kernel_size, strides, padding, dilation, groups, layout, in_channels, activation, use_bias, weight_initializer, - bias_initializer, op_name=op_name, adj=output_padding, **kwargs) + bias_initializer, op_name='Deconvolution', adj=output_padding, **kwargs) self.outpad = output_padding @@ -693,13 +664,10 @@ def __init__(self, channels, kernel_size, strides=(1, 1, 1), padding=(0, 0, 0), output_padding = (output_padding,)*3 assert len(kernel_size) == 3, "kernel_size must be a number or a list of 3 ints" assert len(output_padding) == 3, "output_padding must be a number or a list of 3 ints" - op_name = kwargs.pop('op_name', 'Deconvolution') - if is_np_array(): - op_name = 'deconvolution' super(Conv3DTranspose, self).__init__( channels, kernel_size, strides, padding, dilation, groups, layout, in_channels, activation, use_bias, weight_initializer, bias_initializer, - op_name=op_name, adj=output_padding, **kwargs) + op_name='Deconvolution', adj=output_padding, **kwargs) self.outpad = output_padding @@ -726,8 +694,7 @@ def _alias(self): return 'pool' def hybrid_forward(self, F, x): - pooling = F.npx.pooling if is_np_array() else F.Pooling - return pooling(x, name='fwd', **self._kwargs) + return F.Pooling(x, name='fwd', **self._kwargs) def __repr__(self): s = '{name}(size={kernel}, stride={stride}, padding={pad}, ceil_mode={ceil_mode}' diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 9807c5e33108..2a9cd88bb214 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -28,8 +28,6 @@ from ... import ndarray, symbol from .. import HybridBlock, tensor_types from . import rnn_cell -from ...util import is_np_array - class _RNNLayer(HybridBlock): """Implementation of recurrent layers.""" @@ -219,10 +217,7 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs): info.update(kwargs) else: info = kwargs - state = func(name='%sh0_%d' % (self.prefix, i), **info) - if is_np_array(): - state = state.as_np_ndarray() - states.append(state) + states.append(func(name='%sh0_%d'%(self.prefix, i), **info)) return states def __call__(self, inputs, states=None, sequence_length=None, **kwargs): @@ -258,9 +253,8 @@ def hybrid_forward(self, F, inputs, states, sequence_length=None, **kwargs): def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs): """ forward using CUDNN or CPU kenrel""" - swapaxes = F.np.swapaxes if is_np_array() else F.swapaxes if self._layout == 'NTC': - inputs = swapaxes(inputs, 0, 1) + inputs = F.swapaxes(inputs, dim1=0, dim2=1) if self._projection_size is None: params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1) for t in ['weight', 'bias'] @@ -275,23 +269,20 @@ def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs): for g in ['i2h', 'h2h', 'h2r'] if g != 'h2r' or t != 'bias') - rnn_param_concat = F.np._internal.rnn_param_concat if is_np_array()\ - else F._internal._rnn_param_concat - params = rnn_param_concat(*params, dim=0) + params = F._internal._rnn_param_concat(*params, dim=0) if self._use_sequence_length: rnn_args = states + [sequence_length] else: rnn_args = states - rnn_fn = F.npx.rnn if is_np_array() else F.RNN - rnn = rnn_fn(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length, - state_size=self._hidden_size, projection_size=self._projection_size, - num_layers=self._num_layers, bidirectional=self._dir == 2, - p=self._dropout, state_outputs=True, mode=self._mode, - lstm_state_clip_min=self._lstm_state_clip_min, - lstm_state_clip_max=self._lstm_state_clip_max, - lstm_state_clip_nan=self._lstm_state_clip_nan) + rnn = F.RNN(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length, + state_size=self._hidden_size, projection_size=self._projection_size, + num_layers=self._num_layers, bidirectional=self._dir == 2, + p=self._dropout, state_outputs=True, mode=self._mode, + lstm_state_clip_min=self._lstm_state_clip_min, + lstm_state_clip_max=self._lstm_state_clip_max, + lstm_state_clip_nan=self._lstm_state_clip_nan) if self._mode == 'lstm': outputs, states = rnn[0], [rnn[1], rnn[2]] @@ -299,7 +290,7 @@ def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs): outputs, states = rnn[0], [rnn[1]] if self._layout == 'NTC': - outputs = swapaxes(outputs, 0, 1) + outputs = F.swapaxes(outputs, dim1=0, dim2=1) return outputs, states diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index c79b5e3ceb6d..4d2aef8e7d91 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -18,8 +18,6 @@ # coding: utf-8 # pylint: disable= """Parallelization utility optimizer.""" -from __future__ import absolute_import - __all__ = ['split_data', 'split_and_load', 'clip_global_norm', 'check_sha1', 'download'] @@ -40,8 +38,7 @@ class requests_failed_to_import(object): import numpy as np from .. import ndarray -from ..util import is_np_shape, is_np_array -from .. import numpy as _mx_np # pylint: disable=reimported +from ..util import is_np_shape def split_data(data, num_slice, batch_axis=0, even_split=True): @@ -86,19 +83,12 @@ def split_data(data, num_slice, batch_axis=0, even_split=True): slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size] for i in range(num_slice)] elif even_split: - if is_np_array(): - slices = _mx_np.split(data, indices_or_sections=num_slice, axis=batch_axis) - else: - slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis) + slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis) else: - if is_np_array(): - indices = [step * i for i in range(1, num_slice)] - slices = _mx_np.split(data, indices_or_sections=indices, axis=batch_axis) - else: - slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step) - if i < num_slice - 1 else - ndarray.slice_axis(data, batch_axis, i*step, size) - for i in range(num_slice)] + slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step) + if i < num_slice - 1 else + ndarray.slice_axis(data, batch_axis, i*step, size) + for i in range(num_slice)] return slices @@ -108,7 +98,7 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True): Parameters ---------- - data : NDArray or ndarray + data : NDArray A batch of data. ctx_list : list of Context A list of Contexts. @@ -119,12 +109,11 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True): Returns ------- - list of NDArrays or ndarrays + list of NDArrays Each corresponds to a context in `ctx_list`. """ - array_fn = _mx_np.array if is_np_array() else ndarray.array if not isinstance(data, ndarray.NDArray): - data = array_fn(data, ctx=ctx_list[0]) + data = ndarray.array(data, ctx=ctx_list[0]) if len(ctx_list) == 1: return [data.as_in_context(ctx_list[0])]