Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support extra cases for broadcasting setitem, fixing a few other bugs…
Browse files Browse the repository at this point in the history
… in indexing
  • Loading branch information
haojin2 committed Feb 6, 2020
1 parent 9c1789f commit af042e8
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 34 deletions.
5 changes: 4 additions & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,10 @@ def _set_nd_basic_indexing(self, key, value):
value_nd.copyto(self)

elif isinstance(value, numeric_types):
self._full(value)
if isinstance(value, bool):
self._full(int(value))
else:
self._full(value)

elif isinstance(value, (np.ndarray, np.generic)):
tmp_shape = _shape_for_bcast(
Expand Down
4 changes: 3 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,9 @@ def _set_np_boolean_indexing(self, key, value):
pos -= 1

if isinstance(value, numeric_types):
_npi.boolean_mask_assign_scalar(data=data, mask=mask, value=value, start_axis=pos, out=data)
_npi.boolean_mask_assign_scalar(data=data, mask=mask,
value=int(value) if isinstance(value, bool) else value,
start_axis=pos, out=data)
elif isinstance(value, ndarray):
_npi.boolean_mask_assign_tensor(data=data, mask=mask, value=value, start_axis=pos, out=data)
else:
Expand Down
42 changes: 24 additions & 18 deletions src/operator/numpy/np_boolean_mask_assign.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,17 @@ bool BooleanAssignShape(const nnvm::NodeAttrs& attrs,
} else {
const TShape& vshape = in_attrs->at(2);
if (vshape.Size() > 1) {
for (int i = 0; i < dshape.ndim(); ++i) {
if (i < start_axis) {
CHECK_EQ(dshape[i], vshape[i])
<< "shape mismatch of value with input at dimension " << i
<< "; dimension is " << dshape[i] << " but corresponding value dimension is "
<< vshape[i];
}
if (i >= start_axis + mshape.ndim()) {
CHECK_EQ(dshape[i], vshape[i - mshape.ndim() + 1])
<< "shape mismatch of value with input at dimension " << i
<< "; dimension is " << dshape[i] << " but corresponding value dimension is "
<< vshape[i - mshape.ndim() + 1];
}
for (int i = 0; i < start_axis; ++i) {
CHECK_EQ(dshape[i], vshape[i])
<< "shape mismatch of value with input at dimension " << i
<< "; dimension is " << dshape[i] << " but corresponding value dimension is "
<< vshape[i];
}
for (int i = 1; i <= (dshape.ndim() - start_axis - mshape.ndim()); ++i) {
CHECK_EQ(dshape[dshape.ndim() - i], vshape[vshape.ndim() - i])
<< "shape mismatch of value with input at dimension " << (dshape.ndim() - i)
<< "; dimension is " << dshape[dshape.ndim() - i]
<< " but corresponding value dimension is " << vshape[vshape.ndim() - 1];
}
}
}
Expand Down Expand Up @@ -225,8 +223,13 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
const TShape& vshape = inputs[2].shape_;

if (inputs.size() == 3U) {
// tensor case
if (inputs[2].shape_.Size() != 1) {
if (vshape[start_axis] != 1) {
auto vndim = vshape.ndim();
auto dndim = dshape.ndim();
auto mndim = mshape.ndim();
CHECK(vndim <= (dndim - mndim + 1));
if ((vndim == (dndim - mndim + 1)) && (vshape[start_axis] != 1)) {
// tensor case, check tensor size equal to or broadcastable with valid_num
CHECK_EQ(static_cast<size_t>(valid_num), vshape[start_axis])
<< "boolean array indexing assignment cannot assign " << vshape
Expand All @@ -250,20 +253,23 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
}

if (inputs.size() == 3U) {
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
if (inputs[2].shape_.Size() == 1) {
Kernel<BooleanAssignCPUKernel<true>, cpu>::Launch(
s, valid_num, data.dptr<DType>(), prefix_sum.data(), prefix_sum.size(),
leading, middle, valid_num, trailing, inputs[2].dptr<DType>());
} else {
Kernel<BooleanAssignCPUKernel<false>, cpu>::Launch(
bool need_broadcast = (vshape.ndim() == (dshape.ndim() - mshape.ndim() + 1)) ?
(vshape[start_axis] == 1) :
true;
Kernel<BooleanAssignCPUKernel<false>, cpu>::Launch(
s, valid_num, data.dptr<DType>(), prefix_sum.data(), prefix_sum.size(),
leading, middle, valid_num, trailing, inputs[2].dptr<DType>(), (vshape[start_axis] == 1));
leading, middle, valid_num, trailing, inputs[2].dptr<DType>(), need_broadcast);
}
});
} else {
CHECK(attrs.dict.find("value") != attrs.dict.end()) << "value needs be provided";
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
Kernel<BooleanAssignCPUKernel<true>, cpu>::Launch(
s, valid_num, data.dptr<DType>(), prefix_sum.data(), prefix_sum.size(),
leading, middle, trailing, static_cast<DType>(std::stod(attrs.dict.at("value"))));
Expand Down
18 changes: 13 additions & 5 deletions src/operator/numpy/np_boolean_mask_assign.cu
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,13 @@ void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs& attrs,
const TShape& vshape = inputs[2].shape_;

if (inputs.size() == 3U) {
// tensor case
if (inputs[2].shape_.Size() != 1) {
if (vshape[start_axis] != 1) {
auto vndim = vshape.ndim();
auto dndim = dshape.ndim();
auto mndim = mshape.ndim();
CHECK(vndim <= (dndim - mndim + 1));
if ((vndim == (dndim - mndim + 1)) && (vshape[start_axis] != 1)) {
// tensor case, check tensor size equal to or broadcastable with valid_num
CHECK_EQ(static_cast<size_t>(valid_num), vshape[start_axis])
<< "boolean array indexing assignment cannot assign " << vshape
Expand All @@ -230,22 +235,25 @@ void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs& attrs,

if (inputs.size() == 3U) {
if (inputs[2].shape_.Size() == 1) {
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
Kernel<BooleanAssignGPUKernel<true>, gpu>::Launch(
s, leading * valid_num * trailing, data.dptr<DType>(), prefix_sum, mask_size + 1,
leading, middle, valid_num, trailing, inputs[2].dptr<DType>());
});
} else {
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
bool need_broadcast = (vshape.ndim() == (dshape.ndim() - mshape.ndim() + 1)) ?
(vshape[start_axis] == 1) :
true;
MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
Kernel<BooleanAssignGPUKernel<false>, gpu>::Launch(
s, leading * valid_num * trailing, data.dptr<DType>(), prefix_sum, mask_size + 1,
leading, middle, valid_num, trailing, inputs[2].dptr<DType>(), (vshape[start_axis] == 1));
leading, middle, valid_num, trailing, inputs[2].dptr<DType>(), need_broadcast);
});
}
} else {
CHECK(attrs.dict.find("value") != attrs.dict.end()) << "value is not provided";
double value = std::stod(attrs.dict.at("value"));
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
Kernel<BooleanAssignGPUKernel<true>, gpu>::Launch(
s, leading * valid_num * trailing, data.dptr<DType>(), prefix_sum, mask_size + 1,
leading, middle, valid_num, trailing, static_cast<DType>(value));
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
Stream<cpu> *s = ctx.get_stream<cpu>();
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);

MSHADOW_TYPE_SWITCH(outputs[take_::kOut].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH(inputs[take_::kIdx].type_flag_, IType, { // index data type
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[take_::kOut].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[take_::kIdx].type_flag_, IType, { // index data type
if (param.mode == take_::kRaise) {
IType min = 0;
IType max = static_cast<IType>(arrshape[actual_axis] - 1);
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,8 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
Stream<gpu> *s = ctx.get_stream<gpu>();
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);

MSHADOW_TYPE_SWITCH(outputs[take_::kOut].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH(inputs[take_::kIdx].type_flag_, IType, { // index data type
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[take_::kOut].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[take_::kIdx].type_flag_, IType, { // index data type
if (param.mode == take_::kRaise) {
// check out-of-bound indices
IType min = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ struct InitOpWithScalarParam : dmlc::Parameter<InitOpWithScalarParam> {
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
MXNET_ADD_ALL_TYPES
MXNET_ADD_ALL_TYPES_WITH_BOOL
.describe("Target data type.");
DMLC_DECLARE_FIELD(value)
.describe("Value with which to fill newly created tensor");
Expand Down
11 changes: 7 additions & 4 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,9 +1301,11 @@ def hybrid_forward(self, F, a, mask, value):
dshape, mshape, start_axis = config
test_data = np.random.uniform(size=dshape)
valid_num = 0
while test_data.size != 0 and valid_num == 0:
mx_mask = np.random.choice([False, True], size=mshape)
valid_num = int(mx_mask.sum())
while valid_num == 0:
mx_mask = np.random.choice(np.array([False, True], dtype=np.bool), size=mshape)
if test_data.size == 0:
break
valid_num = int(mx_mask.asnumpy().sum())
np_mask = mx_mask.asnumpy().astype(_np.bool)
vshape = []
vshape_broadcast = []
Expand All @@ -1318,7 +1320,8 @@ def hybrid_forward(self, F, a, mask, value):
vshape.append(dshape[i])
vshape_broadcast.append(dshape[i])
vshape_broadcast = tuple(vshape_broadcast)
for val in [42.0, np.array(42.), np.array([42.]), np.random.uniform(size=vshape), np.random.uniform(size=vshape_broadcast)]:
for val in [42.0, _np.array(42.), _np.array([42.]), _np.random.uniform(size=vshape), _np.random.uniform(size=vshape_broadcast)]:
mx_val = val if isinstance(val, float) else np.array(val, dtype=np.float32)
test_block = TestBooleanAssignScalar(val, start_axis) if isinstance(val, float) else TestBooleanAssignTensor(start_axis)
if hybridize:
test_block.hybridize()
Expand Down

0 comments on commit af042e8

Please sign in to comment.