From 5df48b2d7cc67721dcb8c0d27a4223dc7ac6b63b Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Mon, 18 Jun 2018 17:31:12 +0000 Subject: [PATCH] backward of enhanced take op --- src/operator/mshadow_op.h | 3 +- src/operator/tensor/indexing_op-inl.cuh | 107 +------- src/operator/tensor/indexing_op.cc | 22 +- src/operator/tensor/indexing_op.h | 344 ++++++++++++++++++------ tests/python/unittest/test_operator.py | 68 +++-- 5 files changed, 328 insertions(+), 216 deletions(-) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 76bc1a5c5451..4b37c330690f 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -602,8 +602,7 @@ struct clip : public mxnet_op::tunable { } } template - MSHADOW_XINLINE static DType Map(DType x, DType upper_bound, DType lower_bound) { - DType ret = x; + MSHADOW_XINLINE static DType Map(DType x, DType lower_bound, DType upper_bound) { if (x > upper_bound) { return upper_bound; } else if (x < lower_bound) { diff --git a/src/operator/tensor/indexing_op-inl.cuh b/src/operator/tensor/indexing_op-inl.cuh index fff8686906e4..cb10ed6c4a34 100644 --- a/src/operator/tensor/indexing_op-inl.cuh +++ b/src/operator/tensor/indexing_op-inl.cuh @@ -27,6 +27,8 @@ #define MXNET_OPERATOR_TENSOR_INDEXING_OP_CUH_ #include #include +#include "../mshadow_op.h" +#include "./util/tensor_util-inl.cuh" #if CUDA_VERSION >= 9000 #define FULLMASK 0xFFFFFFFF @@ -271,7 +273,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, const mshadow::Tensor& sorted, const mshadow::Tensor& index, const mshadow::Tensor &src, - mshadow::Tensor* workspace) { + mshadow::Tensor* workspace = NULL) { CHECK_EQ(dst.CheckContiguous(), true); CHECK_EQ(sorted.CheckContiguous(), true); CHECK_EQ(index.CheckContiguous(), true); @@ -293,7 +295,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, (NULL, encode_bytes, NULL, NULL, NULL, NULL, sorted.size(0), stream); size_t exclusivesum_bytes = 0; cub::DeviceScan::ExclusiveSum - (NULL, exclusivesum_bytes, NULL, NUsrc_indices_bytesLL, sorted.size(0), stream); + (NULL, exclusivesum_bytes, NULL, NULL, sorted.size(0), stream); size_t temporary_bytes = std::max(encode_bytes, exclusivesum_bytes); // Check that we have enough storage @@ -320,107 +322,6 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, num_runs_ptr, dst.size(0)); } -template -struct TakeGradGeneralKernel { - template - MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, const DType* ograd, - const IType* src_indptr, const IType* original_idx, - mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides, - const int in_ndims, const int out_ndims, const int idx_ndims, const int axis) { - const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1]; - const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1]; - const int in_mid_index = in_rest_index / in_stride[axis]; - const int in_tail_index = (axis == in_ndims - 1) ? - 0 : (in_rest_index % in_stride[axis]); - for (int i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) { - const int out_mid_index = original_idx[i]; - int target = in_tail_index + out_mid_index * out_stride[axis + idx_ndims - 1]; - target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1]; - arr_grad[tid] += ograd[target]; - } - } -} - -template -inline void TakeOpBackwardImpl(mshadow::Stream* s, - const OpContext& ctx, - const TBlob& arr, - const TBlob& idx, - const TBlob& ograd, - const int axis) { - using namespace mxnet_op; - using namespace mshadow; - CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; - const TShape& arrshape = arr.shape_; - const TShape& idxshape = idx.shape_; - const TShape& oshape = ograd.shape_; - // get size of temporary storage for sort - char* temp_storage_ptr = nullptr; - size_t scan_temp_storage_bytes = 0; - IType* src_indptr_bytes = nullptr; - cub::DeviceScan::ExclusiveSum(temp_storage_ptr, - scan_temp_storage_bytes, - src_indptr_bytes, - src_indptr_bytes, - arrshape[axis] + 1, - mshadow::Stream::GetStream(s)); - size_t sort_temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); - size_t temp_storage_bytes = max(scan_temp_storage_bytes, sort_temp_storage_bytes); - size_t original_idx_bytes = idxshape.Size() * sizeof(IType); - size_t src_indptr_bytes = (arrshape[actual_axis] + 1) * sizeof(IType); - size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes; - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); - IType* sorted_idx_ptr = reinterpret_cast(workspace.dptr_); - IType* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); - src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); - char* temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes; - // Reset indptr to zero - Kernel::Launch(s, arrshape[actual_axis] + 1, src_indptr_ptr); - // Fill original_idx - Kernel::Launch( - s, idxshape.Size(), 1, IType(0), IType(1), kWriteTo, original_idx_ptr); - // Fill sorted_idx_ptr with unsorted copy of idx - Kernel, gpu>::Launch( - s, idxshape.Size(), sorted_idx_ptr, idx.dptr()); - if (clip) { - Kernel, gpu>::Launch(s, idxshape.Size(), sorted_idx_ptr, - sorted_idx_ptr, IType(0), IType(arrshape[axis])); - } else { - Kernel, gpu>::Launch(s, idxshape.Size(), sorted_idx_ptr, - sorted_idx_ptr, IType(arrshape[axis])); - } - Tensor original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); - Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); - int num_bits = ilog2(static_cast(idxshape.Size()) - 1); - Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); - SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits); - Kernel::Launch( - s, idxshape.Size(), src_indptr_ptr, idx.dptr(), idxshape.Size()); - cub::DeviceScan::ExclusiveSum(temp_storage_ptr, - temp_storage_bytes, - src_indptr_bytes, - src_indptr_bytes, - arrshape[actual_axis] + 1, - mshadow::Stream::GetStream(s)); - - Shape<10> in_strides; - int stride = 1; - for (int i = arrshape.ndim() - 1; i > 0; stride *= arrshape[i], --i) { - in_strides[i] = stride; - } - Shape<10> out_strides; - stride = 1; - for (int i = oshape.ndim() - 1; i > 0; stride *= oshape[i], --i) { - out_strides[i] = stride; - } - MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { - Kernel::Launch( - s, arrshape.Size(), arr.dptr(), ograd.dptr(), src_indptr_ptr, original_idx_ptr, - in_strides, out_strides, arrshape.ndim(), oshape.ndim(), idxshape.ndim(), actual_axis); - }); -} - } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_TENSOR_INDEXING_OP_CUH_ diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index fb2f7d992c6e..e6c729830dfa 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -360,14 +360,14 @@ NNVM_REGISTER_OP(take) This function slices the input array along a particular axis with the provided indices. -Given an input array with shape ``(d0, d1, d2)`` and indices with shape ``(i0, i1)``, the output -will have shape ``(i0, i1, d1, d2)``, computed by:: +Given an input tensor with shape ``(d0, d1, d2)`` and indices with shape ``(i0, i1)``, and axis=1, +the output will have shape ``(d0, i0, i1, d2)``, computed by:: - output[i,j,:,:] = input[indices[i,j],:,:] + output[:,i,j,:] = input[:,indices[i,j],:] .. note:: - - `axis`- Only slicing along axis 0 is supported for now. - - `mode`- Only `clip` mode is supported for now. + - `axis`- Could be from -r to r-1 where r is the rank of input tensor + - `mode`- Could be either `clip` or `wrap`. Examples:: x = [4. 5. 6.] @@ -375,6 +375,9 @@ Examples:: // Trivial case, take the second element along the first axis. take(x, [1]) = [ 5. ] + // The other trivial case, axis=-1, take the third element along the first axis + take(x, [3], axis=-1, mode='clip') = [ 6. ] + x = [[ 1., 2.], [ 3., 4.], [ 5., 6.]] @@ -386,6 +389,14 @@ Examples:: [[ 3., 4.], [ 5., 6.]]] + // In this case we will get rows 0 and 1, then 1 and 2 (calculated by wrapping around). + // Along axis 1 + take(x, [[0, 3], [-1, -2]], axis=1, mode='wrap') = [[[ 1., 2.], + [ 3., 4.]], + + [[ 3., 4.], + [ 5., 6.]]] + )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -413,6 +424,7 @@ Examples:: NNVM_REGISTER_OP(_backward_take) .set_num_inputs(2) .set_num_outputs(2) +.set_attr_parser(ParamParser) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 4102d9895fba..a8f263841ccf 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -44,6 +44,9 @@ #include "./sort_op.h" #include "./init_op.h" #include "../../engine/openmp.h" +#ifdef __CUDACC__ +#include "./indexing_op-inl.cuh" +#endif namespace mxnet { namespace op { @@ -135,22 +138,6 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, dst[sorted[y]] += src[index[y]]; } } -/*! - * \brief CPU/GPU: Gradient accumulate of embedding matrix. - dst[sorted[i]] += src[index[i]] - Called when the batchsize of src is larger than the featuredim - * \param dst destination - * \param sorted the sorted indices - * \param index original index of the sorted indices - * \param src source output - * \param workspace (optional) temporary storage - */ -template -inline void AddTakeGradLargeBatch(mshadow::Tensor dst, - const mshadow::Tensor& sorted, - const mshadow::Tensor& index, - const mshadow::Tensor &src, - mshadow::Tensor* workspace = NULL); template inline bool EmbeddingOpShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, @@ -322,8 +309,13 @@ struct Take { MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx, const int M, const int K) { int j = static_cast(idx[i/M]); - if (j <= 0) j = 0; - else if (j >= K) j = K - 1; + if (clip) { + if (j <= 0) j = 0; + else if (j >= K) j = K - 1; + } else { + j = j % K; + j += (j < 0) ? K : 0; + } out_data[i] = in_data[j * M + i % M]; } @@ -353,10 +345,11 @@ struct Take { 0 : (out_rest_index % in_stride[axis]); int idx_index = static_cast(idx[out_mid_index]); if (clip) { - idx_index = (idx_index < -axis_dim) ? 0 : idx_index; + idx_index = (idx_index < 0) ? 0 : idx_index; idx_index = (idx_index > axis_dim - 1) ? (axis_dim - 1) : idx_index; } idx_index %= axis_dim; + idx_index += (idx_index < 0) ? axis_dim : 0; const int in_tail_index = out_tail_index; const int in_head_index = out_head_index; int in_src_index = in_tail_index + idx_index * in_stride[axis]; @@ -798,12 +791,12 @@ inline bool TakeOpShape(const nnvm::NodeAttrs& attrs, out_attrs->clear(); - const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); + const index_t actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); TShape oshape(idxshape.ndim() + arrshape.ndim() - 1); - for (int i = 0; i < (int)idxshape.ndim(); ++i) { + for (index_t i = 0; i < idxshape.ndim(); ++i) { oshape[i + actual_axis] = idxshape[i]; } - for (int i = 0; i < (int)arrshape.ndim(); i++) { + for (index_t i = 0; i < arrshape.ndim(); i++) { if (i < actual_axis) { oshape[i] = arrshape[i]; } else if (i > actual_axis) { @@ -847,73 +840,254 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type - mshadow::Shape<10> in_strides; - int stride = 1; - for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { - in_strides[i] = stride; - } - mshadow::Shape<10> out_strides; - stride = 1; - for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { - out_strides[i] = stride; - } - if (param.mode == take_::kClip) { - Kernel, xpu>::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - in_strides, out_strides, arrshape.ndim(), oshape.ndim(), - idxshape.ndim(), arrshape[actual_axis], actual_axis); - } else if (param.mode == take_::kWrap) { - Kernel, xpu>::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - in_strides, out_strides, arrshape.ndim(), oshape.ndim(), - idxshape.ndim(), arrshape[actual_axis], actual_axis); + if (actual_axis == 0) { + if (param.mode == take_::kClip) { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + oshape.Size()/idxshape.Size(), arrshape[0]); + } else { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + oshape.Size()/idxshape.Size(), arrshape[0]); + } + } else { + mshadow::Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + mshadow::Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + if (param.mode == take_::kClip) { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], actual_axis); + } else if (param.mode == take_::kWrap) { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], actual_axis); + } } }); }); } +struct TakeGradGeneralKernel { + /*! + * \brief Map function for general case of take grad + * \param tid global thread id + * \param arr_grad ptr to in_grad + * \param ograd ptr to out_grad + * \param src_indptr ptr to indptr to src indices + * \param original_idx ptr to original indices of the inputs + * \param in_strides strides of inputs + * \param out_strides strides of outputs + * \param in_ndims # of dims of input tensor + * \param out_ndims # of dims of output tensor + * \param idx_ndims # of dims of indices tensor + * \param axis_dim dim size of the axis dimension + * \param axis axis id + */ + template + MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, const DType* ograd, + const IType* src_indptr, const IType* original_idx, + mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides, + const int in_ndims, const int out_ndims, const int idx_ndims, + const int axis) { + const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1]; + const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1]; + const int in_mid_index = in_rest_index / in_strides[axis]; + const int in_tail_index = (axis == in_ndims - 1) ? + 0 : (in_rest_index % in_strides[axis]); + for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) { + const int out_mid_index = original_idx[i]; + int target = in_tail_index + out_mid_index * in_strides[axis]; + target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1]; + arr_grad[tid] += ograd[target]; + } + } +}; + +template +void TakeOpBackwardImpl(mshadow::Stream* s, + const OpContext& ctx, + const TBlob& arr, + const TBlob& idx, + const TBlob& ograd, + const int axis) { + using namespace mxnet_op; + using namespace mshadow; + CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; + const TShape& arrshape = arr.shape_; + const TShape& idxshape = idx.shape_; + const TShape& oshape = ograd.shape_; + MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, { + // get size of temporary storage for sort + char* temp_storage_ptr = nullptr; + int* src_indptr_ptr = nullptr; + size_t temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); + size_t original_idx_bytes = idxshape.Size() * sizeof(int); + size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int); + size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes; + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); + int* sorted_idx_ptr = reinterpret_cast(workspace.dptr_); + int* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); + src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); + temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes; + // Reset indptr to zero + Kernel::Launch(s, arrshape[axis] + 1, src_indptr_ptr); + // Fill original_idx + Kernel::Launch(s, idxshape.Size(), 1, 0, 1, kWriteTo, original_idx_ptr); + // Fill sorted_idx_ptr with unsorted copy of idx + Kernel::Launch( + s, idxshape.Size(), sorted_idx_ptr, idx.dptr()); + if (clip) { + Kernel, cpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, + 0, static_cast(arrshape[axis] - 1)); + } else { + Kernel, cpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, static_cast(arrshape[axis])); + } + Tensor original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); + Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); + int num_bits = ilog2(static_cast(idxshape.Size()) - 1); + Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); + SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits); + for (size_t i = 0; i < idxshape.Size(); ++i) { + src_indptr_ptr[sorted_idx_ptr[i] + 1] += 1; + } + for (int i = 0; i < arrshape[axis]; ++i) { + src_indptr_ptr[i + 1] += src_indptr_ptr[i]; + } + Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { + Kernel::Launch( + s, arrshape.Size(), arr.dptr(), ograd.dptr(), src_indptr_ptr, + original_idx_ptr, in_strides, out_strides, + arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis); + }); + }); +} + +#ifdef __CUDACC__ template -inline void TakeOpBackwardImpl(mshadow::Stream* s, - const OpContext& ctx, - const TBlob& arr, - const TBlob& idx, - const TBlob& ograd, - const int axis) { - return; - // CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; - // const TShape& arrshape = arr.shape_; - // const TShape& idxshape = idx.shape_; - // const TShape& oshape = ograd.shape_; - // // get size of temporary storage for sort - // size_t temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); - // size_t original_idx_bytes = idxshape.Size() * sizeof(IType); - // size_t src_indices_bytes = arrshape[actual_axis] * sizeof(IType); - // size_t workspace_bytes = src_indices_bytes + 2 * original_idx_bytes + temp_storage_bytes; - // Tensor workspace = - // ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); - // IType* sorted_idx_ptr = reinterpret_cast(workspace.dptr_); - // IType* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); - // IType* src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); - // char* temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indices_bytes; - // // Reset indptr to zero - // mxnet_op::Kernel::Launch(s, arrshape[actual_axis], src_indptr_ptr); - // // Fill original_idx - // mxnet_op::Kernel::Launch( - // s, idxshape.Size(), 1, IType(0), IType(1), kWriteTo, original_idx_ptr); - // // Fill sorted_idx_ptr with unsorted copy of idx - // mxnet_op::Kernel, xpu>::Launch( - // s, idxshape.Size(), sorted_idx_ptr, idx.dptr()); - // Tensor original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); - // Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); - // int num_bits = ilog2(static_cast(idxshape.Size()) - 1); - // SortByKey(idx.dptr, original_idx, true, &temp_storage, 0, num_bits); - // Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); - // Tensor src_indptr(src_indptr_ptr, Shape1(arrshape[actual_axis]), s); +void TakeOpBackwardImpl(mshadow::Stream* s, + const OpContext& ctx, + const TBlob& arr, + const TBlob& idx, + const TBlob& ograd, + const int axis) { + using namespace mxnet_op; + using namespace mshadow; + CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; + const TShape& arrshape = arr.shape_; + const TShape& idxshape = idx.shape_; + const TShape& oshape = ograd.shape_; + MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, { + // get size of temporary storage for sort + char* temp_storage_ptr = nullptr; + size_t scan_temp_storage_bytes = 0; + int* src_indptr_ptr = nullptr; + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, + scan_temp_storage_bytes, + src_indptr_ptr, + src_indptr_ptr, + arrshape[axis] + 1, + mshadow::Stream::GetStream(s)); + size_t sort_temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); + size_t histo_temp_storage_bytes = 0; + int* sorted_idx_ptr = nullptr; + cub::DeviceHistogram::HistogramEven(temp_storage_ptr, + histo_temp_storage_bytes, + sorted_idx_ptr, + src_indptr_ptr, + static_cast(arrshape[axis] + 1), + 0, + static_cast(arrshape[axis] + 1), + static_cast(idxshape.Size()), + mshadow::Stream::GetStream(s)); + size_t temp_storage_bytes = max(scan_temp_storage_bytes, sort_temp_storage_bytes); + temp_storage_bytes = max(temp_storage_bytes, histo_temp_storage_bytes); + size_t original_idx_bytes = idxshape.Size() * sizeof(int); + size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int); + size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes; + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); + sorted_idx_ptr = reinterpret_cast(workspace.dptr_); + int* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); + src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); + temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes; + // Reset indptr to zero + Kernel::Launch(s, arrshape[axis] + 1, src_indptr_ptr); + // Fill original_idx + Kernel::Launch( + s, idxshape.Size(), 1, static_cast(0), static_cast(1), + kWriteTo, original_idx_ptr); + // Fill sorted_idx_ptr with unsorted copy of idx + Kernel::Launch( + s, idxshape.Size(), sorted_idx_ptr, idx.dptr()); + if (clip) { + Kernel, gpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, 0, static_cast(arrshape[axis])); + } else { + Kernel, gpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, static_cast(arrshape[axis])); + } + Tensor original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); + Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); + int num_bits = ilog2(static_cast(idxshape.Size()) - 1); + Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); + SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits); + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, + temp_storage_bytes, + src_indptr_ptr, + src_indptr_ptr, + arrshape[axis] + 1, + mshadow::Stream::GetStream(s)); + + Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { + Kernel::Launch( + s, arrshape.Size(), arr.dptr(), ograd.dptr(), + src_indptr_ptr, original_idx_ptr, in_strides, out_strides, + arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis); + }); + }); } +#endif template void TakeOpBackward(const nnvm::NodeAttrs& attrs, @@ -943,7 +1117,7 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, const TShape& arrshape = outputs[0].shape_; const TShape& oshape = inputs[0].shape_; - const int actual_axis = param.axis + ((param.axis < 0) ? oshape.ndim() : 0); + const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); int idxndim = idxshape.ndim(); Tensor idx = inputs[1].get_with_shape( @@ -1445,7 +1619,5 @@ void ScatterSetNDForward(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet -#ifdef __CUDACC__ -#include "./indexing_op-inl.cuh" -#endif + #endif // MXNET_OPERATOR_TENSOR_INDEXING_OP_H_ diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7893499a556b..0e43e5d98970 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3549,6 +3549,32 @@ def test_blockgrad(): @with_seed() def test_take(): + def grad_helper(grad_in, axis, idx): + if axis == 0: + if axis == len(grad_in.shape) - 1: + grad_in[idx] += 1.0 + else: + grad_in[idx, :] += 1.0 + elif axis == 1: + if axis == len(grad_in.shape) - 1: + grad_in[:, idx] += 1.0 + else: + grad_in[:, idx, :] += 1.0 + elif axis == 2: + if axis == len(grad_in.shape) - 1: + grad_in[:, :, idx] += 1.0 + else: + grad_in[:, :, idx, :] += 1.0 + elif axis == 3: + if axis == len(grad_in.shape) - 1: + grad_in[:, :, :, idx] += 1.0 + else: + grad_in[:, :, :, idx, :] += 1.0 + elif axis == 4: + grad_in[:, :, :, :, idx] += 1.0 + else: + raise ValueError("axis %d is not supported..." % axis) + def check_output_n_grad(data_shape, idx_shape, axis, mode): data = mx.sym.Variable('a') idx = mx.sym.Variable('indices') @@ -3558,32 +3584,34 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode): indices=idx_shape, axis=axis, mode=mode) data_real = np.random.normal(size=data_shape).astype('float32') idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape) - # grad_out = np.ones(idx_shape + data_shape[1:], dtype='float32') - # grad_in = np.zeros(data_shape, dtype='float32') + if axis < 0: + axis += len(data_shape) + + grad_out = np.ones((data_shape[0:axis] if axis > 0 else ()) + idx_shape + (data_shape[axis+1:] if axis < len(data_shape) - 1 else ()), dtype='float32') + grad_in = np.zeros(data_shape, dtype='float32') exe.arg_dict['a'][:] = mx.nd.array(data_real) exe.arg_dict['indices'][:] = mx.nd.array(idx_real) exe.forward(is_train=True) assert_almost_equal(exe.outputs[0].asnumpy(), np.take(data_real, idx_real, axis=axis, mode=mode)) - # for i in np.nditer(idx_real): - # grad_in[i] += 1.0 - # - # exe.backward([mx.nd.array(grad_out)]) - # assert_almost_equal(exe.grad_dict['a'].asnumpy(), grad_in) - - - for data_ndim in range(1, 5): - for idx_ndim in range(1, 4): - for axis in range(-data_ndim, data_ndim): - data_shape = () - for _ in range(data_ndim): - data_shape += (np.random.randint(low=1, high=5), ) - idx_shape = () - for _ in range(idx_ndim): - idx_shape += (np.random.randint(low=1, high=5), ) - check_output_n_grad(data_shape, idx_shape, axis, 'clip') - check_output_n_grad(data_shape, idx_shape, axis, 'wrap') + for i in np.nditer(idx_real): + grad_helper(grad_in, axis, i) + + exe.backward([mx.nd.array(grad_out)]) + assert_almost_equal(exe.grad_dict['a'].asnumpy(), grad_in) + + for mode in ['clip', 'wrap']: + for data_ndim in range(1, 5): + for idx_ndim in range(1, 4): + for axis in range(-data_ndim, data_ndim): + data_shape = () + for _ in range(data_ndim): + data_shape += (np.random.randint(low=1, high=5), ) + idx_shape = () + for _ in range(idx_ndim): + idx_shape += (np.random.randint(low=1, high=5), ) + check_output_n_grad(data_shape, idx_shape, axis, mode) @with_seed()