Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

fix full and full_like when input is boolean #17668

Merged
merged 1 commit into from
Feb 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
else:
ret = broadcast_to(fill_value, shape).astype(dtype)
return ret
if isinstance(fill_value, bool):
fill_value = int(fill_value)
dtype = _np.bool if dtype is None else dtype
dtype = _np.float32 if dtype is None else dtype
return _npi.full(shape=shape, value=fill_value, ctx=ctx, dtype=dtype, out=out)
# pylint: enable=too-many-arguments, redefined-outer-name
Expand Down Expand Up @@ -426,6 +429,8 @@ def full_like(a, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
raise NotImplementedError
if ctx is None:
ctx = current_context()
if isinstance(fill_value, bool):
fill_value = int(fill_value)
return _npi.full_like(a, fill_value=fill_value, dtype=dtype, ctx=ctx, out=out)


Expand Down
5 changes: 5 additions & 0 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,9 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
else:
ret = broadcast_to(fill_value, shape).astype(dtype)
return ret
if isinstance(fill_value, bool):
fill_value = int(fill_value)
dtype = _np.bool if dtype is None else dtype
dtype = _np.float32 if dtype is None else dtype
return _npi.full(shape=shape, value=fill_value, ctx=ctx, dtype=dtype, out=out)

Expand Down Expand Up @@ -1299,6 +1302,8 @@ def full_like(a, fill_value, dtype=None, order='C', ctx=None, out=None): # pyli
raise NotImplementedError
if ctx is None:
ctx = current_context()
if isinstance(fill_value, bool):
fill_value = int(fill_value)
return _npi.full_like(a, fill_value=fill_value, ctx=ctx, dtype=dtype, out=out)


Expand Down
15 changes: 8 additions & 7 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,15 @@ struct FullLikeOpParam : public dmlc::Parameter<FullLikeOpParam> {
dmlc::optional<int> dtype;
DMLC_DECLARE_PARAMETER(FullLikeOpParam) {
DMLC_DECLARE_FIELD(fill_value)
.describe("Value with which to fill newly created tensor");
.describe("Value with which to fill newly created tensor");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(dmlc::optional<int>())
MXNET_ADD_ALL_TYPES
.describe("Target data type.");
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype)
.set_default(dmlc::optional<int>())
MXNET_ADD_ALL_TYPES_WITH_BOOL
.describe("Target data type.");
}
};

Expand Down
10 changes: 6 additions & 4 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5881,6 +5881,8 @@ def hybrid_forward(self, F, a):
configs = [
((3, 4), 2.0),
((0, 3), 2.0),
((2, 3), True),
((3, 0), False),
((3, 4), np.array(2.0)),
((0, 3), np.array(2.0)),
((2, 3), np.array([1, 2, 3], dtype=np.float32)),
Expand All @@ -5890,7 +5892,7 @@ def hybrid_forward(self, F, a):
]

rtol, atol = 1e-3, 1e-5
dtypes = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']
dtypes = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool']
for shape, fill_value in configs:
for hybridize in [True, False]:
for dtype in dtypes:
Expand Down Expand Up @@ -5931,7 +5933,7 @@ def hybrid_forward(self, F, x, *args, **kwargs):
if StrictVersion(platform.python_version()) < StrictVersion('3.0.0'):
return

dtypes = ['float64', 'float32', 'float16', 'int64', 'int32', 'int8']
dtypes = ['float64', 'float32', 'float16', 'int64', 'int32', 'int8', 'bool']
shapes = [
(),
(1,),
Expand All @@ -5944,11 +5946,11 @@ def hybrid_forward(self, F, x, *args, **kwargs):
(3, 3, 1, 0),
]
# numpy.full_like operator in py2 cannot handle shape like (5, 0, 3) properly
fill_values = [0, 1, 2, 3, 4, 5, 6]
fill_values = [0, 1, 2, 3, 4, 5, 6, True, False]
flags = [True, False]
for fill_value, dtype, shape, hybridize in itertools.product(
fill_values, dtypes, shapes, flags):
param_dtype= _np.random.choice(dtypes)
param_dtype = _np.random.choice(dtypes)
a = np.random.uniform(low=0, high=100, size=shape, dtype='float64').astype(dtype)
test = TestFullLike(fill_value, param_dtype, npx.current_context())
expected_ret = _np.full_like(a.asnumpy(), fill_value=fill_value, dtype=param_dtype)
Expand Down