Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Parallelize CPU version and add GPU version of boolean_mask op
Browse files Browse the repository at this point in the history
  • Loading branch information
HyperZealot committed Feb 14, 2019
1 parent df5310b commit 3051c7c
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 70 deletions.
102 changes: 36 additions & 66 deletions src/operator/contrib/boolean_mask-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,83 +50,53 @@ struct BooleanMaskParam : public dmlc::Parameter<BooleanMaskParam> {
}
};

struct BooleanMaskForwardKernel {
template<typename DType>
static void MSHADOW_XINLINE Map(int i,
DType* out,
const DType* data,
const int32_t* idx,
const size_t col_size) {
int row_id = i / col_size;
int col_id = i % col_size;
int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1];
int32_t curr = idx[row_id];
if (prev != curr) {
out[prev * col_size + col_id] = data[i];
}
}
};

struct BooleanMaskBackwardKernel {
template<typename DType>
static void MSHADOW_XINLINE Map(int i,
DType* igrad,
const DType* ograd,
const int32_t* idx,
const size_t col_size) {
int row_id = i / col_size;
int col_id = i % col_size;
int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1];
int32_t curr = idx[row_id];
if (prev != curr) {
igrad[i] = ograd[prev * col_size + col_id];
}
}
};

template<typename xpu>
inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
// TODO(@junrushao1994): This implementation is a proof-of-concept,
// hence very slow actually. Performance should be improved in the future.
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
const int axis = param.axis;
const NDArray &data = inputs[0];
const NDArray &idx = inputs[1];
const NDArray &out = outputs[0];
CHECK_EQ(axis, 0) << "Not supported yet";
CHECK_EQ(data.shape()[axis], idx.shape()[0]);
CHECK_EQ(idx.shape().ndim(), 1U);
// count the number of 1s in `idx`, so that we could know the output dimension
size_t valid_num = 0;
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
DType* idx_dptr = idx.data().dptr<DType>();
int length = idx.shape()[0];
for (int i = 0; i < length; i++) {
if (idx_dptr[i]) {
++valid_num;
}
}
});
// set the output shape forcefully
TShape s = data.shape();
s[axis] = valid_num;
const_cast<NDArray &>(out).Init(s);
// do the copy
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
DType* idx_dptr = idx.data().dptr<DType>();
int length = idx.shape()[0];
mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>();
for (int i = 0, j = 0; i < length; ++i) {
if (idx_dptr[i]) {
NDArray src = data.At(i);
NDArray dst = out.At(j++);
CHECK(src.shape() == dst.shape());
mxnet_op::copy(stream, dst.data(), src.data());
}
}
});
}
const std::vector<NDArray> &outputs);

template<typename xpu>
inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
// inputs: {ograd, data, idx}
// outputs: {igrad_data, igrad_idx}
const NDArray& ograd = inputs[0];
const NDArray& idx = inputs[2];
const NDArray& igrad_data = outputs[0];
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
DType* idx_dptr = idx.data().dptr<DType>();
int length = idx.shape()[0];
mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>();
Fill<false>(stream, igrad_data.data(), req[0], 0);
for (int i = 0, j = 0; i < length; ++i) {
if (idx_dptr[i]) {
NDArray src = ograd.At(j++);
NDArray dst = igrad_data.At(i);
CHECK(src.shape() == dst.shape());
mxnet_op::copy(stream, dst.data(), src.data());
}
}
});
}
const std::vector<NDArray> &outputs);

} // namespace op
} // namespace mxnet
Expand Down
114 changes: 112 additions & 2 deletions src/operator/contrib/boolean_mask.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ namespace op {

DMLC_REGISTER_PARAMETER(BooleanMaskParam);


bool BooleanMaskType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
Expand Down Expand Up @@ -75,9 +74,116 @@ bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs,
return true;
}

