Skip to content

Commit

Permalink
[cherry-pick] fix the cumsum big shape and random bug
Browse files Browse the repository at this point in the history
  • Loading branch information
wawltor committed Jun 23, 2022
1 parent 90ae353 commit 5b17c53
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions paddle/phi/kernels/gpu/cumsum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,8 @@ __global__ void BlockScanKernel(T* d_out,
} temp_storage;

int bx = blockIdx.x;
int by = blockIdx.y;

BlockPrefixCallbackOp<T> prefix_op(0);
T block_aggregate = static_cast<T>(0);

// Obtain this block's segment of consecutive keys (blocked across threads)
int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
Expand All @@ -168,7 +166,7 @@ __global__ void BlockScanKernel(T* d_out,
valid_item = scan_size;
}

int offset = bx * scan_size + block_offset + by * (inner_size * scan_size);
int offset = block_offset + bx * scan_size;

T thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load)
Expand Down Expand Up @@ -260,8 +258,10 @@ void CumsumKernel(const Context& dev_ctx,
dim3 blocks(32, 8);
dim3 transpose_grids((width + tile_size - 1) / tile_size,
(height + tile_size - 1) / tile_size);
out->Resize(out_dims);
auto* tmp_data = out->data<T>();

DenseTensor tmp_tensor;
tmp_tensor.Resize(out_dims);
auto* tmp_data = dev_ctx.template Alloc<T>(&tmp_tensor);

T* next_in_data = out_data;
T* next_out_data = tmp_data;
Expand All @@ -281,6 +281,8 @@ void CumsumKernel(const Context& dev_ctx,
// Consider the size of shared memory, here block size is 128
dim3 scan_grid(outer_size, inner_size);
dim3 reverse_grid = scan_grid;
int64_t grid_size = outer_size * inner_size;

if (reverse) {
if (transpose) {
reverse_grid.x = scan_grid.y;
Expand All @@ -295,17 +297,17 @@ void CumsumKernel(const Context& dev_ctx,
}
}
if (!transpose && !reverse) {
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
BlockScanKernel<T, 128, 4><<<grid_size, 128, 0, dev_ctx.stream()>>>(
out_data, in_data, outer_size, inner_size, scan_size, exclusive);

} else {
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
next_out_data,
next_in_data,
outer_size,
inner_size,
scan_size,
exclusive);
BlockScanKernel<T, 128, 4>
<<<grid_size, 128, 0, dev_ctx.stream()>>>(next_out_data,
next_in_data,
outer_size,
inner_size,
scan_size,
exclusive);
}
swap_ptr(next_in_data, next_out_data);
if (reverse) {
Expand Down

1 comment on commit 5b17c53

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on 5b17c53 Jun 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #43777 Commit ID: 5b17c53 contains failed CI.

🔹 Failed: PR-CI-Mac-Python3-23

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-NPU-23

Unknown Failed
Unknown Failed

Please sign in to comment.