diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 81a55c4a0137..7a2032df7580 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -601,6 +601,15 @@ struct clip : public mxnet_op::tunable { return x; } } + template + MSHADOW_XINLINE static DType Map(DType x, DType lower_bound, DType upper_bound) { + if (x > upper_bound) { + return upper_bound; + } else if (x < lower_bound) { + return lower_bound; + } + return x; + } }; /***** gamma ******/ diff --git a/src/operator/tensor/indexing_op-inl.cuh b/src/operator/tensor/indexing_op-inl.cuh index 67dc2bbc334c..b2f514e20cd9 100644 --- a/src/operator/tensor/indexing_op-inl.cuh +++ b/src/operator/tensor/indexing_op-inl.cuh @@ -28,6 +28,8 @@ #include #include #include "../mxnet_op.h" +#include "../mshadow_op.h" +#include "./util/tensor_util-inl.cuh" #if CUDA_VERSION >= 9000 #define FULLMASK 0xFFFFFFFF @@ -272,7 +274,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, const mshadow::Tensor& sorted, const mshadow::Tensor& index, const mshadow::Tensor &src, - mshadow::Tensor* workspace) { + mshadow::Tensor* workspace = NULL) { CHECK_EQ(dst.CheckContiguous(), true); CHECK_EQ(sorted.CheckContiguous(), true); CHECK_EQ(index.CheckContiguous(), true); diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 64c5d86cbd1c..0f96e2cc2f72 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -367,36 +367,46 @@ NNVM_REGISTER_OP(take) This function slices the input array along a particular axis with the provided indices. -Given an input array with shape ``(d0, d1, d2)`` and indices with shape ``(i0, i1)``, the output -will have shape ``(i0, i1, d1, d2)``, computed by:: - - output[i,j,:,:] = input[indices[i,j],:,:] - -.. note:: - - `axis`- Only slicing along axis 0 is supported for now. - - `mode`- Only `clip` mode is supported for now. +Given data tensor of rank r >= 1, and indices tensor of rank q, gather entries of the axis +dimension of data (by default outer-most one as axis=0) indexed by indices, and concatenates them +in an output tensor of rank q + (r - 1). Examples:: x = [4. 5. 6.] // Trivial case, take the second element along the first axis. + take(x, [1]) = [ 5. ] + // The other trivial case, axis=-1, take the third element along the first axis + + take(x, [3], axis=-1, mode='clip') = [ 6. ] + x = [[ 1., 2.], [ 3., 4.], [ 5., 6.]] // In this case we will get rows 0 and 1, then 1 and 2. Along axis 0 + take(x, [[0,1],[1,2]]) = [[[ 1., 2.], [ 3., 4.]], [[ 3., 4.], [ 5., 6.]]] + // In this case we will get rows 0 and 1, then 1 and 2 (calculated by wrapping around). + // Along axis 1 + + take(x, [[0, 3], [-1, -2]], axis=1, mode='wrap') = [[[ 1., 2.], + [ 3., 4.]], + + [[ 3., 4.], + [ 5., 6.]]] + )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) -.set_attr_parser(TakeParamParser) +.set_attr_parser(ParamParser) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"a", "indices"}; @@ -420,6 +430,7 @@ Examples:: NNVM_REGISTER_OP(_backward_take) .set_num_inputs(2) .set_num_outputs(2) +.set_attr_parser(ParamParser) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 7fb0f6bf514d..3d349c9f4292 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -44,6 +44,9 @@ #include "./sort_op.h" #include "./init_op.h" #include "../../engine/openmp.h" +#ifdef __CUDACC__ +#include "./indexing_op-inl.cuh" +#endif namespace mxnet { namespace op { @@ -135,22 +138,6 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, dst[sorted[y]] += src[index[y]]; } } -/*! - * \brief CPU/GPU: Gradient accumulate of embedding matrix. - dst[sorted[i]] += src[index[i]] - Called when the batchsize of src is larger than the featuredim - * \param dst destination - * \param sorted the sorted indices - * \param index original index of the sorted indices - * \param src source output - * \param workspace (optional) temporary storage - */ -template -inline void AddTakeGradLargeBatch(mshadow::Tensor dst, - const mshadow::Tensor& sorted, - const mshadow::Tensor& index, - const mshadow::Tensor &src, - mshadow::Tensor* workspace = NULL); template inline bool EmbeddingOpShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, @@ -311,6 +298,7 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs, /*! \brief name the struct Take instead of take * to avoid conflict with the take function in mshadow */ +template struct Take { // assume that idx have been flattened to a 1-D tensor (N,) // assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M) @@ -321,10 +309,53 @@ struct Take { MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx, const int M, const int K) { int j = static_cast(idx[i/M]); - if (j <= 0) j = 0; - else if (j >= K) j = K - 1; + if (clip) { + if (j <= 0) j = 0; + else if (j >= K) j = K - 1; + } else { + j = j % K; + j += (j < 0) ? K : 0; + } out_data[i] = in_data[j * M + i % M]; } + + /*! + * \brief Map function for take operator + * \param i global thread id + * \param out_data ptr to output buffer + * \param in_data ptr to input buffer + * \param idx ptr to indices buffer + * \param in_ndims # of dims of input tensor + * \param out_ndims # of dims of output tensor + * \param idx_ndims # of dims of indices tensor + * \param axis_dim dim size of the axis dimension + * \param axis axis id + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx, + const mshadow::Shape<10> in_stride, + const mshadow::Shape<10> out_stride, + const int in_ndims, const int out_ndims, const int idx_ndims, + const int axis_dim, const int axis) { + // i is the global flattened index in the output + const int out_head_index = (axis == 0) ? 0 : (i / out_stride[axis - 1]); + const int out_rest_index = (axis == 0) ? i : (i % out_stride[axis - 1]); + const int out_mid_index = out_rest_index / in_stride[axis]; + const int out_tail_index = (axis == in_ndims - 1) ? + 0 : (out_rest_index % in_stride[axis]); + int idx_index = static_cast(idx[out_mid_index]); + if (clip) { + idx_index = (idx_index < 0) ? 0 : idx_index; + idx_index = (idx_index > axis_dim - 1) ? (axis_dim - 1) : idx_index; + } + idx_index %= axis_dim; + idx_index += (idx_index < 0) ? axis_dim : 0; + const int in_tail_index = out_tail_index; + const int in_head_index = out_head_index; + int in_src_index = in_tail_index + idx_index * in_stride[axis]; + in_src_index += (axis == 0) ? 0 : in_head_index * in_stride[axis - 1]; + out_data[i] = in_data[in_src_index]; + } }; // Embedding forward implementation with dense weight @@ -345,7 +376,7 @@ void EmbeddingOpForwardDnsImpl(mshadow::Stream* s, Tensor wmat = weight.get(s); Tensor out = output.get_with_shape( Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); - Kernel::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_, + Kernel, xpu>::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_, idx.dptr_, wmat.shape_[1], wmat.shape_[0]); }); }); @@ -728,53 +759,52 @@ struct TakeParam: public dmlc::Parameter { int mode; DMLC_DECLARE_PARAMETER(TakeParam) { DMLC_DECLARE_FIELD(axis) - .set_lower_bound(0) .set_default(0) - .describe("The axis of input array to be taken."); + .describe("The axis of input array to be taken." + "For input tensor of rank r, it could be in the range of [-r, r-1]"); DMLC_DECLARE_FIELD(mode) .add_enum("raise", take_::kRaise) .add_enum("wrap", take_::kWrap) .add_enum("clip", take_::kClip) .set_default(take_::kClip) - .describe("Specify how out-of-bound indices bahave." + .describe("Specify how out-of-bound indices bahave. Default is \"clip\"." " \"clip\" means clip to the range. So, if all indices mentioned are too large," " they are replaced by the index that addresses the last element along an axis. " " \"wrap\" means to wrap around. " - " \"raise\" means to raise an error. "); + " \"raise\" means to raise an error, not supported yet."); } }; -template -inline void TakeParamParser(nnvm::NodeAttrs *attrs) { - PType param; - param.Init(attrs->dict); - if (param.axis != 0) { - LOG(FATAL) << "Axis other than 0 currently not supported."; - } - if (param.mode != take_::kClip) { - LOG(FATAL) << "Mode other than clip currently not supported."; - } -} - inline bool TakeOpShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { - using namespace mshadow; - const TShape &arrshape = (*in_attrs)[take_::kArr]; - const TShape &idxshape = (*in_attrs)[take_::kIdx]; - if (idxshape.ndim() == 0U || idxshape.Size() == 0U) return false; + using namespace mshadow; + const TShape &arrshape = (*in_attrs)[take_::kArr]; + const TShape &idxshape = (*in_attrs)[take_::kIdx]; + if (idxshape.ndim() == 0U || idxshape.Size() == 0U) return false; + const TakeParam& param = nnvm::get(attrs.parsed); + if (param.mode == take_::kRaise) { + LOG(FATAL) << "Raise is not supported for the time being..."; + } + CHECK(param.axis >= -1 * (int)arrshape.ndim() && param.axis < (int)arrshape.ndim()) + << "Axis should be in the range of [-r, r-1] where r is the rank of input tensor"; - out_attrs->clear(); + out_attrs->clear(); - TShape oshape(idxshape.ndim() + arrshape.ndim() - 1); - for (size_t i = 0; i < idxshape.ndim(); ++i) { - oshape[i] = idxshape[i]; - } - for (size_t i = 0; i < arrshape.ndim() - 1; i++) { - oshape[i + idxshape.ndim()] = arrshape[i + 1]; + const index_t actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); + TShape oshape(idxshape.ndim() + arrshape.ndim() - 1); + for (index_t i = 0; i < idxshape.ndim(); ++i) { + oshape[i + actual_axis] = idxshape[i]; + } + for (index_t i = 0; i < arrshape.ndim(); i++) { + if (i < actual_axis) { + oshape[i] = arrshape[i]; + } else if (i > actual_axis) { + oshape[i + idxshape.ndim() - 1] = arrshape[i]; } - out_attrs->push_back(oshape); - return true; + } + out_attrs->push_back(oshape); + return true; } inline bool TakeOpType(const nnvm::NodeAttrs& attrs, @@ -797,6 +827,7 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mxnet_op; if (req[take_::kOut] == kNullOp) return; + const TakeParam& param = nnvm::get(attrs.parsed); CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); @@ -805,17 +836,258 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, const TShape& oshape = outputs[take_::kOut].shape_; Stream *s = ctx.get_stream(); + const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type - Kernel::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - oshape.Size()/idxshape.Size(), arrshape[0]); + if (actual_axis == 0) { + if (param.mode == take_::kClip) { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + oshape.Size()/idxshape.Size(), arrshape[0]); + } else { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + oshape.Size()/idxshape.Size(), arrshape[0]); + } + } else { + mshadow::Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + mshadow::Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + if (param.mode == take_::kClip) { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], actual_axis); + } else if (param.mode == take_::kWrap) { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], actual_axis); + } + } + }); + }); +} + +struct TakeGradGeneralKernel { + /*! + * \brief Map function for general case of take grad + * \param tid global thread id + * \param arr_grad ptr to in_grad + * \param ograd ptr to out_grad + * \param src_indptr ptr to indptr to src indices + * \param original_idx ptr to original indices of the inputs + * \param in_strides strides of inputs + * \param out_strides strides of outputs + * \param in_ndims # of dims of input tensor + * \param out_ndims # of dims of output tensor + * \param idx_ndims # of dims of indices tensor + * \param axis_dim dim size of the axis dimension + * \param axis axis id + */ + template + MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, const DType* ograd, + const IType* src_indptr, const IType* original_idx, + mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides, + const int in_ndims, const int out_ndims, const int idx_ndims, + const int axis) { + const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1]; + const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1]; + const int in_mid_index = in_rest_index / in_strides[axis]; + const int in_tail_index = (axis == in_ndims - 1) ? + 0 : (in_rest_index % in_strides[axis]); + for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) { + const int out_mid_index = original_idx[i]; + int target = in_tail_index + out_mid_index * in_strides[axis]; + target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1]; + arr_grad[tid] += ograd[target]; + } + } +}; + +template +void TakeOpBackwardImpl(mshadow::Stream* s, + const OpContext& ctx, + const TBlob& arr, + const TBlob& idx, + const TBlob& ograd, + const int axis) { + using namespace mxnet_op; + using namespace mshadow; + CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; + const TShape& arrshape = arr.shape_; + const TShape& idxshape = idx.shape_; + const TShape& oshape = ograd.shape_; + MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, { + // get size of temporary storage for sort + int* src_indptr_ptr = nullptr; + size_t temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); + size_t original_idx_bytes = idxshape.Size() * sizeof(int); + size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int); + size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes; + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); + int* sorted_idx_ptr = reinterpret_cast(workspace.dptr_); + int* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); + src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); + Tensor temp_storage( + workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes, Shape1(temp_storage_bytes), s); + // Reset indptr to zero + Kernel::Launch(s, arrshape[axis] + 1, src_indptr_ptr); + // Fill original_idx + Kernel::Launch(s, idxshape.Size(), 1, 0, 1, kWriteTo, original_idx_ptr); + // Fill sorted_idx_ptr with unsorted copy of idx + Kernel::Launch( + s, idxshape.Size(), sorted_idx_ptr, idx.dptr()); + if (clip) { + Kernel, cpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, + 0, static_cast(arrshape[axis] - 1)); + } else { + Kernel, cpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, static_cast(arrshape[axis])); + } + Tensor original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); + int num_bits = ilog2(static_cast(idxshape.Size()) - 1); + Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); + SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits); + for (size_t i = 0; i < idxshape.Size(); ++i) { + src_indptr_ptr[sorted_idx_ptr[i] + 1] += 1; + } + for (int i = 0; i < arrshape[axis]; ++i) { + src_indptr_ptr[i + 1] += src_indptr_ptr[i]; + } + Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { + Kernel::Launch( + s, arrshape.Size(), arr.dptr(), ograd.dptr(), src_indptr_ptr, + original_idx_ptr, in_strides, out_strides, + arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis); }); }); } +#ifdef __CUDACC__ +template +void TakeOpBackwardImpl(mshadow::Stream* s, + const OpContext& ctx, + const TBlob& arr, + const TBlob& idx, + const TBlob& ograd, + const int axis) { + using namespace mxnet_op; + using namespace mshadow; + CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; + const TShape& arrshape = arr.shape_; + const TShape& idxshape = idx.shape_; + const TShape& oshape = ograd.shape_; + MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, { + // get size of temporary storage for sort + char* temp_storage_ptr = nullptr; + size_t scan_temp_storage_bytes = 0; + int* src_indptr_ptr = nullptr; + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, + scan_temp_storage_bytes, + src_indptr_ptr, + src_indptr_ptr, + arrshape[axis] + 1, + mshadow::Stream::GetStream(s)); + size_t sort_temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); + size_t histo_temp_storage_bytes = 0; + int* sorted_idx_ptr = nullptr; + cub::DeviceHistogram::HistogramEven(temp_storage_ptr, + histo_temp_storage_bytes, + sorted_idx_ptr, + src_indptr_ptr, + static_cast(arrshape[axis] + 1), + 0, + static_cast(arrshape[axis] + 1), + static_cast(idxshape.Size()), + mshadow::Stream::GetStream(s)); + size_t temp_storage_bytes = max(scan_temp_storage_bytes, sort_temp_storage_bytes); + temp_storage_bytes = max(temp_storage_bytes, histo_temp_storage_bytes); + size_t original_idx_bytes = idxshape.Size() * sizeof(int); + size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int); + size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes; + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); + sorted_idx_ptr = reinterpret_cast(workspace.dptr_); + int* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); + src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); + temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes; + // Reset indptr to zero + Kernel::Launch(s, arrshape[axis] + 1, src_indptr_ptr); + // Fill original_idx + Kernel::Launch( + s, idxshape.Size(), 1, static_cast(0), static_cast(1), + kWriteTo, original_idx_ptr); + // Fill sorted_idx_ptr with unsorted copy of idx + Kernel::Launch( + s, idxshape.Size(), sorted_idx_ptr, idx.dptr()); + if (clip) { + Kernel, gpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, 0, static_cast(arrshape[axis])); + } else { + Kernel, gpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, static_cast(arrshape[axis])); + } + Tensor original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); + Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); + int num_bits = ilog2(static_cast(idxshape.Size()) - 1); + Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); + SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits); + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, + temp_storage_bytes, + src_indptr_ptr, + src_indptr_ptr, + arrshape[axis] + 1, + mshadow::Stream::GetStream(s)); + + Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { + Kernel::Launch( + s, arrshape.Size(), arr.dptr(), ograd.dptr(), + src_indptr_ptr, original_idx_ptr, in_strides, out_strides, + arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis); + }); + }); +} +#endif + template void TakeOpBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -829,20 +1101,24 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, CHECK_EQ(req[take_::kIdx], kNullOp) << "take layer doesn't support gradient into index"; - // inputs are specified in the .cc file, which are the gradients from - // the upper layer and the input index - // outputs are the gradients of inputs in the feed-forward pass - const TShape& idxshape = inputs[1].shape_; - const TShape& arrshape = outputs[0].shape_; - const TShape& oshape = inputs[0].shape_; - - int idxndim = idxshape.ndim(); + const TakeParam& param = nnvm::get(attrs.parsed); // grad_out is the gradient of the outputs in the feed-forward // grad_in is the gradient of the inputs in the feed-forward Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type + // inputs are specified in the .cc file, which are the gradients from + // the upper layer and the input index + // outputs are the gradients of inputs in the feed-forward pass + const TShape& idxshape = inputs[1].shape_; + const TShape& arrshape = outputs[0].shape_; + const TShape& oshape = inputs[0].shape_; + + const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); + + int idxndim = idxshape.ndim(); Tensor idx = inputs[1].get_with_shape( Shape1(idxshape.ProdShape(0, idxndim)), s); Tensor grad_out = inputs[0].get_with_shape( @@ -850,27 +1126,41 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, Tensor grad_in = outputs[0].get_with_shape( Shape2(arrshape[0], arrshape.ProdShape(1, arrshape.ndim())), s); - if (req[take_::kArr] == kWriteTo || req[take_::kArr] == kAddTo) { - if (req[take_::kArr] == kWriteTo) { - grad_in = scalar(0.0f); - } - // shape_out_prod ~= the number of elements loaded in AddTakeGrad - // shape_in_prod ~= the number of elements stored in AddTakeGrad - // When the number of elements processed is low, use AddTakeGrad. - // The approximate cut-off value 16384 was found experimentally on Titan X Pascal - uint64_t shape_in_prod = - static_cast(grad_in.shape_[0])* - static_cast(grad_in.shape_[1]); - uint64_t shape_out_prod = - static_cast(grad_out.shape_[0])* - static_cast(grad_out.shape_[1]); - if (shape_out_prod < (uint64_t)16384 && shape_in_prod < (uint64_t)16384) { - AddTakeGrad(grad_in, idx, grad_out); + // re-using the previous code for axis = 0 case + if (actual_axis == 0) { + if (req[take_::kArr] == kWriteTo || req[take_::kArr] == kAddTo) { + if (req[take_::kArr] == kWriteTo) { + grad_in = scalar(0.0f); + } + // shape_out_prod ~= the number of elements loaded in AddTakeGrad + // shape_in_prod ~= the number of elements stored in AddTakeGrad + // When the number of elements processed is low, use AddTakeGrad. + // The approximate cut-off value 16384 was found experimentally on Titan X Pascal + uint64_t shape_in_prod = + static_cast(grad_in.shape_[0])* + static_cast(grad_in.shape_[1]); + uint64_t shape_out_prod = + static_cast(grad_out.shape_[0])* + static_cast(grad_out.shape_[1]); + if (shape_out_prod < (uint64_t)16384 && shape_in_prod < (uint64_t)16384) { + AddTakeGrad(grad_in, idx, grad_out); + } else { + AddTakeGradLargeBatchCaller(ctx, grad_in, idx, grad_out); + } } else { - AddTakeGradLargeBatchCaller(ctx, grad_in, idx, grad_out); + LOG(FATAL) << "wrong req"; } + // for all other cases } else { - LOG(FATAL) << "wrong req"; + const TBlob& idx = inputs[1]; + const TBlob& arr = outputs[0]; + const TBlob& ograd = inputs[0]; + + if (param.mode == take_::kClip) { + TakeOpBackwardImpl(s, ctx, arr, idx, ograd, actual_axis); + } else { + TakeOpBackwardImpl(s, ctx, arr, idx, ograd, actual_axis); + } } }); }); @@ -1328,7 +1618,5 @@ void ScatterSetNDForward(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet -#ifdef __CUDACC__ -#include "./indexing_op-inl.cuh" -#endif + #endif // MXNET_OPERATOR_TENSOR_INDEXING_OP_H_ diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 6c6ff310519d..814266ad9aa3 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3650,39 +3650,69 @@ def test_blockgrad(): @with_seed() def test_take(): - def check_output_n_grad(data_shape, idx_shape): + def grad_helper(grad_in, axis, idx): + if axis == 0: + if axis == len(grad_in.shape) - 1: + grad_in[idx] += 1.0 + else: + grad_in[idx, :] += 1.0 + elif axis == 1: + if axis == len(grad_in.shape) - 1: + grad_in[:, idx] += 1.0 + else: + grad_in[:, idx, :] += 1.0 + elif axis == 2: + if axis == len(grad_in.shape) - 1: + grad_in[:, :, idx] += 1.0 + else: + grad_in[:, :, idx, :] += 1.0 + elif axis == 3: + if axis == len(grad_in.shape) - 1: + grad_in[:, :, :, idx] += 1.0 + else: + grad_in[:, :, :, idx, :] += 1.0 + elif axis == 4: + grad_in[:, :, :, :, idx] += 1.0 + else: + raise ValueError("axis %d is not supported..." % axis) + + def check_output_n_grad(data_shape, idx_shape, axis, mode): + data = mx.sym.Variable('a') + idx = mx.sym.Variable('indices') + idx = mx.sym.BlockGrad(idx) + result = mx.sym.take(a=data, indices=idx, axis=axis, mode=mode) exe = result.simple_bind(default_context(), a=data_shape, - indices=idx_shape) + indices=idx_shape, axis=axis, mode=mode) data_real = np.random.normal(size=data_shape).astype('float32') - idx_real = np.random.randint(low=0, high=data_shape[0], size=idx_shape) - grad_out = np.ones(idx_shape + data_shape[1:], dtype='float32') + idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape) + if axis < 0: + axis += len(data_shape) + + grad_out = np.ones((data_shape[0:axis] if axis > 0 else ()) + idx_shape + (data_shape[axis+1:] if axis < len(data_shape) - 1 else ()), dtype='float32') grad_in = np.zeros(data_shape, dtype='float32') exe.arg_dict['a'][:] = mx.nd.array(data_real) exe.arg_dict['indices'][:] = mx.nd.array(idx_real) exe.forward(is_train=True) - assert_almost_equal(exe.outputs[0].asnumpy(), data_real[idx_real]) + assert_almost_equal(exe.outputs[0].asnumpy(), np.take(data_real, idx_real, axis=axis, mode=mode)) for i in np.nditer(idx_real): - grad_in[i] += 1.0 + grad_helper(grad_in, axis, i) exe.backward([mx.nd.array(grad_out)]) assert_almost_equal(exe.grad_dict['a'].asnumpy(), grad_in) - data = mx.sym.Variable('a') - idx = mx.sym.Variable('indices') - idx = mx.sym.BlockGrad(idx) - result = mx.sym.take(a=data, indices=idx) - - for data_ndim in range(2, 5): - for idx_ndim in range(1, 4): - data_shape = () - for _ in range(data_ndim): - data_shape += (np.random.randint(low=3, high=6), ) - idx_shape = () - for _ in range(idx_ndim): - idx_shape += (np.random.randint(low=3, high=5), ) - check_output_n_grad(data_shape, idx_shape) + for mode in ['clip', 'wrap']: + for data_ndim in range(1, 5): + for idx_ndim in range(1, 4): + for axis in range(-data_ndim, data_ndim): + data_shape = () + for _ in range(data_ndim): + data_shape += (np.random.randint(low=1, high=5), ) + idx_shape = () + for _ in range(idx_ndim): + idx_shape += (np.random.randint(low=1, high=5), ) + check_output_n_grad(data_shape, idx_shape, axis, mode) @with_seed()