Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
general take backward on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Jun 18, 2018
1 parent 52cb1d1 commit 39a5c67
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 27 deletions.
10 changes: 10 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,16 @@ struct clip : public mxnet_op::tunable {
return x;
}
}
template<typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType upper_bound, DType lower_bound) {
DType ret = x;
if (x > upper_bound) {
return upper_bound;
} else if (x < lower_bound) {
return lower_bound;
}
return x;
}
};

/***** gamma ******/
Expand Down
103 changes: 102 additions & 1 deletion src/operator/tensor/indexing_op-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
(NULL, encode_bytes, NULL, NULL, NULL, NULL, sorted.size(0), stream);
size_t exclusivesum_bytes = 0;
cub::DeviceScan::ExclusiveSum<IndexType*, IndexType*>
(NULL, exclusivesum_bytes, NULL, NULL, sorted.size(0), stream);
(NULL, exclusivesum_bytes, NULL, NUsrc_indices_bytesLL, sorted.size(0), stream);
size_t temporary_bytes = std::max(encode_bytes, exclusivesum_bytes);

// Check that we have enough storage
Expand All @@ -320,6 +320,107 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
num_runs_ptr, dst.size(0));
}

template<bool clip = true>
struct TakeGradGeneralKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int tid, DType* arr_grad, const DType* ograd,
const IType* src_indptr, const IType* original_idx,
mshadow::Shape<10> in_strides, mshadow::Shape<10> out_strides,
const int in_ndims, const int out_ndims, const int idx_ndims, const int axis) {
const int in_head_index = (axis == 0) ? 0 : tid / in_strides[axis - 1];
const int in_rest_index = (axis == 0) ? tid : tid % in_strides[axis - 1];
const int in_mid_index = in_rest_index / in_stride[axis];
const int in_tail_index = (axis == in_ndims - 1) ?
0 : (in_rest_index % in_stride[axis]);
for (int i = src_indptr[in_mid_index]; i < src_indptr[in_mid_index + 1]; ++i) {
const int out_mid_index = original_idx[i];
int target = in_tail_index + out_mid_index * out_stride[axis + idx_ndims - 1];
target += (axis == 0) ? 0 : in_head_index * out_strides[axis - 1];
arr_grad[tid] += ograd[target];
}
}
}

template<bool clip = true>
inline void TakeOpBackwardImpl(mshadow::Stream<gpu>* s,
const OpContext& ctx,
const TBlob& arr,
const TBlob& idx,
const TBlob& ograd,
const int axis) {
using namespace mxnet_op;
using namespace mshadow;
CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation";
const TShape& arrshape = arr.shape_;
const TShape& idxshape = idx.shape_;
const TShape& oshape = ograd.shape_;
// get size of temporary storage for sort
char* temp_storage_ptr = nullptr;
size_t scan_temp_storage_bytes = 0;
IType* src_indptr_bytes = nullptr;
cub::DeviceScan::ExclusiveSum(temp_storage_ptr,
scan_temp_storage_bytes,
src_indptr_bytes,
src_indptr_bytes,
arrshape[axis] + 1,
mshadow::Stream<gpu>::GetStream(s));
size_t sort_temp_storage_bytes = SortByKeyWorkspaceSize<IType, IType, xpu>(idxshape.Size());
size_t temp_storage_bytes = max(scan_temp_storage_bytes, sort_temp_storage_bytes);
size_t original_idx_bytes = idxshape.Size() * sizeof(IType);
size_t src_indptr_bytes = (arrshape[actual_axis] + 1) * sizeof(IType);
size_t workspace_bytes = src_indptr_bytes + 2 * original_idx_bytes + temp_storage_bytes;
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_bytes), s);
IType* sorted_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_);
IType* original_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_ + original_idx_bytes);
src_indptr_ptr = reinterpret_cast<IType*>(workspace.dptr_ + 2 * original_idx_bytes);
char* temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes;
// Reset indptr to zero
Kernel<set_zero, gpu>::Launch(s, arrshape[actual_axis] + 1, src_indptr_ptr);
// Fill original_idx
Kernel<range_fwd, gpu>::Launch(
s, idxshape.Size(), 1, IType(0), IType(1), kWriteTo, original_idx_ptr);
// Fill sorted_idx_ptr with unsorted copy of idx
Kernel<op_with_req<mshadow_op::identity, kWriteTo>, gpu>::Launch(
s, idxshape.Size(), sorted_idx_ptr, idx.dptr<IType>());
if (clip) {
Kernel<op_with_req<clip, kWriteTo>, gpu>::Launch(s, idxshape.Size(), sorted_idx_ptr,
sorted_idx_ptr, IType(0), IType(arrshape[axis]));
} else {
Kernel<op_with_req<mod, kWriteTo>, gpu>::Launch(s, idxshape.Size(), sorted_idx_ptr,
sorted_idx_ptr, IType(arrshape[axis]));
}
Tensor<gpu, 1, IType> original_idx(original_idx_ptr, Shape1(idxshape.Size()), s);
Tensor<gpu, 1, char> temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s);
int num_bits = ilog2(static_cast<unsigned int>(idxshape.Size()) - 1);
Tensor<gpu, 1, IType> sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s);
SortByKey(sorted_idx, original_idx, true, &temp_storage, 0, num_bits);
Kernel<HistogramKernel, gpu>::Launch(
s, idxshape.Size(), src_indptr_ptr, idx.dptr<IType>(), idxshape.Size());
cub::DeviceScan::ExclusiveSum(temp_storage_ptr,
temp_storage_bytes,
src_indptr_bytes,
src_indptr_bytes,
arrshape[actual_axis] + 1,
mshadow::Stream<gpu>::GetStream(s));

