diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index feb2caa67b2f..858bff1d077b 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -366,6 +366,9 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin else: ret = broadcast_to(fill_value, shape).astype(dtype) return ret + if isinstance(fill_value, bool): + fill_value = int(fill_value) + dtype = _np.bool if dtype is None else dtype dtype = _np.float32 if dtype is None else dtype return _npi.full(shape=shape, value=fill_value, ctx=ctx, dtype=dtype, out=out) # pylint: enable=too-many-arguments, redefined-outer-name @@ -426,6 +429,8 @@ def full_like(a, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin raise NotImplementedError if ctx is None: ctx = current_context() + if isinstance(fill_value, bool): + fill_value = int(fill_value) return _npi.full_like(a, fill_value=fill_value, dtype=dtype, ctx=ctx, out=out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 46fbc7d1ff7e..afa8766f9fba 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -1255,6 +1255,9 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin else: ret = broadcast_to(fill_value, shape).astype(dtype) return ret + if isinstance(fill_value, bool): + fill_value = int(fill_value) + dtype = _np.bool if dtype is None else dtype dtype = _np.float32 if dtype is None else dtype return _npi.full(shape=shape, value=fill_value, ctx=ctx, dtype=dtype, out=out) @@ -1299,6 +1302,8 @@ def full_like(a, fill_value, dtype=None, order='C', ctx=None, out=None): # pyli raise NotImplementedError if ctx is None: ctx = current_context() + if isinstance(fill_value, bool): + fill_value = int(fill_value) return _npi.full_like(a, fill_value=fill_value, ctx=ctx, dtype=dtype, out=out) diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 3ba4f2c0f5df..6261bffc2474 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -86,14 +86,15 @@ struct FullLikeOpParam : public dmlc::Parameter { dmlc::optional dtype; DMLC_DECLARE_PARAMETER(FullLikeOpParam) { DMLC_DECLARE_FIELD(fill_value) - .describe("Value with which to fill newly created tensor"); + .describe("Value with which to fill newly created tensor"); DMLC_DECLARE_FIELD(ctx) - .set_default("") - .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." - "Only used for imperative calls."); - DMLC_DECLARE_FIELD(dtype).set_default(dmlc::optional()) - MXNET_ADD_ALL_TYPES - .describe("Target data type."); + .set_default("") + .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." + "Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype) + .set_default(dmlc::optional()) + MXNET_ADD_ALL_TYPES_WITH_BOOL + .describe("Target data type."); } }; diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 45b64c26bb88..30e3dc317c65 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -5881,6 +5881,8 @@ def hybrid_forward(self, F, a): configs = [ ((3, 4), 2.0), ((0, 3), 2.0), + ((2, 3), True), + ((3, 0), False), ((3, 4), np.array(2.0)), ((0, 3), np.array(2.0)), ((2, 3), np.array([1, 2, 3], dtype=np.float32)), @@ -5890,7 +5892,7 @@ def hybrid_forward(self, F, a): ] rtol, atol = 1e-3, 1e-5 - dtypes = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64'] + dtypes = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool'] for shape, fill_value in configs: for hybridize in [True, False]: for dtype in dtypes: @@ -5931,7 +5933,7 @@ def hybrid_forward(self, F, x, *args, **kwargs): if StrictVersion(platform.python_version()) < StrictVersion('3.0.0'): return - dtypes = ['float64', 'float32', 'float16', 'int64', 'int32', 'int8'] + dtypes = ['float64', 'float32', 'float16', 'int64', 'int32', 'int8', 'bool'] shapes = [ (), (1,), @@ -5944,11 +5946,11 @@ def hybrid_forward(self, F, x, *args, **kwargs): (3, 3, 1, 0), ] # numpy.full_like operator in py2 cannot handle shape like (5, 0, 3) properly - fill_values = [0, 1, 2, 3, 4, 5, 6] + fill_values = [0, 1, 2, 3, 4, 5, 6, True, False] flags = [True, False] for fill_value, dtype, shape, hybridize in itertools.product( fill_values, dtypes, shapes, flags): - param_dtype= _np.random.choice(dtypes) + param_dtype = _np.random.choice(dtypes) a = np.random.uniform(low=0, high=100, size=shape, dtype='float64').astype(dtype) test = TestFullLike(fill_value, param_dtype, npx.current_context()) expected_ret = _np.full_like(a.asnumpy(), fill_value=fill_value, dtype=param_dtype)