From 91ec1bdcaa4c6cf7b84632895be39d12764e52a6 Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Thu, 28 Aug 2025 16:59:26 +0000 Subject: [PATCH 1/3] [WIP] Unclear bug for static op grad. --- .../kernels/funcs/gather_scatter_functor.cu | 789 ++++++++---------- 1 file changed, 338 insertions(+), 451 deletions(-) diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index f73f8005e90d6c..929854353708d3 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -88,6 +88,27 @@ struct DivMod { } }; +template +static T ExcludeSelfInitialValue(const std::string& reduce_op) { + if (reduce_op == "add") { + return static_cast(0); + } else if (reduce_op == "mul") { + return static_cast(1); + } else if (reduce_op == "max") { + return std::numeric_limits::lowest(); + } else if (reduce_op == "min") { + return std::numeric_limits::max(); + } else if (reduce_op == "mean") { + return static_cast(0); + } else { + PADDLE_ENFORCE_EQ( + 0, + 1, + common::errors::InvalidArgument( + "Unsupported or unnecessary (assign) reduce op: '%s'", reduce_op)); + } +} + // compute two offsets for self tensor and src tensor // if compute_self is true, other wise only src_offset is useful // TODO(heqianyue): remove force inline? @@ -131,6 +152,51 @@ __device__ __forceinline__ void ComputeOffset( if constexpr (compute_self) *input_offset = _input_offset; } +#define COMPUTE_OFFSET_SINGLE_OUTPUT( \ + var_name, smem_offset, id_var_name, copy_size) \ + extern __shared__ int64_t smem_shape_strides[]; \ + int64_t id_var_name = threadIdx.x + blockIdx.x * blockDim.x; \ + if (threadIdx.x < (copy_size * ndim)) { \ + *(smem_shape_strides + threadIdx.x) = *(shape_strides + threadIdx.x); \ + } \ + __syncthreads(); \ + if (id_var_name >= numel) return; \ + int64_t var_name = 0; \ + index_t index = index_data[id_var_name]; \ + const int64_t* stride_info = smem_shape_strides + smem_offset * ndim; \ + ComputeOffset(smem_shape_strides, \ + stride_info, \ + nullptr, \ + &var_name, \ + nullptr, \ + id_var_name, \ + ndim, \ + dim, \ + index); + +#define COMPUTE_OFFSET_DOUBLE_OUTPUT( \ + var_name1, var_name2, id_var_name, offset1, offset2) \ + extern __shared__ int64_t smem_shape_strides[]; \ + int64_t id_var_name = threadIdx.x + blockIdx.x * blockDim.x; \ + if (threadIdx.x < (3 * ndim)) { \ + *(smem_shape_strides + threadIdx.x) = *(shape_strides + threadIdx.x); \ + } \ + __syncthreads(); \ + if (id_var_name >= numel) return; \ + index_t index = index_data[id_var_name]; \ + const int64_t* grad_strides = smem_shape_strides + offset1 * ndim; \ + const int64_t* self_strides = smem_shape_strides + offset2 * ndim; \ + int64_t var_name1 = 0, var_name2 = 0; \ + ComputeOffset(smem_shape_strides, \ + grad_strides, \ + self_strides, \ + &var_name1, \ + &var_name2, \ + id_var_name, \ + ndim, \ + dim, \ + index); + /** * The assign / add / mul / min / max kernels can actually be unified * @@ -150,8 +216,7 @@ __device__ __forceinline__ void ComputeOffset( template + bool is_scatter_like = true> __global__ void GatherScatterGPUKernel( tensor_t* __restrict__ self_data, const index_t* __restrict__ index_data, @@ -163,7 +228,7 @@ __global__ void GatherScatterGPUKernel( int dim, int ndim, const func_t& reduce_op, - int* __restrict__ aux_buffer = nullptr) { + int* __restrict__ atomic_cnt_buffer = nullptr) { extern __shared__ int64_t smem_shape_strides[]; // no more than 27 int64_t, won't affect occupancy @@ -223,46 +288,59 @@ __global__ void GatherScatterGPUKernel( ndim, dim, index); - if constexpr (include_self) { - // unordered-writes branch has the same behavior as torch's. Strangely, - // the old impl performs ordered access for assign (maybe it is because - // there was no atomic primitives for assign), and for other ops, - // unordered atomic access is used - reduce_op(static_cast(self_data + replace_index_self), - static_cast(src_data + replace_index_src)); - } else { - bool is_op_done = false; - phi::CudaAtomicMin(aux_buffer + replace_index_self, tid); - __syncthreads(); - if (tid == aux_buffer[replace_index_self]) { - self_data[replace_index_self] = src_data[replace_index_src]; - is_op_done = true; - } - __syncthreads(); - if (!is_op_done) - reduce_op(static_cast(self_data + replace_index_self), - static_cast(src_data + replace_index_src)); + + reduce_op(static_cast(self_data + replace_index_self), + static_cast(src_data + replace_index_src)); + if (atomic_cnt_buffer) { + phi::CudaAtomicAdd(atomic_cnt_buffer + replace_index_self, 1); } } -template -__global__ void ScatterMeanGPUKernel( - tensor_t* __restrict__ self_data, +// TODO(heqianyue): to fully match the behavior of PyTorch, we should implement +// a integer div (floor) in this kernel, instead of default trunc (to zero) div +template +__global__ void CastDivKernel(tensor_t* __restrict__ self_data, + int* __restrict__ atomic_cnt_buffer, + int64_t numel) { + // mean kernel has only one purpose after refactoring: div by count + // to fuse the kernel into other kernels (like scatter add), we might need + // semaphores to notify when all blocks are done adding. By now, we choose + // this simpler implementation + + int64_t tid = threadIdx.x + static_cast(blockIdx.x) * blockDim.x; + if (tid >= numel) return; + self_data[tid] /= static_cast(atomic_cnt_buffer[tid]); +} + +/** + * Faster pass for scattering a scalar value. + * + * For future optimization: + * TODO(heqianyue): if, for example, the `values` for put_along_axis (and other + * APIs that use scatter kernels) is a scalar, for broadcast=True mode, the + * scalar will be made a tensor and broadcast to specific shape, which is + * wasteful, if actual memory allocation does happen below the hood. We can + * create a special fast pass based on this kernel, to scatter a single scalar + * faster, with less memory consumption, since the current kernel eliminates the + * need for `broadcast_to` and aux_tensor, which might cut the overhead of the + * kernel by more than half. + * + * To upgrade the scalar scatter, one needs to add func_t and reduce_op in the + * kernel, but be aware that, to be backward-compatible with the behaviors in + * the old versions, extra atomic primitives might be needed to make sure the + * correct ordering of stores. + */ +template +__global__ void ScatterAssignScalarValue( + tensor_t* __restrict__ input_data, const index_t* __restrict__ index_data, const int64_t* __restrict__ shape_strides, - const tensor_t* __restrict__ src_data, int64_t self_select_dim_size, - int64_t src_select_dim_size, + tensor_t value_to_scatter, int64_t numel, int dim, int ndim, - const func_t& reduce_op, - bool include_self = true, - int* __restrict__ aux_buffer = nullptr, - int* __restrict__ atomic_cnt_buffer = nullptr) { + int* aux_buffer = nullptr) { extern __shared__ int64_t smem_shape_strides[]; // no more than 27 int64_t, won't affect occupancy @@ -271,74 +349,30 @@ __global__ void ScatterMeanGPUKernel( *(smem_shape_strides + threadIdx.x) = *(shape_strides + threadIdx.x); } __syncthreads(); - // we need threads to complete memory write to smem, even if current thread is - // out of bound if (tid >= numel) return; index_t index = index_data[tid]; + if (index < 0) index += static_cast(self_select_dim_size); - const int64_t* src_strides = smem_shape_strides + ndim; - const int64_t* input_strides = nullptr; + // some kernels might store input_strides differently! Be careful when dealing + // with this. + const int64_t* input_strides = smem_shape_strides + 2 * ndim; // index matrix has different shape with self matrix or src matrix. - int64_t replace_index_self = 0, replace_index_src = 0; - if constexpr (is_scatter_like) { - input_strides = smem_shape_strides + - ndim * 2; // gather pass actually does not need this - // scatter - PADDLE_ENFORCE( - index >= -self_select_dim_size && index < self_select_dim_size, - "The index is out of bounds, " - "please check whether the index and " - "input's shape meet the requirements. It should " - "be greater or equal to [%d] and less than [%d], but received [%ld]", - -self_select_dim_size, - self_select_dim_size, - (int64_t)index); - if (index < 0) { - index += self_select_dim_size; - } - } else { - // gather - PADDLE_ENFORCE( - index >= -src_select_dim_size && index < src_select_dim_size, - "The index is out of bounds, " - "please check whether the index and " - "input's shape meet the requirements. It should " - "be greater or equal to [%d] and less than [%d], but received [%d]", - -src_select_dim_size, - src_select_dim_size, - (int32_t)index); - if (index < 0) { - index += src_select_dim_size; - } - replace_index_self = tid; - } - ComputeOffset(smem_shape_strides, - src_strides, - input_strides, - &replace_index_src, - &replace_index_self, - tid, - ndim, - dim, - index); - if (!include_self) { - self_data[replace_index_self] = 0; - __syncthreads(); - } - - reduce_op(static_cast(self_data + replace_index_self), - static_cast(src_data + replace_index_src)); - - // So this is the culprit - phi::CudaAtomicMax(aux_buffer + replace_index_self, tid); - phi::CudaAtomicAdd(atomic_cnt_buffer + replace_index_self, 1); - __syncthreads(); + int64_t replace_index_self = 0; + ComputeOffset(smem_shape_strides, + input_strides, + nullptr, + &replace_index_self, + nullptr, + tid, + ndim, + dim, + index); - if (tid == aux_buffer[replace_index_self]) { - self_data[replace_index_self] = - self_data[replace_index_self] / - static_cast(atomic_cnt_buffer[replace_index_self]); + input_data[replace_index_self] = value_to_scatter; + if (aux_buffer) { + // fused: used in mean pass, aux_buffer has the same shape as input + aux_buffer[replace_index_self] = 0; } } @@ -441,6 +475,7 @@ struct gpu_gather_scatter_functor { if (index.numel() == 0) { return; } + auto* self_data = self.data(); const auto* index_data = index.data(); const auto* src_data = src.data(); @@ -451,29 +486,13 @@ struct gpu_gather_scatter_functor { auto index_dims = index.dims(); auto src_dims = src.dims(); if (self_size == 0 || src_size == 0 || index_size == 0) return; - int64_t select_dim_size = index_dims[dim]; - // index matrix has different shape with self matrix or src matrix. + // index matrix might have different shape with self matrix or src matrix. int64_t self_select_dim_size = self_dims[dim]; int64_t src_select_dim_size = src_dims[dim]; - int64_t inner_dim_size = 1; - int64_t outer_dim_size = 1; - for (int64_t i = 0; i < dim; ++i) { - inner_dim_size *= index_dims[i]; - } - for (int i = dim + 1; i < index_dims.size(); i++) { - outer_dim_size *= index_dims[i]; - } constexpr int block = 512; - int64_t n = inner_dim_size * select_dim_size * outer_dim_size; - int64_t grid = (n + block - 1) / block; + int64_t grid = (index_size + block - 1) / block; auto stream = reinterpret_cast(dev_ctx).stream(); - DenseTensor shared_mem_tensor; - if (method_name == "scatter_assign_gpu") { - shared_mem_tensor.Resize({self_size}); - auto* winners = dev_ctx.Alloc(&shared_mem_tensor); - phi::funcs::set_constant(dev_ctx, &shared_mem_tensor, 0); - } int64_t ndim = index.dims().size(); @@ -500,7 +519,7 @@ struct gpu_gather_scatter_functor { const size_t shared_mem_bytes = sizeof(int64_t) * shape_stride_dev.numel(); DenseTensor aux_tensor; - if (method_name == "scatter_assign_gpu") { + if (method_name == "assign") { aux_tensor.Resize({self_size}); dev_ctx.Alloc(&aux_tensor); phi::funcs::set_constant(dev_ctx, &aux_tensor, 0); @@ -526,73 +545,55 @@ struct gpu_gather_scatter_functor { index_size, dim, ndim); - } else if (method_name == "scatter_mean_gpu") { - // TODO(heqianyue): the original impl is too wasteful, this can be - // optimized - DenseTensor atomic_cnt_tensor; - aux_tensor.Resize({self_size}); + return; + } + + // completely eliminate the need for aux_buffer! For most cases we can have + // up to 50% memory reduction! + DenseTensor atomic_cnt_tensor; + int* atomic_cnt_buffer = nullptr; + if (method_name == "mean") { atomic_cnt_tensor.Resize({self_size}); - dev_ctx.Alloc(&aux_tensor); dev_ctx.Alloc(&atomic_cnt_tensor); + phi::funcs::set_constant(dev_ctx, &atomic_cnt_tensor, 1); + atomic_cnt_buffer = atomic_cnt_tensor.data(); + } + if (!include_self) { + tensor_t init_val = ExcludeSelfInitialValue(method_name); + // exclude self requires us to overwrite the positions that will have + // values scattered, we cannot fuse the kernels all in one in a simple + // way, since when shape is large, atomic primitives will only be synced + // intra-block-ly, resulting in incorrect results, should inter-block + // atomic reduce occur. + ScatterAssignScalarValue<<>>( + self_data, + index_data, + shape_strides, + self_select_dim_size, + init_val, + index_size, + dim, + ndim, + atomic_cnt_buffer); + } - // threadidx must start with 0, otherwise atomicMax will be faulty - phi::funcs::set_constant(dev_ctx, &aux_tensor, 0); - phi::funcs::set_constant( - dev_ctx, &atomic_cnt_tensor, include_self ? 1 : 0); - - int* aux_buffer = aux_tensor.data(); - int* atomic_cnt_buffer = atomic_cnt_tensor.data(); - ScatterMeanGPUKernel - <<>>(self_data, - index_data, - shape_strides, - src_data, - self_select_dim_size, - src_select_dim_size, - index_size, - dim, - ndim, - reduce_op, - include_self, - aux_buffer, - atomic_cnt_buffer); - } else { - if (include_self) { - GatherScatterGPUKernel - <<>>(self_data, - index_data, - shape_strides, - src_data, - self_select_dim_size, - src_select_dim_size, - index_size, - dim, - ndim, - reduce_op, - nullptr); - } else { - aux_tensor.Resize({self_size}); - dev_ctx.Alloc(&aux_tensor); - phi::funcs::set_constant(dev_ctx, &aux_tensor, index_size + 1); - - int* aux_buffer = aux_tensor.data(); - GatherScatterGPUKernel - <<>>(self_data, - index_data, - shape_strides, - src_data, - self_select_dim_size, - src_select_dim_size, - index_size, - dim, - ndim, - reduce_op, - aux_buffer); - } + GatherScatterGPUKernel + <<>>(self_data, + index_data, + shape_strides, + src_data, + self_select_dim_size, + src_select_dim_size, + index_size, + dim, + ndim, + reduce_op, + atomic_cnt_buffer); + if (method_name == "mean") { + constexpr int _block = 512; + int64_t grid = (self_size + _block - 1) / _block; + CastDivKernel<<>>( + self_data, atomic_cnt_buffer, self_size); } } }; // struct gpu_gather_scatter_functor @@ -606,14 +607,8 @@ void gpu_gather_kernel(phi::DenseTensor self, const phi::DeviceContext& dev_ctx) { gpu_gather_scatter_functor()(result, - dim, - index, - self, - "gather_out_gpu", - tensor_assign, - include_self, - dev_ctx); + /*is_scatter_like=*/false>()( + result, dim, index, self, "assign", tensor_assign, include_self, dev_ctx); return; } @@ -626,14 +621,8 @@ void gpu_scatter_assign_kernel(phi::DenseTensor self, const phi::DeviceContext& dev_ctx) { gpu_gather_scatter_functor()(self, - dim, - index, - src, - "scatter_assign_gpu", - tensor_assign, - include_self, - dev_ctx); + /*is_scatter_like=*/true>()( + self, dim, index, src, "assign", tensor_assign, include_self, dev_ctx); } template @@ -645,14 +634,8 @@ void gpu_scatter_add_kernel(phi::DenseTensor self, const phi::DeviceContext& dev_ctx) { gpu_gather_scatter_functor()(self, - dim, - index, - src, - "scatter_add_gpu", - reduce_add, - include_self, - dev_ctx); + /*is_scatter_like=*/true>()( + self, dim, index, src, "add", reduce_add, include_self, dev_ctx); } template @@ -664,14 +647,8 @@ void gpu_scatter_mul_kernel(phi::DenseTensor self, const phi::DeviceContext& dev_ctx) { gpu_gather_scatter_functor()(self, - dim, - index, - src, - "scatter_mul_gpu", - reduce_mul, - include_self, - dev_ctx); + /*is_scatter_like=*/true>()( + self, dim, index, src, "mul", reduce_mul, include_self, dev_ctx); } template @@ -683,14 +660,8 @@ void gpu_scatter_mean_kernel(phi::DenseTensor self, const phi::DeviceContext& dev_ctx) { gpu_gather_scatter_functor()(self, - dim, - index, - src, - "scatter_mean_gpu", - reduce_add, - include_self, - dev_ctx); + /*is_scatter_like=*/true>()( + self, dim, index, src, "mean", reduce_add, include_self, dev_ctx); } template @@ -702,14 +673,8 @@ void gpu_scatter_max_kernel(phi::DenseTensor self, const phi::DeviceContext& dev_ctx) { gpu_gather_scatter_functor()(self, - dim, - index, - src, - "scatter_max_gpu", - reduce_max, - include_self, - dev_ctx); + /*is_scatter_like=*/true>()( + self, dim, index, src, "max", reduce_max, include_self, dev_ctx); } template @@ -721,14 +686,8 @@ void gpu_scatter_min_kernel(phi::DenseTensor self, const phi::DeviceContext& dev_ctx) { gpu_gather_scatter_functor()(self, - dim, - index, - src, - "scatter_min_gpu", - reduce_min, - include_self, - dev_ctx); + /*is_scatter_like=*/true>()( + self, dim, index, src, "min", reduce_min, include_self, dev_ctx); } template @@ -827,6 +786,62 @@ void gpu_scatter_input_grad_kernel(phi::DenseTensor self, index_size); } +namespace { +enum GradDispatchTag { + MulInputGrad = 0x0, + MinMaxInputGrad, + MeanInputGrad, + ValueGrad, + MeanValueGrad, + MinMaxValueGrad, +}; +} // anonymous namespace + +template +__global__ void ScatterGradPrePassKernel( + tensor_t* __restrict__ grad_data, + const index_t* __restrict__ index_data, + const tensor_t* __restrict__ out_data, + const tensor_t* __restrict__ value_data, + const tensor_t* __restrict__ x_data, + const int64_t* __restrict__ shape_strides, + int dim, + int ndim, + int64_t numel, + int64_t grad_numel, + int* __restrict__ aux_buffer, + bool include_self = true) { + if constexpr (dispatch == GradDispatchTag::MulInputGrad) { + COMPUTE_OFFSET_SINGLE_OUTPUT(replace_index, 1, tid, 2) + atomicMax(aux_buffer + replace_index, tid); + } else if constexpr (dispatch == GradDispatchTag::MinMaxInputGrad) { + // This is a special case, src is stored in shape_strides + 2 * dim but used + // as the 2nd param for compute offset + COMPUTE_OFFSET_DOUBLE_OUTPUT(replace_index_value, replace_index, tid, 2, 1) + if (value_data[replace_index_value] == out_data[replace_index]) + phi::CudaAtomicAdd(aux_buffer + replace_index, 1); + } else if constexpr (dispatch == GradDispatchTag::MeanInputGrad) { + COMPUTE_OFFSET_SINGLE_OUTPUT(replace_index, 1, tid, 2) + atomicMax(aux_buffer + replace_index, tid); + phi::CudaAtomicAdd(aux_buffer + grad_numel + replace_index, 1); + } else if constexpr (dispatch == GradDispatchTag::ValueGrad) { + COMPUTE_OFFSET_SINGLE_OUTPUT(replace_index_self, 2, tid, 3) + atomicMax(aux_buffer + replace_index_self, tid); + } else if constexpr (dispatch == GradDispatchTag::MeanValueGrad) { + COMPUTE_OFFSET_SINGLE_OUTPUT(replace_index_self, 2, tid, 3) + phi::CudaAtomicAdd(aux_buffer + replace_index_self, 1); + } else if constexpr (dispatch == GradDispatchTag::MinMaxValueGrad) { + COMPUTE_OFFSET_DOUBLE_OUTPUT( + replace_index_grad, replace_index_self, tid, 1, 2) + grad_data[replace_index_grad] = 0; + if (include_self && + x_data[replace_index_self] == out_data[replace_index_self]) + phi::CudaAtomicAdd(aux_buffer + replace_index_self, 1); + if (value_data[replace_index_grad] == out_data[replace_index_self]) + phi::CudaAtomicAdd(aux_buffer + replace_index_self, 1); + } +} + template __global__ void ScatterMulInputGradGPUKernel( tensor_t* __restrict__ grad_data, @@ -838,31 +853,7 @@ __global__ void ScatterMulInputGradGPUKernel( int ndim, int64_t numel, int* __restrict__ aux_buffer) { - extern __shared__ int64_t smem_shape_strides[]; - int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (threadIdx.x < (2 * ndim)) { - *(smem_shape_strides + threadIdx.x) = *(shape_strides + threadIdx.x); - } - __syncthreads(); - if (tid >= numel) return; - - int64_t replace_index = 0; - index_t index = index_data[tid]; - // the second `ndim` elements are not used in this kernel - const int64_t* grad_strides = smem_shape_strides + ndim; - - ComputeOffset(smem_shape_strides, - grad_strides, - nullptr, - &replace_index, - nullptr, - tid, - ndim, - dim, - index); - atomicMax(aux_buffer + replace_index, tid); - __syncthreads(); + COMPUTE_OFFSET_SINGLE_OUTPUT(replace_index, 1, tid, 2) if (tid == aux_buffer[replace_index]) { grad_data[replace_index] = grad_data[replace_index] * out_data[replace_index] / x_data[replace_index]; @@ -875,42 +866,13 @@ __global__ void ScatterMinMaxInputGradGPUKernel( const index_t* __restrict__ index_data, const tensor_t* __restrict__ out_data, const tensor_t* __restrict__ x_data, - const tensor_t* __restrict__ value_data, const tensor_t* __restrict__ self_data, const int64_t* __restrict__ shape_strides, int dim, int ndim, int64_t numel, int* __restrict__ aux_buffer) { - extern __shared__ int64_t smem_shape_strides[]; - int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (threadIdx.x < (3 * ndim)) { - *(smem_shape_strides + threadIdx.x) = *(shape_strides + threadIdx.x); - } - __syncthreads(); - if (tid >= numel) return; - - index_t index = index_data[tid]; - const int64_t* grad_strides = smem_shape_strides + ndim; - const int64_t* src_strides = smem_shape_strides + 2 * ndim; - - int64_t replace_index = 0, replace_index_value = 0; - // the ordering of src_strides and grad_strides in the following function - // param is correct - ComputeOffset(smem_shape_strides, - src_strides, - grad_strides, - &replace_index_value, - &replace_index, - tid, - ndim, - dim, - index); - - if (value_data[replace_index_value] == out_data[replace_index]) - phi::CudaAtomicAdd(aux_buffer + replace_index, 1); - __syncthreads(); + COMPUTE_OFFSET_SINGLE_OUTPUT(replace_index, 2, tid, 3) if (out_data[replace_index] != x_data[replace_index]) { grad_data[replace_index] = 0; } else { @@ -988,6 +950,19 @@ void gpu_scatter_mul_min_max_input_grad_kernel( if (reduce == "mul" || reduce == "multiply") { phi::funcs::set_constant(dev_ctx, &aux_tensor, 0); shared_mem_bytes *= 2; // 1 stride, 1 shape + + ScatterGradPrePassKernel + <<>>(grad_data, + index_data, + out_data, + value_data, + x_data, + shape_strides, + dim, + ndim, + index.numel(), + grad.numel(), + aux_buffer); ScatterMulInputGradGPUKernel <<>>(grad_data, index_data, @@ -1001,12 +976,24 @@ void gpu_scatter_mul_min_max_input_grad_kernel( } else if (reduce == "amin" || reduce == "amax") { phi::funcs::set_constant(dev_ctx, &aux_tensor, 1); shared_mem_bytes *= 3; // two strides, 1 shape + + ScatterGradPrePassKernel + <<>>(grad_data, + index_data, + out_data, + value_data, + x_data, + shape_strides, + dim, + ndim, + index.numel(), + grad.numel(), + aux_buffer); ScatterMinMaxInputGradGPUKernel <<>>(grad_data, index_data, out_data, x_data, - value_data, self_data, shape_strides, dim, @@ -1026,32 +1013,7 @@ __global__ void ScatterMeanInputGradGPUKernel( int64_t numel, int64_t grad_numel, int* __restrict__ aux_buffer) { - extern __shared__ int64_t smem_shape_strides[]; - int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (threadIdx.x < (2 * ndim)) { - *(smem_shape_strides + threadIdx.x) = *(shape_strides + threadIdx.x); - } - __syncthreads(); - if (tid >= numel) return; - - index_t index = index_data[tid]; - const int64_t* grad_strides = smem_shape_strides + ndim; - - int64_t replace_index = 0; - ComputeOffset(smem_shape_strides, - grad_strides, - nullptr, - &replace_index, - nullptr, - tid, - ndim, - dim, - index); - - atomicMax(aux_buffer + replace_index, tid); - phi::CudaAtomicAdd(aux_buffer + grad_numel + replace_index, 1); - __syncthreads(); + COMPUTE_OFFSET_SINGLE_OUTPUT(replace_index, 1, tid, 2) if (tid == aux_buffer[replace_index]) { grad_data[replace_index] = grad_data[replace_index] / @@ -1120,6 +1082,18 @@ void gpu_scatter_mean_input_grad_kernel(phi::DenseTensor self, const int64_t* shape_strides = shape_stride_dev.data(); size_t shared_mem_bytes = sizeof(int64_t) * ndim * 2; + ScatterGradPrePassKernel + <<>>(grad_data, + index_data, + nullptr, + nullptr, + nullptr, + shape_strides, + dim, + ndim, + index.numel(), + grad_size, + aux_buffer); ScatterMeanInputGradGPUKernel <<>>(grad_data, index_data, @@ -1141,33 +1115,8 @@ __global__ void ScatterValueGradGPUKernel( int ndim, int64_t numel, int* __restrict__ aux_buffer) { - extern __shared__ int64_t smem_shape_strides[]; - int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (threadIdx.x < (3 * ndim)) { - *(smem_shape_strides + threadIdx.x) = *(shape_strides + threadIdx.x); - } - __syncthreads(); - if (tid >= numel) return; - - index_t index = index_data[tid]; - const int64_t* grad_strides = smem_shape_strides + ndim; - const int64_t* self_strides = smem_shape_strides + 2 * ndim; - - int64_t replace_index_self = 0, replace_index_grad = 0; - ComputeOffset(smem_shape_strides, - grad_strides, - self_strides, - &replace_index_grad, - &replace_index_self, - tid, - ndim, - dim, - index); - - atomicMax(aux_buffer + replace_index_self, tid); - __syncthreads(); - + COMPUTE_OFFSET_DOUBLE_OUTPUT( + replace_index_grad, replace_index_self, tid, 1, 2) if (tid == aux_buffer[replace_index_self]) { grad_data[replace_index_grad] = self_data[replace_index_self]; } @@ -1230,6 +1179,18 @@ void gpu_scatter_value_grad_kernel(phi::DenseTensor self, const int64_t* shape_strides = shape_stride_dev.data(); size_t shared_mem_bytes = sizeof(int64_t) * ndim * 3; + ScatterGradPrePassKernel + <<>>(grad_data, + index_data, + nullptr, + nullptr, + nullptr, + shape_strides, + dim, + ndim, + index.numel(), + grad.numel(), + aux_buffer); ScatterValueGradGPUKernel <<>>(grad_data, self_data, @@ -1251,33 +1212,8 @@ __global__ void ScatterMeanValueGradGPUKernel( int ndim, int64_t numel, int* __restrict__ aux_buffer) { - extern __shared__ int64_t smem_shape_strides[]; - int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (threadIdx.x < (3 * ndim)) { - *(smem_shape_strides + threadIdx.x) = *(shape_strides + threadIdx.x); - } - __syncthreads(); - if (tid >= numel) return; - - index_t index = index_data[tid]; - const int64_t* grad_strides = smem_shape_strides + ndim; - const int64_t* self_strides = smem_shape_strides + 2 * ndim; - - int64_t replace_index_self = 0, replace_index_grad = 0; - ComputeOffset(smem_shape_strides, - grad_strides, - self_strides, - &replace_index_grad, - &replace_index_self, - tid, - ndim, - dim, - index); - - phi::CudaAtomicAdd(aux_buffer + replace_index_self, 1); - __syncthreads(); - + COMPUTE_OFFSET_DOUBLE_OUTPUT( + replace_index_grad, replace_index_self, tid, 1, 2) grad_data[replace_index_grad] = self_data[replace_index_self] / static_cast(aux_buffer[replace_index_self]); @@ -1292,29 +1228,8 @@ __global__ void ScatterAddValueGradGPUKernel( int dim, int ndim, int64_t numel) { - extern __shared__ int64_t smem_shape_strides[]; - int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (threadIdx.x < (3 * ndim)) { - *(smem_shape_strides + threadIdx.x) = *(shape_strides + threadIdx.x); - } - __syncthreads(); - if (tid >= numel) return; - - index_t index = index_data[tid]; - const int64_t* grad_strides = smem_shape_strides + ndim; - const int64_t* self_strides = smem_shape_strides + 2 * ndim; - - int64_t replace_index_self = 0, replace_index_grad = 0; - ComputeOffset(smem_shape_strides, - grad_strides, - self_strides, - &replace_index_grad, - &replace_index_self, - tid, - ndim, - dim, - index); + COMPUTE_OFFSET_DOUBLE_OUTPUT( + replace_index_grad, replace_index_self, tid, 1, 2) grad_data[replace_index_grad] = self_data[replace_index_self]; } @@ -1380,6 +1295,18 @@ void gpu_scatter_add_mean_value_grad_kernel( dev_ctx.Alloc(&aux_tensor); phi::funcs::set_constant(dev_ctx, &aux_tensor, include_self ? 1 : 0); int* aux_buffer = aux_tensor.data(); + ScatterGradPrePassKernel + <<>>(grad_data, + index_data, + nullptr, + nullptr, + nullptr, + shape_strides, + dim, + ndim, + index.numel(), + grad.numel(), + aux_buffer); ScatterMeanValueGradGPUKernel <<>>(grad_data, self_data, @@ -1412,29 +1339,8 @@ __global__ void ScatterMulValueGradGPUKernel( int dim, int ndim, int64_t numel) { - extern __shared__ int64_t smem_shape_strides[]; - int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (threadIdx.x < (3 * ndim)) { - *(smem_shape_strides + threadIdx.x) = *(shape_strides + threadIdx.x); - } - __syncthreads(); - if (tid >= numel) return; - - index_t index = index_data[tid]; - const int64_t* grad_strides = smem_shape_strides + ndim; - const int64_t* self_strides = smem_shape_strides + 2 * ndim; - - int64_t replace_index_self = 0, replace_index_grad = 0; - ComputeOffset(smem_shape_strides, - grad_strides, - self_strides, - &replace_index_grad, - &replace_index_self, - tid, - ndim, - dim, - index); + COMPUTE_OFFSET_DOUBLE_OUTPUT( + replace_index_grad, replace_index_self, tid, 1, 2) grad_data[replace_index_grad] = self_data[replace_index_self] * (out_data[replace_index_self] / value_data[replace_index_grad]); @@ -1447,45 +1353,14 @@ __global__ void ScatterMinMaxValueGradGPUKernel( const tensor_t* __restrict__ self_data, const tensor_t* __restrict__ value_data, const tensor_t* __restrict__ out_data, - const tensor_t* __restrict__ x_data, const int64_t* __restrict__ shape_strides, int dim, int ndim, int64_t numel, bool include_self, int* __restrict__ aux_buffer) { - extern __shared__ int64_t smem_shape_strides[]; - int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; - - if (threadIdx.x < (3 * ndim)) { - *(smem_shape_strides + threadIdx.x) = *(shape_strides + threadIdx.x); - } - __syncthreads(); - if (tid >= numel) return; - - index_t index = index_data[tid]; - const int64_t* grad_strides = smem_shape_strides + ndim; - const int64_t* self_strides = smem_shape_strides + 2 * ndim; - - int64_t replace_index_self = 0, replace_index_grad = 0; - ComputeOffset(smem_shape_strides, - grad_strides, - self_strides, - &replace_index_grad, - &replace_index_self, - tid, - ndim, - dim, - index); - - if (include_self && - x_data[replace_index_self] == out_data[replace_index_self]) - phi::CudaAtomicAdd(aux_buffer + replace_index_self, 1); - __syncthreads(); - grad_data[replace_index_grad] = 0; - if (value_data[replace_index_grad] == out_data[replace_index_self]) - phi::CudaAtomicAdd(aux_buffer + replace_index_self, 1); - __syncthreads(); + COMPUTE_OFFSET_DOUBLE_OUTPUT( + replace_index_grad, replace_index_self, tid, 1, 2) if (value_data[replace_index_grad] == out_data[replace_index_self]) grad_data[replace_index_grad] = self_data[replace_index_self] / @@ -1569,13 +1444,25 @@ void gpu_scatter_mul_min_max_value_grad_kernel( phi::funcs::set_constant(dev_ctx, &aux_tensor, 0); int* aux_buffer = aux_tensor.data(); + ScatterGradPrePassKernel + <<>>(grad_data, + index_data, + out_data, + value_data, + x_data, + shape_strides, + dim, + ndim, + index.numel(), + grad.numel(), + aux_buffer, + include_self); ScatterMinMaxValueGradGPUKernel <<>>(grad_data, index_data, self_data, value_data, out_data, - x_data, shape_strides, dim, ndim, From ab4a56e3228d94e160f88880854d9be3078f9a00 Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Fri, 29 Aug 2025 01:05:19 +0000 Subject: [PATCH 2/3] [PHI] scatter/gather two stage kernels are ready --- .../kernels/funcs/gather_scatter_functor.cu | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index 929854353708d3..0b4bcd176aa686 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -77,17 +77,6 @@ __global__ void CudaMemsetAsync(int* dest, int value, size_t size) { dest[tid] = value; } -struct DivMod { - template - static __device__ __forceinline__ void divmod(T dividend, - T divisor, - T* __restrict__ quotient, - T* __restrict__ remainder) { - *quotient = dividend / divisor; - *remainder = dividend % divisor; - } -}; - template static T ExcludeSelfInitialValue(const std::string& reduce_op) { if (reduce_op == "add") { @@ -109,6 +98,17 @@ static T ExcludeSelfInitialValue(const std::string& reduce_op) { } } +struct DivMod { + template + static __device__ __forceinline__ void divmod(T dividend, + T divisor, + T* __restrict__ quotient, + T* __restrict__ remainder) { + *quotient = dividend / divisor; + *remainder = dividend % divisor; + } +}; + // compute two offsets for self tensor and src tensor // if compute_self is true, other wise only src_offset is useful // TODO(heqianyue): remove force inline? @@ -608,7 +608,7 @@ void gpu_gather_kernel(phi::DenseTensor self, gpu_gather_scatter_functor()( - result, dim, index, self, "assign", tensor_assign, include_self, dev_ctx); + result, dim, index, self, "gather", tensor_assign, include_self, dev_ctx); return; } @@ -976,7 +976,6 @@ void gpu_scatter_mul_min_max_input_grad_kernel( } else if (reduce == "amin" || reduce == "amax") { phi::funcs::set_constant(dev_ctx, &aux_tensor, 1); shared_mem_bytes *= 3; // two strides, 1 shape - ScatterGradPrePassKernel <<>>(grad_data, index_data, From 2714dc12f1677c6a6a78d909331d502b0697accd Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Fri, 29 Aug 2025 04:13:35 +0000 Subject: [PATCH 3/3] [PHI] FIxed reduce = min/max input grad bug --- paddle/phi/kernels/funcs/gather_scatter_functor.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index 0b4bcd176aa686..8442bdf652a44d 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -872,7 +872,7 @@ __global__ void ScatterMinMaxInputGradGPUKernel( int ndim, int64_t numel, int* __restrict__ aux_buffer) { - COMPUTE_OFFSET_SINGLE_OUTPUT(replace_index, 2, tid, 3) + COMPUTE_OFFSET_SINGLE_OUTPUT(replace_index, 1, tid, 2) if (out_data[replace_index] != x_data[replace_index]) { grad_data[replace_index] = 0; } else {