Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize prelu alpha grad #7600

Merged
merged 9 commits into from
Feb 26, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion oneflow/core/functional/impl/activation_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,13 @@ class PReluGradFunctor {
}
Maybe<TensorTuple> operator()(const std::shared_ptr<Tensor>& dy, const std::shared_ptr<Tensor>& x,
const std::shared_ptr<Tensor>& alpha) const {
return OpInterpUtil::Dispatch<one::TensorTuple>(*op_, {dy, x, alpha});
MutableAttrMap attrs;
if (alpha->requires_grad()) {
JUST(attrs.SetAttr<bool>("alpha_requires_grad", true));
} else {
JUST(attrs.SetAttr<bool>("alpha_requires_grad", false));
}
return OpInterpUtil::Dispatch<one::TensorTuple>(*op_, {dy, x, alpha}, attrs);
}

private:
Expand Down
3 changes: 3 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4633,6 +4633,9 @@ def OneFlow_PreluGradOp : OneFlow_BaseOp<"prelu_grad", [NoSideEffect, DeclareOpI
OneFlow_Tensor:$dx,
OneFlow_Tensor:$alpha_diff
);
let attrs = (ins
DefaultValuedAttr<BoolAttr, "false">:$alpha_requires_grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是说 alpha_requires_grad 默认是 false 吗,感觉它不应该有默认值

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emmm,感觉默认为false比较符合直觉?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉没有很强的理由默认选择 false,是不是要求用户显式传入会好一些

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉没有很强的理由默认选择 false,是不是要求用户显式传入会好一些

同意,而且有默认也应该是true,因为这里应为false但是没设置导致默认为true,只对性能有影响,反过来影响正确性

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉没有很强的理由默认选择 false,是不是要求用户显式传入会好一些

好的,那我改成true吧(这个是可以显示传入的,如果alpha的requires_grad=True/False就显示传入了

);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
Expand Down
207 changes: 139 additions & 68 deletions oneflow/user/kernels/prelu_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct PreluForwardSingleAlphaPtrFunctor {
const T* alpha_ptr;
};

template<typename T, typename IndexType, int pack_size, bool tail>
template<typename T, typename IndexType, int pack_size, bool tail, bool alpha_requires_grad>
__global__ void PReluBackwardSingleAlphaGpu(const IndexType elem_cnt, const int64_t n_tail,
const T* x, const T* alpha, const T* dy, T* dx,
T* alpha_diff, const T* tail_x, const T* tail_dy,
Expand All @@ -70,32 +70,44 @@ __global__ void PReluBackwardSingleAlphaGpu(const IndexType elem_cnt, const int6
dy_vec.storage = *dy_load;

LoadPack dx_vec;
LoadPack dalpha_vec;

T zero_val = static_cast<T>(0.0);
T alpha_diff_i = 0;
if (alpha_requires_grad) {
LoadPack dalpha_vec;
#pragma unroll
for (int i = 0; i < pack_size; i++) {
if (x_vec.elem[i] > zero_val) {
dx_vec.elem[i] = dy_vec.elem[i];
dalpha_vec.elem[i] = zero_val;
} else {
dx_vec.elem[i] = dy_vec.elem[i] * alpha_val;
dalpha_vec.elem[i] = dy_vec.elem[i] * x_vec.elem[i];
for (int i = 0; i < pack_size; i++) {
if (x_vec.elem[i] > zero_val) {
dx_vec.elem[i] = dy_vec.elem[i];
dalpha_vec.elem[i] = zero_val;
} else {
dx_vec.elem[i] = dy_vec.elem[i] * alpha_val;
dalpha_vec.elem[i] = dy_vec.elem[i] * x_vec.elem[i];
}
}
*(reinterpret_cast<LoadType*>(dx + linear_index)) = dx_vec.storage;
*(reinterpret_cast<LoadType*>(alpha_diff + linear_index)) = dalpha_vec.storage;
} else {
#pragma unroll
for (int i = 0; i < pack_size; i++) {
if (x_vec.elem[i] > zero_val) {
dx_vec.elem[i] = dy_vec.elem[i];
} else {
dx_vec.elem[i] = dy_vec.elem[i] * alpha_val;
}
}
*(reinterpret_cast<LoadType*>(dx + linear_index)) = dx_vec.storage;
}
*(reinterpret_cast<LoadType*>(dx + linear_index)) = dx_vec.storage;
*(reinterpret_cast<LoadType*>(alpha_diff + linear_index)) = dalpha_vec.storage;
}

if (tail && global_thread_id < n_tail) {
const T tail_dy_val = tail_dy[global_thread_id];
if (tail_x[global_thread_id] > zero_val) {
tail_dx[global_thread_id] = tail_dy_val;
tail_alpha_diff[global_thread_id] = zero_val;
if (alpha_requires_grad) { tail_alpha_diff[global_thread_id] = zero_val; }
} else {
tail_dx[global_thread_id] = alpha_val * tail_dy_val;
tail_alpha_diff[global_thread_id] = tail_x[global_thread_id] * tail_dy_val;
if (alpha_requires_grad) {
tail_alpha_diff[global_thread_id] = tail_x[global_thread_id] * tail_dy_val;
}
}
}
}
Expand Down Expand Up @@ -141,7 +153,7 @@ __global__ void PReluForwardMultiAlphaGpu(const IndexType elem_cnt, const IndexT
}
}

template<typename T>
template<typename T, bool alpha_requires_grad>
__global__ void BroadcastPReluMultiAlphaNaiveBackwardGpu(const int32_t elem_cnt,
const int32_t alpha_size,
const int32_t inner_size, const T* x,
Expand All @@ -154,15 +166,15 @@ __global__ void BroadcastPReluMultiAlphaNaiveBackwardGpu(const int32_t elem_cnt,
int32_t alpha_i = (i / inner_size) % alpha_size;
if (x_i > zero_val) {
dx[i] = dy_i;
alpha_diff[i] = zero_val;
if (alpha_requires_grad) { alpha_diff[i] = zero_val; }
} else {
dx[i] = dy_i * alpha[alpha_i];
alpha_diff[i] = dy_i * x_i;
if (alpha_requires_grad) { alpha_diff[i] = dy_i * x_i; }
}
}
}

template<typename T, typename IndexType, int pack_size>
template<typename T, typename IndexType, int pack_size, bool alpha_requires_grad>
__global__ void PReluBackwardMultiAlphaGpu(const IndexType elem_cnt, const IndexType alpha_size,
const IndexType inner_size, const T* x, const T* alpha,
const T* dy, T* dx, T* alpha_diff) {
Expand All @@ -184,23 +196,33 @@ __global__ void PReluBackwardMultiAlphaGpu(const IndexType elem_cnt, const Index
dy_vec.storage = *dy_load;

LoadPack dx_vec;
LoadPack dalpha_vec;

T zero_val = static_cast<T>(0.0);
T alpha_val = alpha[alpha_idx];
T alpha_diff_i = 0;
if (alpha_requires_grad) {
LoadPack dalpha_vec;
T zero_val = static_cast<T>(0.0);
#pragma unroll
for (int i = 0; i < pack_size; i++) {
if (x_vec.elem[i] > zero_val) {
dx_vec.elem[i] = dy_vec.elem[i];
dalpha_vec.elem[i] = zero_val;
} else {
dx_vec.elem[i] = dy_vec.elem[i] * alpha_val;
dalpha_vec.elem[i] = dy_vec.elem[i] * x_vec.elem[i];
for (int i = 0; i < pack_size; i++) {
if (x_vec.elem[i] > zero_val) {
dx_vec.elem[i] = dy_vec.elem[i];
dalpha_vec.elem[i] = zero_val;
} else {
dx_vec.elem[i] = dy_vec.elem[i] * alpha_val;
dalpha_vec.elem[i] = dy_vec.elem[i] * x_vec.elem[i];
}
}
*(reinterpret_cast<LoadType*>(dx + linear_index)) = dx_vec.storage;
*(reinterpret_cast<LoadType*>(alpha_diff + linear_index)) = dalpha_vec.storage;
} else {
#pragma unroll
for (int i = 0; i < pack_size; i++) {
if (x_vec.elem[i] > zero_val) {
dx_vec.elem[i] = dy_vec.elem[i];
} else {
dx_vec.elem[i] = dy_vec.elem[i] * alpha_val;
}
}
*(reinterpret_cast<LoadType*>(dx + linear_index)) = dx_vec.storage;
}
*(reinterpret_cast<LoadType*>(dx + linear_index)) = dx_vec.storage;
*(reinterpret_cast<LoadType*>(alpha_diff + linear_index)) = dalpha_vec.storage;
}
}

Expand Down Expand Up @@ -250,47 +272,77 @@ void DispatchPreluForwardIndex(ep::Stream* stream, const int64_t elem_cnt, const
template<typename T, typename IndexType, int32_t pack_size>
void DispatchPreluBackwardPackSize(ep::Stream* stream, const int64_t elem_cnt,
const int64_t alpha_size, const int64_t inner_size, const T* x,
const T* alpha, const T* dy, T* dx, T* alpha_diff) {
const T* alpha, const T* dy, T* dx, T* alpha_diff,
const bool alpha_requires_grad) {
const int64_t pack_num = elem_cnt / pack_size;
int grid_size;
cudaError_t err = cuda::elementwise::GetNumBlocks(pack_num, &grid_size);

if (pack_size >= 8 && inner_size % 8 == 0) {
PReluBackwardMultiAlphaGpu<T, IndexType, 8>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);
if (alpha_requires_grad) {
PReluBackwardMultiAlphaGpu<T, IndexType, 8, true>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);
} else {
PReluBackwardMultiAlphaGpu<T, IndexType, 8, false>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);
}
} else if (pack_size >= 4 && inner_size % 4 == 0) {
PReluBackwardMultiAlphaGpu<T, IndexType, 4>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);
if (alpha_requires_grad) {
PReluBackwardMultiAlphaGpu<T, IndexType, 4, true>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);
} else {
PReluBackwardMultiAlphaGpu<T, IndexType, 4, false>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);
}
} else if (pack_size >= 2 && inner_size % 2 == 0) {
PReluBackwardMultiAlphaGpu<T, IndexType, 2>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);
if (alpha_requires_grad) {
PReluBackwardMultiAlphaGpu<T, IndexType, 2, true>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);
} else {
PReluBackwardMultiAlphaGpu<T, IndexType, 2, false>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);
}

} else {
BroadcastPReluMultiAlphaNaiveBackwardGpu<T>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);
if (alpha_requires_grad) {
BroadcastPReluMultiAlphaNaiveBackwardGpu<T, true>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);
} else {
BroadcastPReluMultiAlphaNaiveBackwardGpu<T, false>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, alpha_size, inner_size, x, alpha, dy, dx, alpha_diff);
}
}
}

