Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
backward of enhanced take op
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Jun 29, 2018
1 parent a2b196b commit fdd788f
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 216 deletions.
3 changes: 1 addition & 2 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,7 @@ struct clip : public mxnet_op::tunable {
}
}
template<typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType upper_bound, DType lower_bound) {
DType ret = x;
MSHADOW_XINLINE static DType Map(DType x, DType lower_bound, DType upper_bound) {
if (x > upper_bound) {
return upper_bound;
} else if (x < lower_bound) {
Expand Down
107 changes: 4 additions & 103 deletions src/operator/tensor/indexing_op-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <cub/device/device_run_length_encode.cuh>
#include <cub/device/device_scan.cuh>
#include "../mxnet_op.h"
#include "../mshadow_op.h"
#include "./util/tensor_util-inl.cuh"

#if CUDA_VERSION >= 9000
#define FULLMASK 0xFFFFFFFF
Expand Down Expand Up @@ -272,7 +274,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
const mshadow::Tensor<gpu, 1, IndexType>& sorted,
const mshadow::Tensor<gpu, 1, IndexType>& index,
const mshadow::Tensor<gpu, 2, DType> &src,
mshadow::Tensor<gpu, 1, char>* workspace) {
mshadow::Tensor<gpu, 1, char>* workspace = NULL) {
CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(sorted.CheckContiguous(), true);
CHECK_EQ(index.CheckContiguous(), true);
Expand All @@ -294,7 +296,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
(NULL, encode_bytes, NULL, NULL, NULL, NULL, sorted.size(0), stream);
size_t exclusivesum_bytes = 0;
cub::DeviceScan::ExclusiveSum<IndexType*, IndexType*>
(NULL, exclusivesum_bytes, NULL, NUsrc_indices_bytesLL, sorted.size(0), stream);
(NULL, exclusivesum_bytes, NULL, NULL, sorted.size(0), stream);
size_t temporary_bytes = std::max(encode_bytes, exclusivesum_bytes);

// Check that we have enough storage
Expand All @@ -321,107 +323,6 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
num_runs_ptr, dst.size(0));
}

template<bool clip = true>
struct TakeGradGeneralKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, const DType* ograd,
const IType* src_indptr, const IType* original_idx,
mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides,
const int in_ndims, const int out_ndims, const int idx_ndims, const int axis) {
const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1];
const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1];
const int in_mid_index = in_rest_index / in_stride[axis];
const int in_tail_index = (axis == in_ndims - 1) ?
0 : (in_rest_index % in_stride[axis]);
for (int i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) {
const int out_mid_index = original_idx[i];
int target = in_tail_index + out_mid_index * out_stride[axis + idx_ndims - 1];
target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1];
arr_grad[tid] += ograd[target];
}
}
}

