Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support broadcasting on the indexed axis
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Jan 8, 2020
1 parent c1f6d64 commit d64e1f1
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 23 deletions.
30 changes: 20 additions & 10 deletions src/operator/numpy/np_boolean_mask_assign.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,21 @@ struct BooleanAssignCPUKernel {
const size_t middle,
const size_t valid_num,
const size_t trailing,
DType* tensor) {
DType* tensor,
const bool broadcast = false) {
// binary search for the turning point
size_t mid = bin_search(idx, idx_size, i);
// final answer is in mid
for (size_t l = 0; l < leading; ++l) {
for (size_t t = 0; t < trailing; ++t) {
data[(l * middle + mid) * trailing + t] =
(scalar) ? tensor[0] : tensor[(l * valid_num + i) * trailing + t];
if (scalar) {
data[(l * middle + mid) * trailing + t] = tensor[0];
} else {
data[(l * middle + mid) * trailing + t] =
(broadcast) ?
tensor[l * trailing + t] :
tensor[(l * valid_num + i) * trailing + t];
}
}
}
}
Expand Down Expand Up @@ -215,14 +222,17 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
// If there's no True in mask, return directly
if (valid_num == 0) return;

const TShape& vshape = inputs[2].shape_;

if (inputs.size() == 3U) {
const TShape& vshape = inputs[2].shape_;
if (inputs[2].shape_.Size() != 1) {
// tensor case, check tensor size with the valid_num
CHECK_EQ(static_cast<size_t>(valid_num), vshape[start_axis])
<< "boolean array indexing assignment cannot assign " << vshape
<< " input values to the " << valid_num << " output values where the mask is true"
<< std::endl;
if (vshape[start_axis] != 1) {
// tensor case, check tensor size equal to or broadcastable with valid_num
CHECK_EQ(static_cast<size_t>(valid_num), vshape[start_axis])
<< "boolean array indexing assignment cannot assign " << vshape
<< " input values to the " << valid_num << " output values where the mask is true"
<< std::endl;
}
}
}

Expand All @@ -248,7 +258,7 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
} else {
Kernel<BooleanAssignCPUKernel<false>, cpu>::Launch(
s, valid_num, data.dptr<DType>(), prefix_sum.data(), prefix_sum.size(),
leading, middle, valid_num, trailing, inputs[2].dptr<DType>());
leading, middle, valid_num, trailing, inputs[2].dptr<DType>(), (vshape[start_axis] == 1));
}
});
} else {
Expand Down
22 changes: 13 additions & 9 deletions src/operator/numpy/np_boolean_mask_assign.cu
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ struct BooleanAssignGPUKernel {
const size_t middle,
const size_t valid_num,
const size_t trailing,
DType* tensor) {
DType* tensor,
const bool broadcast = false) {
// binary search for the turning point
size_t m = i / trailing % valid_num;
size_t l = i / trailing / valid_num;
Expand All @@ -103,7 +104,7 @@ struct BooleanAssignGPUKernel {
if (scalar) {
data[dst] = tensor[0];
} else {
data[dst] = tensor[i];
data[dst] = broadcast ? tensor[l * trailing + i % trailing] : tensor[i];
}
}
};
Expand Down Expand Up @@ -200,14 +201,17 @@ void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs& attrs,
// If there's no True in mask, return directly
if (valid_num == 0) return;

const TShape& vshape = inputs[2].shape_;

if (inputs.size() == 3U) {
const TShape& vshape = inputs[2].shape_;
if (inputs[2].shape_.Size() != 1) {
// tensor case, check tensor size with the valid_num
CHECK_EQ(static_cast<size_t>(valid_num), vshape[start_axis])
<< "boolean array indexing assignment cannot assign " << vshape
<< " input values to the " << valid_num << " output values where the mask is true"
<< std::endl;
if (vshape[start_axis] != 1) {
// tensor case, check tensor size equal to or broadcastable with valid_num
CHECK_EQ(static_cast<size_t>(valid_num), vshape[start_axis])
<< "boolean array indexing assignment cannot assign " << vshape
<< " input values to the " << valid_num << " output values where the mask is true"
<< std::endl;
}
}
}

Expand Down Expand Up @@ -235,7 +239,7 @@ void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs& attrs,
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
Kernel<BooleanAssignGPUKernel<false>, gpu>::Launch(
s, leading * valid_num * trailing, data.dptr<DType>(), prefix_sum, mask_size + 1,
leading, middle, valid_num, trailing, inputs[2].dptr<DType>());
leading, middle, valid_num, trailing, inputs[2].dptr<DType>(), (vshape[start_axis] == 1));
});
}
} else {
Expand Down
14 changes: 10 additions & 4 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,22 +1296,28 @@ def hybrid_forward(self, F, a, mask, value):

for hybridize in [False]:
for config in configs:
print(config)
dshape, mshape, start_axis = config
test_data = np.random.uniform(size=dshape)
mx_mask = np.around(np.random.uniform(size=mshape))
valid_num = int(mx_mask.sum())
valid_num = 0
while test_data.size != 0 and valid_num == 0:
mx_mask = np.around(np.random.uniform(size=mshape))
valid_num = int(mx_mask.sum())
np_mask = mx_mask.asnumpy().astype(_np.bool)
vshape = []
vshape_broadcast = []
for i in range(len(dshape)):
if i < start_axis:
vshape.append(dshape[i])
vshape_broadcast.append(dshape[i])
elif i == start_axis:
vshape.append(valid_num)
vshape_broadcast.append(1)
elif i >= start_axis + len(mshape):
vshape.append(dshape[i])
vshape_broadcast.append(dshape[i])
vshape = tuple(vshape)
for val in [42.0, np.array(42.), np.array([42.]), np.random.uniform(size=vshape)]:
vshape_broadcast = tuple(vshape_broadcast)
for val in [42.0, np.array(42.), np.array([42.]), np.random.uniform(size=vshape), np.random.uniform(size=vshape_broadcast)]:
test_block = TestBooleanAssignScalar(val, start_axis) if isinstance(val, float) else TestBooleanAssignTensor(start_axis)
if hybridize:
test_block.hybridize()
Expand Down

0 comments on commit d64e1f1

Please sign in to comment.