From b82991375cba63082038ba100f491a7a9ee1513d Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Mon, 11 Jun 2018 05:58:09 +0000 Subject: [PATCH 1/3] take forward for any axis with enhanced test --- src/operator/tensor/indexing_op.cc | 2 +- src/operator/tensor/indexing_op.h | 126 ++++++++++++++++++------- tests/python/unittest/test_operator.py | 48 +++++----- 3 files changed, 119 insertions(+), 57 deletions(-) diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 64c5d86cbd1c..8c2416e89ab6 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -396,7 +396,7 @@ Examples:: )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) -.set_attr_parser(TakeParamParser) +.set_attr_parser(ParamParser) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"a", "indices"}; diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 5f9e59dfa538..0b028b975bbf 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -311,6 +311,7 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs, /*! \brief name the struct Take instead of take * to avoid conflict with the take function in mshadow */ +template struct Take { // assume that idx have been flattened to a 1-D tensor (N,) // assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M) @@ -325,6 +326,43 @@ struct Take { else if (j >= K) j = K - 1; out_data[i] = in_data[j * M + i % M]; } + + /*! + * \brief Map function for take operator + * \param i global thread id + * \param out_data ptr to output buffer + * \param in_data ptr to input buffer + * \param idx ptr to indices buffer + * \param in_ndims # of dims of input tensor + * \param out_ndims # of dims of output tensor + * \param idx_ndims # of dims of indices tensor + * \param axis_dim dim size of the axis dimension + * \param axis axis id + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx, + const mshadow::Shape<10> in_stride, + const mshadow::Shape<10> out_stride, + const int in_ndims, const int out_ndims, const int idx_ndims, + const int axis_dim, const int axis) { + // i is the global flattened index in the output + const int out_head_index = (axis == 0) ? 0 : (i / out_stride[axis - 1]); + const int out_rest_index = (axis == 0) ? i : (i % out_stride[axis - 1]); + const int out_mid_index = out_rest_index / in_stride[axis]; + const int out_tail_index = (axis == in_ndims - 1) ? + 0 : (out_rest_index % in_stride[axis]); + int idx_index = static_cast(idx[out_mid_index]); + if (clip) { + idx_index = (idx_index < -axis_dim) ? 0 : idx_index; + idx_index = (idx_index > axis_dim - 1) ? (axis_dim - 1) : idx_index; + } + idx_index %= axis_dim; + const int in_tail_index = out_tail_index; + const int in_head_index = out_head_index; + int in_src_index = in_tail_index + idx_index * in_stride[axis]; + in_src_index += (axis == 0) ? 0 : in_head_index * in_stride[axis - 1]; + out_data[i] = in_data[in_src_index]; + } }; // Embedding forward implementation with dense weight @@ -345,7 +383,7 @@ void EmbeddingOpForwardDnsImpl(mshadow::Stream* s, Tensor wmat = weight.get(s); Tensor out = output.get_with_shape( Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); - Kernel::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_, + Kernel, xpu>::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_, idx.dptr_, wmat.shape_[1], wmat.shape_[0]); }); }); @@ -728,9 +766,9 @@ struct TakeParam: public dmlc::Parameter { int mode; DMLC_DECLARE_PARAMETER(TakeParam) { DMLC_DECLARE_FIELD(axis) - .set_lower_bound(0) .set_default(0) - .describe("The axis of input array to be taken."); + .describe("The axis of input array to be taken." + "For input tensor of rank r, it could be in the range of [-r, r-1]"); DMLC_DECLARE_FIELD(mode) .add_enum("raise", take_::kRaise) .add_enum("wrap", take_::kWrap) @@ -744,37 +782,36 @@ struct TakeParam: public dmlc::Parameter { } }; -template -inline void TakeParamParser(nnvm::NodeAttrs *attrs) { - PType param; - param.Init(attrs->dict); - if (param.axis != 0) { - LOG(FATAL) << "Axis other than 0 currently not supported."; - } - if (param.mode != take_::kClip) { - LOG(FATAL) << "Mode other than clip currently not supported."; - } -} - inline bool TakeOpShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { - using namespace mshadow; - const TShape &arrshape = (*in_attrs)[take_::kArr]; - const TShape &idxshape = (*in_attrs)[take_::kIdx]; - if (idxshape.ndim() == 0U || idxshape.Size() == 0U) return false; + using namespace mshadow; + const TShape &arrshape = (*in_attrs)[take_::kArr]; + const TShape &idxshape = (*in_attrs)[take_::kIdx]; + if (idxshape.ndim() == 0U || idxshape.Size() == 0U) return false; + const TakeParam& param = nnvm::get(attrs.parsed); + if (param.mode == take_::kRaise) { + LOG(FATAL) << "Raise is not supported for the time being..."; + } + CHECK(param.axis >= -1 * (int)arrshape.ndim() && param.axis < (int)arrshape.ndim()) + << "Axis should be in the range of [-r, r-1] where r is the rank of input tensor"; - out_attrs->clear(); + out_attrs->clear(); - TShape oshape(idxshape.ndim() + arrshape.ndim() - 1); - for (size_t i = 0; i < idxshape.ndim(); ++i) { - oshape[i] = idxshape[i]; - } - for (size_t i = 0; i < arrshape.ndim() - 1; i++) { - oshape[i + idxshape.ndim()] = arrshape[i + 1]; + const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); + TShape oshape(idxshape.ndim() + arrshape.ndim() - 1); + for (int i = 0; i < (int)idxshape.ndim(); ++i) { + oshape[i + actual_axis] = idxshape[i]; + } + for (int i = 0; i < (int)arrshape.ndim(); i++) { + if (i < actual_axis) { + oshape[i] = arrshape[i]; + } else if (i > actual_axis) { + oshape[i + idxshape.ndim() - 1] = arrshape[i]; } - out_attrs->push_back(oshape); - return true; + } + out_attrs->push_back(oshape); + return true; } inline bool TakeOpType(const nnvm::NodeAttrs& attrs, @@ -797,6 +834,7 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { using namespace mxnet_op; if (req[take_::kOut] == kNullOp) return; + const TakeParam& param = nnvm::get(attrs.parsed); CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); @@ -805,13 +843,35 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, const TShape& oshape = outputs[take_::kOut].shape_; Stream *s = ctx.get_stream(); + const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type - Kernel::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - oshape.Size()/idxshape.Size(), arrshape[0]); + mshadow::Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + mshadow::Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + if (param.mode == take_::kClip) { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], actual_axis); + } else if (param.mode == take_::kWrap) { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], actual_axis); + } }); }); } diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index f9dde2e6d245..d7721f6aa2b0 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3652,39 +3652,41 @@ def test_blockgrad(): @with_seed() def test_take(): - def check_output_n_grad(data_shape, idx_shape): + def check_output_n_grad(data_shape, idx_shape, axis, mode): + data = mx.sym.Variable('a') + idx = mx.sym.Variable('indices') + idx = mx.sym.BlockGrad(idx) + result = mx.sym.take(a=data, indices=idx, axis=axis, mode=mode) exe = result.simple_bind(default_context(), a=data_shape, - indices=idx_shape) + indices=idx_shape, axis=axis, mode=mode) data_real = np.random.normal(size=data_shape).astype('float32') - idx_real = np.random.randint(low=0, high=data_shape[0], size=idx_shape) - grad_out = np.ones(idx_shape + data_shape[1:], dtype='float32') - grad_in = np.zeros(data_shape, dtype='float32') + idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape) + # grad_out = np.ones(idx_shape + data_shape[1:], dtype='float32') + # grad_in = np.zeros(data_shape, dtype='float32') exe.arg_dict['a'][:] = mx.nd.array(data_real) exe.arg_dict['indices'][:] = mx.nd.array(idx_real) exe.forward(is_train=True) - assert_almost_equal(exe.outputs[0].asnumpy(), data_real[idx_real]) + assert_almost_equal(exe.outputs[0].asnumpy(), np.take(data_real, idx_real, axis=axis, mode=mode)) - for i in np.nditer(idx_real): - grad_in[i] += 1.0 + # for i in np.nditer(idx_real): + # grad_in[i] += 1.0 + # + # exe.backward([mx.nd.array(grad_out)]) + # assert_almost_equal(exe.grad_dict['a'].asnumpy(), grad_in) - exe.backward([mx.nd.array(grad_out)]) - assert_almost_equal(exe.grad_dict['a'].asnumpy(), grad_in) - data = mx.sym.Variable('a') - idx = mx.sym.Variable('indices') - idx = mx.sym.BlockGrad(idx) - result = mx.sym.take(a=data, indices=idx) - - for data_ndim in range(2, 5): + for data_ndim in range(1, 5): for idx_ndim in range(1, 4): - data_shape = () - for _ in range(data_ndim): - data_shape += (np.random.randint(low=3, high=6), ) - idx_shape = () - for _ in range(idx_ndim): - idx_shape += (np.random.randint(low=3, high=5), ) - check_output_n_grad(data_shape, idx_shape) + for axis in range(-data_ndim, data_ndim): + data_shape = () + for _ in range(data_ndim): + data_shape += (np.random.randint(low=1, high=5), ) + idx_shape = () + for _ in range(idx_ndim): + idx_shape += (np.random.randint(low=1, high=5), ) + check_output_n_grad(data_shape, idx_shape, axis, 'clip') + check_output_n_grad(data_shape, idx_shape, axis, 'wrap') @with_seed() From 41be3c187ccabfe31ada3ac5ad00f50f7fb9c377 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Thu, 14 Jun 2018 17:08:12 +0000 Subject: [PATCH 2/3] general take backward on gpu --- src/operator/mshadow_op.h | 10 +++ src/operator/tensor/indexing_op-inl.cuh | 103 +++++++++++++++++++++- src/operator/tensor/indexing_op.h | 109 ++++++++++++++++++------ 3 files changed, 195 insertions(+), 27 deletions(-) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 81a55c4a0137..12f1a99e6dc1 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -601,6 +601,16 @@ struct clip : public mxnet_op::tunable { return x; } } + template + MSHADOW_XINLINE static DType Map(DType x, DType upper_bound, DType lower_bound) { + DType ret = x; + if (x > upper_bound) { + return upper_bound; + } else if (x < lower_bound) { + return lower_bound; + } + return x; + } }; /***** gamma ******/ diff --git a/src/operator/tensor/indexing_op-inl.cuh b/src/operator/tensor/indexing_op-inl.cuh index 67dc2bbc334c..7d3479d0ee6c 100644 --- a/src/operator/tensor/indexing_op-inl.cuh +++ b/src/operator/tensor/indexing_op-inl.cuh @@ -294,7 +294,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, (NULL, encode_bytes, NULL, NULL, NULL, NULL, sorted.size(0), stream); size_t exclusivesum_bytes = 0; cub::DeviceScan::ExclusiveSum - (NULL, exclusivesum_bytes, NULL, NULL, sorted.size(0), stream); + (NULL, exclusivesum_bytes, NULL, NUsrc_indices_bytesLL, sorted.size(0), stream); size_t temporary_bytes = std::max(encode_bytes, exclusivesum_bytes); // Check that we have enough storage @@ -321,6 +321,107 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, num_runs_ptr, dst.size(0)); } +template +struct TakeGradGeneralKernel { + template + MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, const DType* ograd, + const IType* src_indptr, const IType* original_idx, + mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides, + const int in_ndims, const int out_ndims, const int idx_ndims, const int axis) { + const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1]; + const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1]; + const int in_mid_index = in_rest_index / in_stride[axis]; + const int in_tail_index = (axis == in_ndims - 1) ? + 0 : (in_rest_index % in_stride[axis]); + for (int i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) { + const int out_mid_index = original_idx[i]; + int target = in_tail_index + out_mid_index * out_stride[axis + idx_ndims - 1]; + target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1]; + arr_grad[tid] += ograd[target]; + } + } +} + +template +inline void TakeOpBackwardImpl(mshadow::Stream* s, + const OpContext& ctx, + const TBlob& arr, + const TBlob& idx, + const TBlob& ograd, + const int axis) { + using namespace mxnet_op; + using namespace mshadow; + CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; + const TShape& arrshape = arr.shape_; + const TShape& idxshape = idx.shape_; + const TShape& oshape = ograd.shape_; + // get size of temporary storage for sort + char* temp_storage_ptr = nullptr; + size_t scan_temp_storage_bytes = 0; + IType* src_indptr_bytes = nullptr; + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, + scan_temp_storage_bytes, + src_indptr_bytes, + src_indptr_bytes, + arrshape[axis] + 1, + mshadow::Stream::GetStream(s)); + size_t sort_temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); + size_t temp_storage_bytes = max(scan_temp_storage_bytes, sort_temp_storage_bytes); + size_t original_idx_bytes = idxshape.Size() * sizeof(IType); + size_t src_indptr_bytes = (arrshape[actual_axis] + 1) * sizeof(IType); + size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes; + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); + IType* sorted_idx_ptr = reinterpret_cast(workspace.dptr_); + IType* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); + src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); + char* temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes; + // Reset indptr to zero + Kernel::Launch(s, arrshape[actual_axis] + 1, src_indptr_ptr); + // Fill original_idx + Kernel::Launch( + s, idxshape.Size(), 1, IType(0), IType(1), kWriteTo, original_idx_ptr); + // Fill sorted_idx_ptr with unsorted copy of idx + Kernel, gpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, idx.dptr()); + if (clip) { + Kernel, gpu>::Launch(s, idxshape.Size(), sorted_idx_ptr, + sorted_idx_ptr, IType(0), IType(arrshape[axis])); + } else { + Kernel, gpu>::Launch(s, idxshape.Size(), sorted_idx_ptr, + sorted_idx_ptr, IType(arrshape[axis])); + } + Tensor original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); + Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); + int num_bits = ilog2(static_cast(idxshape.Size()) - 1); + Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); + SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits); + Kernel::Launch( + s, idxshape.Size(), src_indptr_ptr, idx.dptr(), idxshape.Size()); + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, + temp_storage_bytes, + src_indptr_bytes, + src_indptr_bytes, + arrshape[actual_axis] + 1, + mshadow::Stream::GetStream(s)); + + Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i > 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i > 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { + Kernel::Launch( + s, arrshape.Size(), arr.dptr(), ograd.dptr(), src_indptr_ptr, original_idx_ptr, + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), idxshape.ndim(), actual_axis); + }); +} + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_TENSOR_INDEXING_OP_CUH_ diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 0b028b975bbf..4102d9895fba 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -876,6 +876,45 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, }); } +template +inline void TakeOpBackwardImpl(mshadow::Stream* s, + const OpContext& ctx, + const TBlob& arr, + const TBlob& idx, + const TBlob& ograd, + const int axis) { + return; + // CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; + // const TShape& arrshape = arr.shape_; + // const TShape& idxshape = idx.shape_; + // const TShape& oshape = ograd.shape_; + // // get size of temporary storage for sort + // size_t temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); + // size_t original_idx_bytes = idxshape.Size() * sizeof(IType); + // size_t src_indices_bytes = arrshape[actual_axis] * sizeof(IType); + // size_t workspace_bytes = src_indices_bytes + 2 * original_idx_bytes + temp_storage_bytes; + // Tensor workspace = + // ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); + // IType* sorted_idx_ptr = reinterpret_cast(workspace.dptr_); + // IType* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); + // IType* src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); + // char* temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indices_bytes; + // // Reset indptr to zero + // mxnet_op::Kernel::Launch(s, arrshape[actual_axis], src_indptr_ptr); + // // Fill original_idx + // mxnet_op::Kernel::Launch( + // s, idxshape.Size(), 1, IType(0), IType(1), kWriteTo, original_idx_ptr); + // // Fill sorted_idx_ptr with unsorted copy of idx + // mxnet_op::Kernel, xpu>::Launch( + // s, idxshape.Size(), sorted_idx_ptr, idx.dptr()); + // Tensor original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); + // Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); + // int num_bits = ilog2(static_cast(idxshape.Size()) - 1); + // SortByKey(idx.dptr, original_idx, true, &temp_storage, 0, num_bits); + // Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); + // Tensor src_indptr(src_indptr_ptr, Shape1(arrshape[actual_axis]), s); +} + template void TakeOpBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -889,20 +928,24 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, CHECK_EQ(req[take_::kIdx], kNullOp) << "take layer doesn't support gradient into index"; - // inputs are specified in the .cc file, which are the gradients from - // the upper layer and the input index - // outputs are the gradients of inputs in the feed-forward pass - const TShape& idxshape = inputs[1].shape_; - const TShape& arrshape = outputs[0].shape_; - const TShape& oshape = inputs[0].shape_; - - int idxndim = idxshape.ndim(); + const TakeParam& param = nnvm::get(attrs.parsed); // grad_out is the gradient of the outputs in the feed-forward // grad_in is the gradient of the inputs in the feed-forward Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type + // inputs are specified in the .cc file, which are the gradients from + // the upper layer and the input index + // outputs are the gradients of inputs in the feed-forward pass + const TShape& idxshape = inputs[1].shape_; + const TShape& arrshape = outputs[0].shape_; + const TShape& oshape = inputs[0].shape_; + + const int actual_axis = param.axis + ((param.axis < 0) ? oshape.ndim() : 0); + + int idxndim = idxshape.ndim(); Tensor idx = inputs[1].get_with_shape( Shape1(idxshape.ProdShape(0, idxndim)), s); Tensor grad_out = inputs[0].get_with_shape( @@ -910,27 +953,41 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, Tensor grad_in = outputs[0].get_with_shape( Shape2(arrshape[0], arrshape.ProdShape(1, arrshape.ndim())), s); - if (req[take_::kArr] == kWriteTo || req[take_::kArr] == kAddTo) { - if (req[take_::kArr] == kWriteTo) { - grad_in = scalar(0.0f); - } - // shape_out_prod ~= the number of elements loaded in AddTakeGrad - // shape_in_prod ~= the number of elements stored in AddTakeGrad - // When the number of elements processed is low, use AddTakeGrad. - // The approximate cut-off value 16384 was found experimentally on Titan X Pascal - uint64_t shape_in_prod = - static_cast(grad_in.shape_[0])* - static_cast(grad_in.shape_[1]); - uint64_t shape_out_prod = - static_cast(grad_out.shape_[0])* - static_cast(grad_out.shape_[1]); - if (shape_out_prod < (uint64_t)16384 && shape_in_prod < (uint64_t)16384) { - AddTakeGrad(grad_in, idx, grad_out); + // re-using the previous code for axis = 0 case + if (actual_axis == 0) { + if (req[take_::kArr] == kWriteTo || req[take_::kArr] == kAddTo) { + if (req[take_::kArr] == kWriteTo) { + grad_in = scalar(0.0f); + } + // shape_out_prod ~= the number of elements loaded in AddTakeGrad + // shape_in_prod ~= the number of elements stored in AddTakeGrad + // When the number of elements processed is low, use AddTakeGrad. + // The approximate cut-off value 16384 was found experimentally on Titan X Pascal + uint64_t shape_in_prod = + static_cast(grad_in.shape_[0])* + static_cast(grad_in.shape_[1]); + uint64_t shape_out_prod = + static_cast(grad_out.shape_[0])* + static_cast(grad_out.shape_[1]); + if (shape_out_prod < (uint64_t)16384 && shape_in_prod < (uint64_t)16384) { + AddTakeGrad(grad_in, idx, grad_out); + } else { + AddTakeGradLargeBatchCaller(ctx, grad_in, idx, grad_out); + } } else { - AddTakeGradLargeBatchCaller(ctx, grad_in, idx, grad_out); + LOG(FATAL) << "wrong req"; } + // for all other cases } else { - LOG(FATAL) << "wrong req"; + const TBlob& idx = inputs[1]; + const TBlob& arr = outputs[0]; + const TBlob& ograd = inputs[0]; + + if (param.mode == take_::kClip) { + TakeOpBackwardImpl(s, ctx, arr, idx, ograd, actual_axis); + } else { + TakeOpBackwardImpl(s, ctx, arr, idx, ograd, actual_axis); + } } }); }); From 5268cf24a173c1122fa6b7720c2e404c8bf928cf Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Mon, 18 Jun 2018 17:31:12 +0000 Subject: [PATCH 3/3] backward of enhanced take op --- src/operator/mshadow_op.h | 3 +- src/operator/tensor/indexing_op-inl.cuh | 107 +------- src/operator/tensor/indexing_op.cc | 27 +- src/operator/tensor/indexing_op.h | 347 ++++++++++++++++++------ tests/python/unittest/test_operator.py | 68 +++-- 5 files changed, 331 insertions(+), 221 deletions(-) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 12f1a99e6dc1..7a2032df7580 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -602,8 +602,7 @@ struct clip : public mxnet_op::tunable { } } template - MSHADOW_XINLINE static DType Map(DType x, DType upper_bound, DType lower_bound) { - DType ret = x; + MSHADOW_XINLINE static DType Map(DType x, DType lower_bound, DType upper_bound) { if (x > upper_bound) { return upper_bound; } else if (x < lower_bound) { diff --git a/src/operator/tensor/indexing_op-inl.cuh b/src/operator/tensor/indexing_op-inl.cuh index 7d3479d0ee6c..b2f514e20cd9 100644 --- a/src/operator/tensor/indexing_op-inl.cuh +++ b/src/operator/tensor/indexing_op-inl.cuh @@ -28,6 +28,8 @@ #include #include #include "../mxnet_op.h" +#include "../mshadow_op.h" +#include "./util/tensor_util-inl.cuh" #if CUDA_VERSION >= 9000 #define FULLMASK 0xFFFFFFFF @@ -272,7 +274,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, const mshadow::Tensor& sorted, const mshadow::Tensor& index, const mshadow::Tensor &src, - mshadow::Tensor* workspace) { + mshadow::Tensor* workspace = NULL) { CHECK_EQ(dst.CheckContiguous(), true); CHECK_EQ(sorted.CheckContiguous(), true); CHECK_EQ(index.CheckContiguous(), true); @@ -294,7 +296,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, (NULL, encode_bytes, NULL, NULL, NULL, NULL, sorted.size(0), stream); size_t exclusivesum_bytes = 0; cub::DeviceScan::ExclusiveSum - (NULL, exclusivesum_bytes, NULL, NUsrc_indices_bytesLL, sorted.size(0), stream); + (NULL, exclusivesum_bytes, NULL, NULL, sorted.size(0), stream); size_t temporary_bytes = std::max(encode_bytes, exclusivesum_bytes); // Check that we have enough storage @@ -321,107 +323,6 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, num_runs_ptr, dst.size(0)); } -template -struct TakeGradGeneralKernel { - template - MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, const DType* ograd, - const IType* src_indptr, const IType* original_idx, - mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides, - const int in_ndims, const int out_ndims, const int idx_ndims, const int axis) { - const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1]; - const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1]; - const int in_mid_index = in_rest_index / in_stride[axis]; - const int in_tail_index = (axis == in_ndims - 1) ? - 0 : (in_rest_index % in_stride[axis]); - for (int i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) { - const int out_mid_index = original_idx[i]; - int target = in_tail_index + out_mid_index * out_stride[axis + idx_ndims - 1]; - target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1]; - arr_grad[tid] += ograd[target]; - } - } -} - -template -inline void TakeOpBackwardImpl(mshadow::Stream* s, - const OpContext& ctx, - const TBlob& arr, - const TBlob& idx, - const TBlob& ograd, - const int axis) { - using namespace mxnet_op; - using namespace mshadow; - CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; - const TShape& arrshape = arr.shape_; - const TShape& idxshape = idx.shape_; - const TShape& oshape = ograd.shape_; - // get size of temporary storage for sort - char* temp_storage_ptr = nullptr; - size_t scan_temp_storage_bytes = 0; - IType* src_indptr_bytes = nullptr; - cub::DeviceScan::ExclusiveSum(temp_storage_ptr, - scan_temp_storage_bytes, - src_indptr_bytes, - src_indptr_bytes, - arrshape[axis] + 1, - mshadow::Stream::GetStream(s)); - size_t sort_temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); - size_t temp_storage_bytes = max(scan_temp_storage_bytes, sort_temp_storage_bytes); - size_t original_idx_bytes = idxshape.Size() * sizeof(IType); - size_t src_indptr_bytes = (arrshape[actual_axis] + 1) * sizeof(IType); - size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes; - Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); - IType* sorted_idx_ptr = reinterpret_cast(workspace.dptr_); - IType* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); - src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); - char* temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes; - // Reset indptr to zero - Kernel::Launch(s, arrshape[actual_axis] + 1, src_indptr_ptr); - // Fill original_idx - Kernel::Launch( - s, idxshape.Size(), 1, IType(0), IType(1), kWriteTo, original_idx_ptr); - // Fill sorted_idx_ptr with unsorted copy of idx - Kernel, gpu>::Launch( - s, idxshape.Size(), sorted_idx_ptr, idx.dptr()); - if (clip) { - Kernel, gpu>::Launch(s, idxshape.Size(), sorted_idx_ptr, - sorted_idx_ptr, IType(0), IType(arrshape[axis])); - } else { - Kernel, gpu>::Launch(s, idxshape.Size(), sorted_idx_ptr, - sorted_idx_ptr, IType(arrshape[axis])); - } - Tensor original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); - Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); - int num_bits = ilog2(static_cast(idxshape.Size()) - 1); - Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); - SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits); - Kernel::Launch( - s, idxshape.Size(), src_indptr_ptr, idx.dptr(), idxshape.Size()); - cub::DeviceScan::ExclusiveSum(temp_storage_ptr, - temp_storage_bytes, - src_indptr_bytes, - src_indptr_bytes, - arrshape[actual_axis] + 1, - mshadow::Stream::GetStream(s)); - - Shape<10> in_strides; - int stride = 1; - for (int i = arrshape.ndim() - 1; i > 0; stride *= arrshape[i], --i) { - in_strides[i] = stride; - } - Shape<10> out_strides; - stride = 1; - for (int i = oshape.ndim() - 1; i > 0; stride *= oshape[i], --i) { - out_strides[i] = stride; - } - MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { - Kernel::Launch( - s, arrshape.Size(), arr.dptr(), ograd.dptr(), src_indptr_ptr, original_idx_ptr, - in_strides, out_strides, arrshape.ndim(), oshape.ndim(), idxshape.ndim(), actual_axis); - }); -} - } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_TENSOR_INDEXING_OP_CUH_ diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 8c2416e89ab6..0f96e2cc2f72 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -367,32 +367,42 @@ NNVM_REGISTER_OP(take) This function slices the input array along a particular axis with the provided indices. -Given an input array with shape ``(d0, d1, d2)`` and indices with shape ``(i0, i1)``, the output -will have shape ``(i0, i1, d1, d2)``, computed by:: - - output[i,j,:,:] = input[indices[i,j],:,:] - -.. note:: - - `axis`- Only slicing along axis 0 is supported for now. - - `mode`- Only `clip` mode is supported for now. +Given data tensor of rank r >= 1, and indices tensor of rank q, gather entries of the axis +dimension of data (by default outer-most one as axis=0) indexed by indices, and concatenates them +in an output tensor of rank q + (r - 1). Examples:: x = [4. 5. 6.] // Trivial case, take the second element along the first axis. + take(x, [1]) = [ 5. ] + // The other trivial case, axis=-1, take the third element along the first axis + + take(x, [3], axis=-1, mode='clip') = [ 6. ] + x = [[ 1., 2.], [ 3., 4.], [ 5., 6.]] // In this case we will get rows 0 and 1, then 1 and 2. Along axis 0 + take(x, [[0,1],[1,2]]) = [[[ 1., 2.], [ 3., 4.]], [[ 3., 4.], [ 5., 6.]]] + // In this case we will get rows 0 and 1, then 1 and 2 (calculated by wrapping around). + // Along axis 1 + + take(x, [[0, 3], [-1, -2]], axis=1, mode='wrap') = [[[ 1., 2.], + [ 3., 4.]], + + [[ 3., 4.], + [ 5., 6.]]] + )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -420,6 +430,7 @@ Examples:: NNVM_REGISTER_OP(_backward_take) .set_num_inputs(2) .set_num_outputs(2) +.set_attr_parser(ParamParser) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 4102d9895fba..9a360b9aac64 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -44,6 +44,9 @@ #include "./sort_op.h" #include "./init_op.h" #include "../../engine/openmp.h" +#ifdef __CUDACC__ +#include "./indexing_op-inl.cuh" +#endif namespace mxnet { namespace op { @@ -135,22 +138,6 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, dst[sorted[y]] += src[index[y]]; } } -/*! - * \brief CPU/GPU: Gradient accumulate of embedding matrix. - dst[sorted[i]] += src[index[i]] - Called when the batchsize of src is larger than the featuredim - * \param dst destination - * \param sorted the sorted indices - * \param index original index of the sorted indices - * \param src source output - * \param workspace (optional) temporary storage - */ -template -inline void AddTakeGradLargeBatch(mshadow::Tensor dst, - const mshadow::Tensor& sorted, - const mshadow::Tensor& index, - const mshadow::Tensor &src, - mshadow::Tensor* workspace = NULL); template inline bool EmbeddingOpShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, @@ -322,8 +309,13 @@ struct Take { MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx, const int M, const int K) { int j = static_cast(idx[i/M]); - if (j <= 0) j = 0; - else if (j >= K) j = K - 1; + if (clip) { + if (j <= 0) j = 0; + else if (j >= K) j = K - 1; + } else { + j = j % K; + j += (j < 0) ? K : 0; + } out_data[i] = in_data[j * M + i % M]; } @@ -353,10 +345,11 @@ struct Take { 0 : (out_rest_index % in_stride[axis]); int idx_index = static_cast(idx[out_mid_index]); if (clip) { - idx_index = (idx_index < -axis_dim) ? 0 : idx_index; + idx_index = (idx_index < 0) ? 0 : idx_index; idx_index = (idx_index > axis_dim - 1) ? (axis_dim - 1) : idx_index; } idx_index %= axis_dim; + idx_index += (idx_index < 0) ? axis_dim : 0; const int in_tail_index = out_tail_index; const int in_head_index = out_head_index; int in_src_index = in_tail_index + idx_index * in_stride[axis]; @@ -774,11 +767,11 @@ struct TakeParam: public dmlc::Parameter { .add_enum("wrap", take_::kWrap) .add_enum("clip", take_::kClip) .set_default(take_::kClip) - .describe("Specify how out-of-bound indices bahave." + .describe("Specify how out-of-bound indices bahave. Default is \"clip\"." " \"clip\" means clip to the range. So, if all indices mentioned are too large," " they are replaced by the index that addresses the last element along an axis. " " \"wrap\" means to wrap around. " - " \"raise\" means to raise an error. "); + " \"raise\" means to raise an error, not supported yet."); } }; @@ -798,12 +791,12 @@ inline bool TakeOpShape(const nnvm::NodeAttrs& attrs, out_attrs->clear(); - const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); + const index_t actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); TShape oshape(idxshape.ndim() + arrshape.ndim() - 1); - for (int i = 0; i < (int)idxshape.ndim(); ++i) { + for (index_t i = 0; i < idxshape.ndim(); ++i) { oshape[i + actual_axis] = idxshape[i]; } - for (int i = 0; i < (int)arrshape.ndim(); i++) { + for (index_t i = 0; i < arrshape.ndim(); i++) { if (i < actual_axis) { oshape[i] = arrshape[i]; } else if (i > actual_axis) { @@ -847,73 +840,253 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type - mshadow::Shape<10> in_strides; - int stride = 1; - for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { - in_strides[i] = stride; - } - mshadow::Shape<10> out_strides; - stride = 1; - for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { - out_strides[i] = stride; - } - if (param.mode == take_::kClip) { - Kernel, xpu>::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - in_strides, out_strides, arrshape.ndim(), oshape.ndim(), - idxshape.ndim(), arrshape[actual_axis], actual_axis); - } else if (param.mode == take_::kWrap) { - Kernel, xpu>::Launch(s, oshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - in_strides, out_strides, arrshape.ndim(), oshape.ndim(), - idxshape.ndim(), arrshape[actual_axis], actual_axis); + if (actual_axis == 0) { + if (param.mode == take_::kClip) { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + oshape.Size()/idxshape.Size(), arrshape[0]); + } else { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + oshape.Size()/idxshape.Size(), arrshape[0]); + } + } else { + mshadow::Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + mshadow::Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + if (param.mode == take_::kClip) { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], actual_axis); + } else if (param.mode == take_::kWrap) { + Kernel, xpu>::Launch(s, oshape.Size(), + outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + in_strides, out_strides, arrshape.ndim(), oshape.ndim(), + idxshape.ndim(), arrshape[actual_axis], actual_axis); + } } }); }); } +struct TakeGradGeneralKernel { + /*! + * \brief Map function for general case of take grad + * \param tid global thread id + * \param arr_grad ptr to in_grad + * \param ograd ptr to out_grad + * \param src_indptr ptr to indptr to src indices + * \param original_idx ptr to original indices of the inputs + * \param in_strides strides of inputs + * \param out_strides strides of outputs + * \param in_ndims # of dims of input tensor + * \param out_ndims # of dims of output tensor + * \param idx_ndims # of dims of indices tensor + * \param axis_dim dim size of the axis dimension + * \param axis axis id + */ + template + MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, const DType* ograd, + const IType* src_indptr, const IType* original_idx, + mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides, + const int in_ndims, const int out_ndims, const int idx_ndims, + const int axis) { + const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1]; + const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1]; + const int in_mid_index = in_rest_index / in_strides[axis]; + const int in_tail_index = (axis == in_ndims - 1) ? + 0 : (in_rest_index % in_strides[axis]); + for (IType i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) { + const int out_mid_index = original_idx[i]; + int target = in_tail_index + out_mid_index * in_strides[axis]; + target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1]; + arr_grad[tid] += ograd[target]; + } + } +}; + +template +void TakeOpBackwardImpl(mshadow::Stream* s, + const OpContext& ctx, + const TBlob& arr, + const TBlob& idx, + const TBlob& ograd, + const int axis) { + using namespace mxnet_op; + using namespace mshadow; + CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; + const TShape& arrshape = arr.shape_; + const TShape& idxshape = idx.shape_; + const TShape& oshape = ograd.shape_; + MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, { + // get size of temporary storage for sort + int* src_indptr_ptr = nullptr; + size_t temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); + size_t original_idx_bytes = idxshape.Size() * sizeof(int); + size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int); + size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes; + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); + int* sorted_idx_ptr = reinterpret_cast(workspace.dptr_); + int* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); + src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); + Tensor temp_storage( + workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes, Shape1(temp_storage_bytes), s); + // Reset indptr to zero + Kernel::Launch(s, arrshape[axis] + 1, src_indptr_ptr); + // Fill original_idx + Kernel::Launch(s, idxshape.Size(), 1, 0, 1, kWriteTo, original_idx_ptr); + // Fill sorted_idx_ptr with unsorted copy of idx + Kernel::Launch( + s, idxshape.Size(), sorted_idx_ptr, idx.dptr()); + if (clip) { + Kernel, cpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, + 0, static_cast(arrshape[axis] - 1)); + } else { + Kernel, cpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, static_cast(arrshape[axis])); + } + Tensor original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); + int num_bits = ilog2(static_cast(idxshape.Size()) - 1); + Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); + SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits); + for (size_t i = 0; i < idxshape.Size(); ++i) { + src_indptr_ptr[sorted_idx_ptr[i] + 1] += 1; + } + for (int i = 0; i < arrshape[axis]; ++i) { + src_indptr_ptr[i + 1] += src_indptr_ptr[i]; + } + Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { + Kernel::Launch( + s, arrshape.Size(), arr.dptr(), ograd.dptr(), src_indptr_ptr, + original_idx_ptr, in_strides, out_strides, + arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis); + }); + }); +} + +#ifdef __CUDACC__ template -inline void TakeOpBackwardImpl(mshadow::Stream* s, - const OpContext& ctx, - const TBlob& arr, - const TBlob& idx, - const TBlob& ograd, - const int axis) { - return; - // CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; - // const TShape& arrshape = arr.shape_; - // const TShape& idxshape = idx.shape_; - // const TShape& oshape = ograd.shape_; - // // get size of temporary storage for sort - // size_t temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); - // size_t original_idx_bytes = idxshape.Size() * sizeof(IType); - // size_t src_indices_bytes = arrshape[actual_axis] * sizeof(IType); - // size_t workspace_bytes = src_indices_bytes + 2 * original_idx_bytes + temp_storage_bytes; - // Tensor workspace = - // ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); - // IType* sorted_idx_ptr = reinterpret_cast(workspace.dptr_); - // IType* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); - // IType* src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); - // char* temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indices_bytes; - // // Reset indptr to zero - // mxnet_op::Kernel::Launch(s, arrshape[actual_axis], src_indptr_ptr); - // // Fill original_idx - // mxnet_op::Kernel::Launch( - // s, idxshape.Size(), 1, IType(0), IType(1), kWriteTo, original_idx_ptr); - // // Fill sorted_idx_ptr with unsorted copy of idx - // mxnet_op::Kernel, xpu>::Launch( - // s, idxshape.Size(), sorted_idx_ptr, idx.dptr()); - // Tensor original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); - // Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); - // int num_bits = ilog2(static_cast(idxshape.Size()) - 1); - // SortByKey(idx.dptr, original_idx, true, &temp_storage, 0, num_bits); - // Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); - // Tensor src_indptr(src_indptr_ptr, Shape1(arrshape[actual_axis]), s); +void TakeOpBackwardImpl(mshadow::Stream* s, + const OpContext& ctx, + const TBlob& arr, + const TBlob& idx, + const TBlob& ograd, + const int axis) { + using namespace mxnet_op; + using namespace mshadow; + CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation"; + const TShape& arrshape = arr.shape_; + const TShape& idxshape = idx.shape_; + const TShape& oshape = ograd.shape_; + MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, { + // get size of temporary storage for sort + char* temp_storage_ptr = nullptr; + size_t scan_temp_storage_bytes = 0; + int* src_indptr_ptr = nullptr; + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, + scan_temp_storage_bytes, + src_indptr_ptr, + src_indptr_ptr, + arrshape[axis] + 1, + mshadow::Stream::GetStream(s)); + size_t sort_temp_storage_bytes = SortByKeyWorkspaceSize(idxshape.Size()); + size_t histo_temp_storage_bytes = 0; + int* sorted_idx_ptr = nullptr; + cub::DeviceHistogram::HistogramEven(temp_storage_ptr, + histo_temp_storage_bytes, + sorted_idx_ptr, + src_indptr_ptr, + static_cast(arrshape[axis] + 1), + 0, + static_cast(arrshape[axis] + 1), + static_cast(idxshape.Size()), + mshadow::Stream::GetStream(s)); + size_t temp_storage_bytes = max(scan_temp_storage_bytes, sort_temp_storage_bytes); + temp_storage_bytes = max(temp_storage_bytes, histo_temp_storage_bytes); + size_t original_idx_bytes = idxshape.Size() * sizeof(int); + size_t src_indptr_bytes = (arrshape[axis] + 1) * sizeof(int); + size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes; + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); + sorted_idx_ptr = reinterpret_cast(workspace.dptr_); + int* original_idx_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); + src_indptr_ptr = reinterpret_cast(workspace.dptr_ + 2 * original_idx_bytes); + temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes; + // Reset indptr to zero + Kernel::Launch(s, arrshape[axis] + 1, src_indptr_ptr); + // Fill original_idx + Kernel::Launch( + s, idxshape.Size(), 1, static_cast(0), static_cast(1), + kWriteTo, original_idx_ptr); + // Fill sorted_idx_ptr with unsorted copy of idx + Kernel::Launch( + s, idxshape.Size(), sorted_idx_ptr, idx.dptr()); + if (clip) { + Kernel, gpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, 0, static_cast(arrshape[axis])); + } else { + Kernel, gpu>::Launch( + s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, static_cast(arrshape[axis])); + } + Tensor original_idx(original_idx_ptr, Shape1(idxshape.Size()), s); + Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); + int num_bits = ilog2(static_cast(idxshape.Size()) - 1); + Tensor sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s); + SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits); + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, + temp_storage_bytes, + src_indptr_ptr, + src_indptr_ptr, + arrshape[axis] + 1, + mshadow::Stream::GetStream(s)); + + Shape<10> in_strides; + int stride = 1; + for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { + in_strides[i] = stride; + } + Shape<10> out_strides; + stride = 1; + for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { + out_strides[i] = stride; + } + MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, { + Kernel::Launch( + s, arrshape.Size(), arr.dptr(), ograd.dptr(), + src_indptr_ptr, original_idx_ptr, in_strides, out_strides, + arrshape.ndim(), oshape.ndim(), idxshape.ndim(), axis); + }); + }); } +#endif template void TakeOpBackward(const nnvm::NodeAttrs& attrs, @@ -943,7 +1116,7 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, const TShape& arrshape = outputs[0].shape_; const TShape& oshape = inputs[0].shape_; - const int actual_axis = param.axis + ((param.axis < 0) ? oshape.ndim() : 0); + const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); int idxndim = idxshape.ndim(); Tensor idx = inputs[1].get_with_shape( @@ -1445,7 +1618,5 @@ void ScatterSetNDForward(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet -#ifdef __CUDACC__ -#include "./indexing_op-inl.cuh" -#endif + #endif // MXNET_OPERATOR_TENSOR_INDEXING_OP_H_ diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index d7721f6aa2b0..734235331a24 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3652,6 +3652,32 @@ def test_blockgrad(): @with_seed() def test_take(): + def grad_helper(grad_in, axis, idx): + if axis == 0: + if axis == len(grad_in.shape) - 1: + grad_in[idx] += 1.0 + else: + grad_in[idx, :] += 1.0 + elif axis == 1: + if axis == len(grad_in.shape) - 1: + grad_in[:, idx] += 1.0 + else: + grad_in[:, idx, :] += 1.0 + elif axis == 2: + if axis == len(grad_in.shape) - 1: + grad_in[:, :, idx] += 1.0 + else: + grad_in[:, :, idx, :] += 1.0 + elif axis == 3: + if axis == len(grad_in.shape) - 1: + grad_in[:, :, :, idx] += 1.0 + else: + grad_in[:, :, :, idx, :] += 1.0 + elif axis == 4: + grad_in[:, :, :, :, idx] += 1.0 + else: + raise ValueError("axis %d is not supported..." % axis) + def check_output_n_grad(data_shape, idx_shape, axis, mode): data = mx.sym.Variable('a') idx = mx.sym.Variable('indices') @@ -3661,32 +3687,34 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode): indices=idx_shape, axis=axis, mode=mode) data_real = np.random.normal(size=data_shape).astype('float32') idx_real = np.random.randint(low=0, high=data_shape[axis], size=idx_shape) - # grad_out = np.ones(idx_shape + data_shape[1:], dtype='float32') - # grad_in = np.zeros(data_shape, dtype='float32') + if axis < 0: + axis += len(data_shape) + + grad_out = np.ones((data_shape[0:axis] if axis > 0 else ()) + idx_shape + (data_shape[axis+1:] if axis < len(data_shape) - 1 else ()), dtype='float32') + grad_in = np.zeros(data_shape, dtype='float32') exe.arg_dict['a'][:] = mx.nd.array(data_real) exe.arg_dict['indices'][:] = mx.nd.array(idx_real) exe.forward(is_train=True) assert_almost_equal(exe.outputs[0].asnumpy(), np.take(data_real, idx_real, axis=axis, mode=mode)) - # for i in np.nditer(idx_real): - # grad_in[i] += 1.0 - # - # exe.backward([mx.nd.array(grad_out)]) - # assert_almost_equal(exe.grad_dict['a'].asnumpy(), grad_in) - - - for data_ndim in range(1, 5): - for idx_ndim in range(1, 4): - for axis in range(-data_ndim, data_ndim): - data_shape = () - for _ in range(data_ndim): - data_shape += (np.random.randint(low=1, high=5), ) - idx_shape = () - for _ in range(idx_ndim): - idx_shape += (np.random.randint(low=1, high=5), ) - check_output_n_grad(data_shape, idx_shape, axis, 'clip') - check_output_n_grad(data_shape, idx_shape, axis, 'wrap') + for i in np.nditer(idx_real): + grad_helper(grad_in, axis, i) + + exe.backward([mx.nd.array(grad_out)]) + assert_almost_equal(exe.grad_dict['a'].asnumpy(), grad_in) + + for mode in ['clip', 'wrap']: + for data_ndim in range(1, 5): + for idx_ndim in range(1, 4): + for axis in range(-data_ndim, data_ndim): + data_shape = () + for _ in range(data_ndim): + data_shape += (np.random.randint(low=1, high=5), ) + idx_shape = () + for _ in range(idx_ndim): + idx_shape += (np.random.randint(low=1, high=5), ) + check_output_n_grad(data_shape, idx_shape, axis, mode) @with_seed()