template<bool clip = true>
inline void TakeOpBackwardImpl(mshadow::Stream<gpu>* s,
const OpContext& ctx,
const TBlob& arr,
const TBlob& idx,
const TBlob& ograd,
const int axis) {
using namespace mxnet_op;
using namespace mshadow;
CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation";
const TShape& arrshape = arr.shape_;
const TShape& idxshape = idx.shape_;
const TShape& oshape = ograd.shape_;
// get size of temporary storage for sort
char* temp_storage_ptr = nullptr;
size_t scan_temp_storage_bytes = 0;
IType* src_indptr_bytes = nullptr;
cub::DeviceScan::ExclusiveSum(temp_storage_ptr,
scan_temp_storage_bytes,
src_indptr_bytes,
src_indptr_bytes,
arrshape[axis] + 1,
mshadow::Stream<gpu>::GetStream(s));
size_t sort_temp_storage_bytes = SortByKeyWorkspaceSize<IType, IType, xpu>(idxshape.Size());
size_t temp_storage_bytes = max(scan_temp_storage_bytes, sort_temp_storage_bytes);
size_t original_idx_bytes = idxshape.Size() * sizeof(IType);
size_t src_indptr_bytes = (arrshape[actual_axis] + 1) * sizeof(IType);
size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes;
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_bytes), s);
IType* sorted_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_);
IType* original_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_ + original_idx_bytes);
src_indptr_ptr = reinterpret_cast<IType*>(workspace.dptr_ + 2 * original_idx_bytes);
char* temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes;
// Reset indptr to zero
Kernel<set_zero, gpu>::Launch(s, arrshape[actual_axis] + 1, src_indptr_ptr);
// Fill original_idx
Kernel<range_fwd, gpu>::Launch(
s, idxshape.Size(), 1, IType(0), IType(1), kWriteTo, original_idx_ptr);
// Fill sorted_idx_ptr with unsorted copy of idx
Kernel<op_with_req<mshadow_op::identity, kWriteTo>, gpu>::Launch(
s, idxshape.Size(), sorted_idx_ptr, idx.dptr<IType>());
if (clip) {
Kernel<op_with_req<clip, kWriteTo>, gpu>::Launch(s, idxshape.Size(), sorted_idx_ptr,
sorted_idx_ptr, IType(0), IType(arrshape[axis]));
} else {
Kernel<op_with_req<mod, kWriteTo>, gpu>::Launch(s, idxshape.Size(), sorted_idx_ptr,
sorted_idx_ptr, IType(arrshape[axis]));
}
Tensor<gpu, 1, IType> original_idx(original_idx_ptr, Shape1(idxshape.Size()), s);
Tensor<gpu, 1, char> temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s);
int num_bits = ilog2(static_cast<unsigned int>(idxshape.Size()) - 1);
Tensor<gpu, 1, IType> sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s);
SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits);
Kernel<HistogramKernel, gpu>::Launch(
s, idxshape.Size(), src_indptr_ptr, idx.dptr<IType>(), idxshape.Size());
cub::DeviceScan::ExclusiveSum(temp_storage_ptr,
temp_storage_bytes,
src_indptr_bytes,
src_indptr_bytes,
arrshape[actual_axis] + 1,
mshadow::Stream<gpu>::GetStream(s));

Shape<10> in_strides;
int stride = 1;
for (int i = arrshape.ndim() - 1; i > 0; stride *= arrshape[i], --i) {
in_strides[i] = stride;
}
Shape<10> out_strides;
stride = 1;
for (int i = oshape.ndim() - 1; i > 0; stride *= oshape[i], --i) {
out_strides[i] = stride;
}
MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, {
Kernel<TakeGradGeneralKernel, gpu>::Launch(
s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(), src_indptr_ptr, original_idx_ptr,
in_strides, out_strides, arrshape.ndim(), oshape.ndim(), idxshape.ndim(), actual_axis);
});
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_INDEXING_OP_CUH_
22 changes: 17 additions & 5 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,21 +360,24 @@ NNVM_REGISTER_OP(take)
This function slices the input array along a particular axis with the provided indices.
Given an input array with shape ``(d0, d1, d2)`` and indices with shape ``(i0, i1)``, the output
will have shape ``(i0, i1, d1, d2)``, computed by::
Given an input tensor with shape ``(d0, d1, d2)`` and indices with shape ``(i0, i1)``, and axis=1,
the output will have shape ``(d0, i0, i1, d2)``, computed by::
output[i,j,:,:] = input[indices[i,j],:,:]
output[:,i,j,:] = input[:,indices[i,j],:]
.. note::
- `axis`- Only slicing along axis 0 is supported for now.
- `mode`- Only `clip` mode is supported for now.
- `axis`- Could be from -r to r-1 where r is the rank of input tensor
- `mode`- Could be either `clip` or `wrap`.
Examples::
x = [4. 5. 6.]
// Trivial case, take the second element along the first axis.
take(x, [1]) = [ 5. ]
// The other trivial case, axis=-1, take the third element along the first axis
take(x, [3], axis=-1, mode='clip') = [ 6. ]
x = [[ 1., 2.],
[ 3., 4.],
[ 5., 6.]]
Expand All @@ -386,6 +389,14 @@ Examples::
[[ 3., 4.],
[ 5., 6.]]]
// In this case we will get rows 0 and 1, then 1 and 2 (calculated by wrapping around).
// Along axis 1
take(x, [[0, 3], [-1, -2]], axis=1, mode='wrap') = [[[ 1., 2.],
[ 3., 4.]],
[[ 3., 4.],
[ 5., 6.]]]
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
Expand Down Expand Up @@ -413,6 +424,7 @@ Examples::
NNVM_REGISTER_OP(_backward_take)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr_parser(ParamParser<TakeParam>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
Expand Down
Loading

0 comments on commit fdd788f

Please sign in to comment.