template<typename T>
void DispatchPreluBackwardIndex(ep::Stream* stream, const int64_t elem_cnt,
const int64_t alpha_size, const int64_t inner_size, const T* x,
const T* alpha, const T* dy, T* dx, T* alpha_diff) {
const T* alpha, const T* dy, T* dx, T* alpha_diff,
const bool alpha_requires_grad) {
constexpr int pack_size = cuda::elementwise::PackSize<T>();
if (elem_cnt < GetMaxVal<int32_t>()) {
DispatchPreluBackwardPackSize<T, int32_t, pack_size>(stream, elem_cnt, alpha_size, inner_size,
x, alpha, dy, dx, alpha_diff);
x, alpha, dy, dx, alpha_diff,
alpha_requires_grad);
} else {
DispatchPreluBackwardPackSize<T, int64_t, pack_size>(stream, elem_cnt, alpha_size, inner_size,
x, alpha, dy, dx, alpha_diff);
x, alpha, dy, dx, alpha_diff,
alpha_requires_grad);
}
}

template<typename T, typename IndexType>
void DispatchPreluBackwardSingleAlphaTail(ep::Stream* stream, const IndexType elem_cnt, const T* x,
const T* alpha, const T* dy, T* dx, T* alpha_diff) {
const T* alpha, const T* dy, T* dx, T* alpha_diff,
const bool alpha_requires_grad) {
constexpr int pack_size = cuda::elementwise::PackSize<T>();
const int64_t pack_num = elem_cnt / pack_size;
int grid_size;
Expand All @@ -299,29 +351,45 @@ void DispatchPreluBackwardSingleAlphaTail(ep::Stream* stream, const IndexType el
const int64_t n_tail = elem_cnt - tail_offset;
const bool tail = n_tail > 0 ? true : false;
if (tail) {
PReluBackwardSingleAlphaGpu<T, IndexType, pack_size, true>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset,
dx + tail_offset, alpha_diff + tail_offset);
if (alpha_requires_grad) {
PReluBackwardSingleAlphaGpu<T, IndexType, pack_size, true, true>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset,
dx + tail_offset, alpha_diff + tail_offset);
} else {
PReluBackwardSingleAlphaGpu<T, IndexType, pack_size, true, false>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset,
dx + tail_offset, alpha_diff + tail_offset);
}
} else {
PReluBackwardSingleAlphaGpu<T, IndexType, pack_size, false>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset,
dx + tail_offset, alpha_diff + tail_offset);
if (alpha_requires_grad) {
PReluBackwardSingleAlphaGpu<T, IndexType, pack_size, false, true>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset,
dx + tail_offset, alpha_diff + tail_offset);
} else {
PReluBackwardSingleAlphaGpu<T, IndexType, pack_size, false, false>
<<<grid_size, kBlockSize, 0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, n_tail, x, alpha, dy, dx, alpha_diff, x + tail_offset, dy + tail_offset,
dx + tail_offset, alpha_diff + tail_offset);
}
}
}