struct BooleanMaskForwardCPUKernel {
template<typename DType>
static void Map(int i,
DType* out,
const DType* data,
const int32_t* idx,
const size_t col_size) {
// i is row id already
int32_t prev = (i == 0) ? 0 : idx[i - 1];
int32_t curr = idx[i];
if (prev != curr) {
std::memcpy(out + prev * col_size, data + i * col_size, col_size * sizeof(DType));
}
}
};

struct BooleanMaskBackwardCPUKernel {
template<typename DType>
static void Map(int i,
DType* igrad,
const DType* ograd,
const int32_t* idx,
const size_t col_size) {
// i is row id already
int32_t prev = (i == 0) ? 0 : idx[i - 1];
int32_t curr = idx[i];
if (prev != curr) {
std::memcpy(igrad + i * col_size, ograd + prev * col_size, col_size * sizeof(DType));
}
}
};

template<>
inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
const int axis = param.axis;
const NDArray &data = inputs[0];
const NDArray &idx = inputs[1];
const NDArray &out = outputs[0];
CHECK_EQ(axis, 0) << "Not supported yet";
CHECK_EQ(data.shape()[axis], idx.shape()[0]);
CHECK_EQ(idx.shape().ndim(), 1U);
// count the number of 1s in `idx`, so that we could know the output dimension
size_t idx_size = idx.shape()[0];
std::vector<int32_t> prefix_sum(idx_size, 0);
size_t valid_num = 0;
// Calculate prefix sum
MSHADOW_TYPE_SWITCH(idx.dtype(), DType, {
DType* idx_dptr = idx.data().dptr<DType>();
for (size_t i = 0; i < idx_size; i++) {
prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
}
valid_num = prefix_sum[idx_size - 1];
});
// set the output shape forcefully
TShape s = data.shape();
s[axis] = valid_num;
const_cast<NDArray &>(out).Init(s);
// do the copy
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
size_t input_size = data.shape().Size();
size_t col_size = input_size / idx_size;
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
mxnet_op::Kernel<BooleanMaskForwardCPUKernel, cpu>::Launch(
stream, idx_size, out.data().dptr<DType>(), data.data().dptr<DType>(),
prefix_sum.data(), col_size);
});
}

template<>
inline void BooleanMaskBackward<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
// inputs: {ograd, data, idx}
// outputs: {igrad_data, igrad_idx}
const NDArray& ograd = inputs[0];
const NDArray& idx = inputs[2];
const NDArray& igrad_data = outputs[0];
MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
MSHADOW_TYPE_SWITCH(idx.dtype(), IType, {
size_t input_size = igrad_data.shape().Size();
size_t idx_size = idx.shape()[0];
size_t col_size = input_size / idx_size;
std::vector<int32_t> prefix_sum(idx_size, 0);
IType* idx_dptr = idx.data().dptr<IType>();
for (size_t i = 0; i < idx_size; i++) {
prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
}
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
mxnet_op::Kernel<BooleanMaskBackwardCPUKernel, cpu>::Launch(
stream, idx_size, igrad_data.data().dptr<DType>(), ograd.data().dptr<DType>(),
prefix_sum.data(), col_size);
});
});
}

NNVM_REGISTER_OP(_contrib_boolean_mask)
.describe(R"code(
Experimental CPU-only support for boolean masking.
Given an n-d NDArray data, and a 1-d NDArray index,
the operator produces an un-predeterminable shaped n-d NDArray out,
which stands for the rows in x where the corresonding element in index is non-zero.
Expand All @@ -94,6 +200,10 @@ which stands for the rows in x where the corresonding element in index is non-ze
.set_attr_parser(ParamParser<BooleanMaskParam>)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "index"};
})
.set_attr<nnvm::FInferType>("FInferType", BooleanMaskType)
.set_attr<FComputeEx>("FComputeEx<cpu>", BooleanMaskForward<cpu>)
.set_attr<FInferStorageType>("FInferStorageType", BooleanMaskStorageType)
Expand Down
Loading

0 comments on commit 3051c7c

Please sign in to comment.