Shape<10> in_strides;
int stride = 1;
for (int i = arrshape.ndim() - 1; i > 0; stride *= arrshape[i], --i) {
in_strides[i] = stride;
}
Shape<10> out_strides;
stride = 1;
for (int i = oshape.ndim() - 1; i > 0; stride *= oshape[i], --i) {
out_strides[i] = stride;
}
MSHADOW_TYPE_SWITCH(arr.type_flag_, DType, {
Kernel<TakeGradGeneralKernel, gpu>::Launch(
s, arrshape.Size(), arr.dptr<DType>(), ograd.dptr<DType>(), src_indptr_ptr, original_idx_ptr,
in_strides, out_strides, arrshape.ndim(), oshape.ndim(), idxshape.ndim(), actual_axis);
});
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_INDEXING_OP_CUH_
109 changes: 83 additions & 26 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,45 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs,
});
}

template<bool clip = true>
inline void TakeOpBackwardImpl(mshadow::Stream<cpu>* s,
const OpContext& ctx,
const TBlob& arr,
const TBlob& idx,
const TBlob& ograd,
const int axis) {
return;
// CHECK(axis != 0) << "axis == 0 case should be dispatched to the legacy implementation";
// const TShape& arrshape = arr.shape_;
// const TShape& idxshape = idx.shape_;
// const TShape& oshape = ograd.shape_;
// // get size of temporary storage for sort
// size_t temp_storage_bytes = SortByKeyWorkspaceSize<IType, IType, cpu>(idxshape.Size());
// size_t original_idx_bytes = idxshape.Size() * sizeof(IType);
// size_t src_indices_bytes = arrshape[actual_axis] * sizeof(IType);
// size_t workspace_bytes = src_indices_bytes + 2 * original_idx_bytes + temp_storage_bytes;
// Tensor<xpu, 1, char> workspace =
// ctx.requested[0].get_space_typed<cpu, 1, char>(Shape1(workspace_bytes), s);
// IType* sorted_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_);
// IType* original_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_ + original_idx_bytes);
// IType* src_indptr_ptr = reinterpret_cast<IType*>(workspace.dptr_ + 2 * original_idx_bytes);
// char* temp_storage_ptr = workspace.dptr_ + 2 * original_idx_bytes + src_indices_bytes;
// // Reset indptr to zero
// mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(s, arrshape[actual_axis], src_indptr_ptr);
// // Fill original_idx
// mxnet_op::Kernel<range_fwd, xpu>::Launch(
// s, idxshape.Size(), 1, IType(0), IType(1), kWriteTo, original_idx_ptr);
// // Fill sorted_idx_ptr with unsorted copy of idx
// mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, kWriteTo>, xpu>::Launch(
// s, idxshape.Size(), sorted_idx_ptr, idx.dptr<IType>());
// Tensor<xpu, 1, IType> original_idx(original_idx_ptr, Shape1(idxshape.Size()), s);
// Tensor<xpu, 1, char> temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s);
// int num_bits = ilog2(static_cast<unsigned int>(idxshape.Size()) - 1);
// SortByKey(idx.dptr<IType>, original_idx, true, &temp_storage, 0, num_bits);
// Tensor<xpu, 1, IType> sorted_idx(sorted_idx_ptr, Shape1(idxshape.Size()), s);
// Tensor<xpu, 1, IType> src_indptr(src_indptr_ptr, Shape1(arrshape[actual_axis]), s);
}