template<typename T>
void DispatchPreluBackwardSingleAlphaIndex(ep::Stream* stream, const int64_t elem_cnt, const T* x,
const T* alpha, const T* dy, T* dx, T* alpha_diff) {
const T* alpha, const T* dy, T* dx, T* alpha_diff,
const bool alpha_requires_grad) {
if (elem_cnt < GetMaxVal<int32_t>()) {
DispatchPreluBackwardSingleAlphaTail<T, int32_t>(stream, elem_cnt, x, alpha, dy, dx,
alpha_diff);
DispatchPreluBackwardSingleAlphaTail<T, int32_t>(stream, elem_cnt, x, alpha, dy, dx, alpha_diff,
alpha_requires_grad);
} else {
DispatchPreluBackwardSingleAlphaTail<T, int64_t>(stream, elem_cnt, x, alpha, dy, dx,
alpha_diff);
DispatchPreluBackwardSingleAlphaTail<T, int64_t>(stream, elem_cnt, x, alpha, dy, dx, alpha_diff,
alpha_requires_grad);
}
}

} // namespace

template<typename T>
Expand Down Expand Up @@ -380,6 +448,7 @@ class GpuPReluGradKernel final : public user_op::OpKernel {
user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
user_op::Tensor* alpha_diff = ctx->Tensor4ArgNameAndIndex("alpha_diff", 0);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const bool alpha_requires_grad = ctx->Attr<bool>("alpha_requires_grad");
const int32_t elem_cnt = x->shape().elem_cnt();
T* broadcasted_alpha_diff = tmp_buffer->mut_dptr<T>();
T* reduce_sum_tmp_buf = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>()
Expand All @@ -394,16 +463,18 @@ class GpuPReluGradKernel final : public user_op::OpKernel {
if (alpha_size == 1) {
DispatchPreluBackwardSingleAlphaIndex<T>(ctx->stream(), elem_cnt, x->dptr<T>(),
alpha->dptr<T>(), dy->dptr<T>(), dx->mut_dptr<T>(),
broadcasted_alpha_diff);
broadcasted_alpha_diff, alpha_requires_grad);
} else {
DispatchPreluBackwardIndex<T>(ctx->stream(), elem_cnt, alpha_size, inner_size, x->dptr<T>(),
alpha->dptr<T>(), dy->dptr<T>(), dx->mut_dptr<T>(),
broadcasted_alpha_diff);
broadcasted_alpha_diff, alpha_requires_grad);
}
if (alpha_requires_grad) {
NdarrayUtil<DeviceType::kCUDA, T>::ReduceSum(
ctx->stream(), XpuVarNdarray<T>(left_extended_shape, alpha_diff->mut_dptr<T>()),
XpuVarNdarray<const T>(x->shape(), broadcasted_alpha_diff),
XpuVarNdarray<T>(x->shape(), reduce_sum_tmp_buf));
}
NdarrayUtil<DeviceType::kCUDA, T>::ReduceSum(
ctx->stream(), XpuVarNdarray<T>(left_extended_shape, alpha_diff->mut_dptr<T>()),
XpuVarNdarray<const T>(x->shape(), broadcasted_alpha_diff),
XpuVarNdarray<T>(x->shape(), reduce_sum_tmp_buf));
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
Expand Down
16 changes: 9 additions & 7 deletions oneflow/user/ops/prelu_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,15 @@ REGISTER_USER_OP_GRAD("prelu").SetGenBackwardOpConfFn([](const user_op::UserOpWr
user_op::AddOpFn AddOp) -> Maybe<void> {
if (op.NeedGenGradTensor4OpInput("x", 0) || op.NeedGenGradTensor4OpInput("alpha", 0)) {
user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
user_op::UserOpConfWrapper grad_op = builder.Op("prelu_grad")
.Input("x", op.input("x", 0))
.Input("dy", op.GetGradTensorWithOpOutput("y", 0))
.Input("alpha", op.input("alpha", 0))
.Output("dx")
.Output("alpha_diff")
.Build();
user_op::UserOpConfWrapper grad_op =
builder.Op("prelu_grad")
.Input("x", op.input("x", 0))
.Input("dy", op.GetGradTensorWithOpOutput("y", 0))
.Input("alpha", op.input("alpha", 0))
.Output("dx")
.Output("alpha_diff")
.Attr("alpha_requires_grad", op.NeedGenGradTensor4OpInput("alpha", 0))
.Build();
AddOp(grad_op);

if (op.NeedGenGradTensor4OpInput("x", 0)) {
Expand Down