Skip to content

Commit da596f0

Browse files
fix masked_fill_grad value_grad bug (#75988) (#76002)
1 parent 7a0b75b commit da596f0

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

paddle/phi/kernels/gpu/masked_fill_grad_kernel.cu

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,24 @@ void GPUMaskedFillGrad(const phi::GPUContext& dev_ctx,
275275
config);
276276
if (value_grad) {
277277
DenseTensor zero_tensor;
278-
FullLikeKernel<T, phi::GPUContext>(
279-
dev_ctx, out_grad, Scalar(T(0.0)), out_grad.dtype(), &zero_tensor);
278+
phi::Full<T, phi::GPUContext>(
279+
dev_ctx,
280+
phi::IntArray(common::vectorize(out_grad.dims())),
281+
T(0.0),
282+
&zero_tensor);
280283
DenseTensor value_grad_tensor;
281284
value_grad_tensor.set_meta(out_grad.meta());
282285
WhereKernel<T, phi::GPUContext>(
283286
dev_ctx, mask, out_grad, zero_tensor, &value_grad_tensor);
284-
SumKernel<T, phi::GPUContext>(
285-
dev_ctx, value_grad_tensor, {1}, out_grad.dtype(), false, value_grad);
287+
std::vector<int> v_dims(value_grad_tensor.dims().size());
288+
std::iota(v_dims.begin(), v_dims.end(), 0);
289+
IntArray v_axis(v_dims);
290+
SumKernel<T, phi::GPUContext>(dev_ctx,
291+
value_grad_tensor,
292+
v_axis,
293+
value_grad->dtype(),
294+
false,
295+
value_grad);
286296
}
287297

288298
} else {

0 commit comments

Comments
 (0)