Skip to content

Commit

Permalink
auto format by CI
Browse files Browse the repository at this point in the history
  • Loading branch information
oneflow-ci-bot committed Dec 31, 2021
1 parent 3500ee0 commit 32b676a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
23 changes: 12 additions & 11 deletions oneflow/user/kernels/cumsum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ __global__ void CumsumForwardGpu(const T* in_ptr, T* out_ptr, int64_t cs_up_spac
}
template<typename T>
__global__ void CumsumForwardGpuUpSpaceIs1(const T* in_ptr, T* out_ptr, int64_t cs_space,
int64_t cs_down_space) {
int64_t cs_down_space) {
CUDA_1D_KERNEL_LOOP(i, cs_down_space) {
auto* in_ptr_base = in_ptr + i;
auto* out_ptr_base = out_ptr + i;
Expand All @@ -67,7 +67,7 @@ __global__ void CumsumForwardGpuUpSpaceIs1(const T* in_ptr, T* out_ptr, int64_t
}
template<typename T>
__global__ void CumsumForwardGpuDownSpaceIs1(const T* in_ptr, T* out_ptr, int64_t cs_up_space,
int64_t cs_space) {
int64_t cs_space) {
CUDA_1D_KERNEL_LOOP(i, cs_up_space) {
auto* in_ptr_base = in_ptr + i * cs_space;
auto* out_ptr_base = out_ptr + i * cs_space;
Expand All @@ -89,17 +89,18 @@ __global__ void CumsumForwardGpuDownSpaceIs1(const T* in_ptr, T* out_ptr, int64_
// ... ... ...
// dmn, ..., d1n, d0n
template<typename T>
__global__ void CumsumBackwardGpu(const T* in_ptr, T* out_ptr, int64_t cs_space, int64_t cs_down_space, int64_t elem_cnt) {
__global__ void CumsumBackwardGpu(const T* in_ptr, T* out_ptr, int64_t cs_space,
int64_t cs_down_space, int64_t elem_cnt) {
for (auto i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < elem_cnt;
i += step) {
auto tmp = cs_space * cs_down_space;
auto tmp = cs_space * cs_down_space;
auto cs_space_id = (i - (i / tmp) * tmp) / cs_down_space;
out_ptr[i] = (cs_space - cs_space_id) * in_ptr[i];
}
}
template<typename T>
__global__ void CumsumBackwardGpu_DownSpaceIs1(const T* in_ptr, T* out_ptr, int64_t cs_up_space, int64_t cs_space,
int64_t elem_cnt) {
__global__ void CumsumBackwardGpu_DownSpaceIs1(const T* in_ptr, T* out_ptr, int64_t cs_up_space,
int64_t cs_space, int64_t elem_cnt) {
for (auto i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < elem_cnt;
i += step) {
auto cs_space_id = i - (i / cs_space) * cs_space;
Expand Down Expand Up @@ -142,8 +143,8 @@ class GpuCumsumKernel final : public user_op::OpKernel {
RUN_CUDA_KERNEL((CumsumForwardGpuDownSpaceIs1<T>), ctx->stream(), thread_num, in_ptr, out_ptr,
cs_up_space, cs_space);
} else {
RUN_CUDA_KERNEL((CumsumForwardGpu<T>), ctx->stream(), thread_num, in_ptr, out_ptr, cs_up_space,
cs_space, cs_down_space);
RUN_CUDA_KERNEL((CumsumForwardGpu<T>), ctx->stream(), thread_num, in_ptr, out_ptr,
cs_up_space, cs_space, cs_down_space);
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
Expand Down Expand Up @@ -184,10 +185,10 @@ class GpuCumsumGradKernel final : public user_op::OpKernel {
auto thread_num = elem_cnt;

if (cs_down_space == 1) {
RUN_CUDA_KERNEL((CumsumBackwardGpu_DownSpaceIs1<T>), ctx->stream(), thread_num, in_ptr, out_ptr,
cs_up_space, cs_space, elem_cnt);
RUN_CUDA_KERNEL((CumsumBackwardGpu_DownSpaceIs1<T>), ctx->stream(), thread_num, in_ptr,
out_ptr, cs_up_space, cs_space, elem_cnt);
} else {
RUN_CUDA_KERNEL((CumsumBackwardGpu<T>), ctx->stream(), thread_num, in_ptr, out_ptr, cs_space,
RUN_CUDA_KERNEL((CumsumBackwardGpu<T>), ctx->stream(), thread_num, in_ptr, out_ptr, cs_space,
cs_down_space, elem_cnt);
}
}
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/test/modules/test_cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from oneflow.test_utils.automated_test_util import *


@flow.unittest.skip_unless_1n1d()
class TestCumsum(flow.unittest.TestCase):
@autotest(n=30, check_graph=True)
Expand Down

0 comments on commit 32b676a

Please sign in to comment.