template<typename xpu>
void TakeOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -889,48 +928,66 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
CHECK_EQ(req[take_::kIdx], kNullOp)
<< "take layer doesn't support gradient into index";

// inputs are specified in the .cc file, which are the gradients from
// the upper layer and the input index
// outputs are the gradients of inputs in the feed-forward pass
const TShape& idxshape = inputs[1].shape_;
const TShape& arrshape = outputs[0].shape_;
const TShape& oshape = inputs[0].shape_;

int idxndim = idxshape.ndim();
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);

// grad_out is the gradient of the outputs in the feed-forward
// grad_in is the gradient of the inputs in the feed-forward
Stream<xpu> *s = ctx.get_stream<xpu>();

MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { // output data type
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type
// inputs are specified in the .cc file, which are the gradients from
// the upper layer and the input index
// outputs are the gradients of inputs in the feed-forward pass
const TShape& idxshape = inputs[1].shape_;
const TShape& arrshape = outputs[0].shape_;
const TShape& oshape = inputs[0].shape_;

const int actual_axis = param.axis + ((param.axis < 0) ? oshape.ndim() : 0);

int idxndim = idxshape.ndim();
Tensor<xpu, 1, IType> idx = inputs[1].get_with_shape<xpu, 1, IType>(
Shape1(idxshape.ProdShape(0, idxndim)), s);
Tensor<xpu, 2, DType> grad_out = inputs[0].get_with_shape<xpu, 2, DType>(
Shape2(oshape.ProdShape(0, idxndim), oshape.ProdShape(idxndim, oshape.ndim())), s);
Tensor<xpu, 2, DType> grad_in = outputs[0].get_with_shape<xpu, 2, DType>(
Shape2(arrshape[0], arrshape.ProdShape(1, arrshape.ndim())), s);

if (req[take_::kArr] == kWriteTo || req[take_::kArr] == kAddTo) {
if (req[take_::kArr] == kWriteTo) {
grad_in = scalar<DType>(0.0f);
}
// shape_out_prod ~= the number of elements loaded in AddTakeGrad
// shape_in_prod ~= the number of elements stored in AddTakeGrad
// When the number of elements processed is low, use AddTakeGrad.
// The approximate cut-off value 16384 was found experimentally on Titan X Pascal
uint64_t shape_in_prod =
static_cast<uint64_t>(grad_in.shape_[0])*
static_cast<uint64_t>(grad_in.shape_[1]);
uint64_t shape_out_prod =
static_cast<uint64_t>(grad_out.shape_[0])*
static_cast<uint64_t>(grad_out.shape_[1]);
if (shape_out_prod < (uint64_t)16384 && shape_in_prod < (uint64_t)16384) {
AddTakeGrad(grad_in, idx, grad_out);
// re-using the previous code for axis = 0 case
if (actual_axis == 0) {
if (req[take_::kArr] == kWriteTo || req[take_::kArr] == kAddTo) {
if (req[take_::kArr] == kWriteTo) {
grad_in = scalar<DType>(0.0f);
}
// shape_out_prod ~= the number of elements loaded in AddTakeGrad
// shape_in_prod ~= the number of elements stored in AddTakeGrad
// When the number of elements processed is low, use AddTakeGrad.
// The approximate cut-off value 16384 was found experimentally on Titan X Pascal
uint64_t shape_in_prod =
static_cast<uint64_t>(grad_in.shape_[0])*
static_cast<uint64_t>(grad_in.shape_[1]);
uint64_t shape_out_prod =
static_cast<uint64_t>(grad_out.shape_[0])*
static_cast<uint64_t>(grad_out.shape_[1]);
if (shape_out_prod < (uint64_t)16384 && shape_in_prod < (uint64_t)16384) {
AddTakeGrad(grad_in, idx, grad_out);
} else {
AddTakeGradLargeBatchCaller(ctx, grad_in, idx, grad_out);
}
} else {
AddTakeGradLargeBatchCaller(ctx, grad_in, idx, grad_out);
LOG(FATAL) << "wrong req";
}
// for all other cases
} else {
LOG(FATAL) << "wrong req";
const TBlob& idx = inputs[1];
const TBlob& arr = outputs[0];
const TBlob& ograd = inputs[0];

if (param.mode == take_::kClip) {
TakeOpBackwardImpl<true>(s, ctx, arr, idx, ograd, actual_axis);
} else {
TakeOpBackwardImpl<false>(s, ctx, arr, idx, ograd, actual_axis);
}
}
});
});
Expand Down

0 comments on commit 39a5c67

Please sign in to comment.