diff --git a/paddle/phi/kernels/dist_grad_kernel.cc b/paddle/phi/kernels/dist_grad_kernel.cc index ba468ad299e4c..17c24fa905b5c 100644 --- a/paddle/phi/kernels/dist_grad_kernel.cc +++ b/paddle/phi/kernels/dist_grad_kernel.cc @@ -52,6 +52,10 @@ void DistGradKernel(const Context& dev_ctx, float p, DenseTensor* x_grad, DenseTensor* y_grad) { + if ((!x_grad) && (!y_grad)) { + return; + } + auto t = Subtract(dev_ctx, x, y); DenseTensor x_grad_tmp; x_grad_tmp.Resize(t.dims()); @@ -59,26 +63,32 @@ void DistGradKernel(const Context& dev_ctx, y_grad_tmp.Resize(t.dims()); PNormGradKernel( dev_ctx, t, out, out_grad, p, -1, 1e-12, false, true, &x_grad_tmp); - ScaleKernel(dev_ctx, x_grad_tmp, -1.0, 0.0, false, &y_grad_tmp); - // do reduce, the implemetation of cpu SumKernel has bug, it changes - // the dims of output iternally, so we Resize x/y_grad twice. - auto res_x = GetReduceDims(x_grad_tmp.dims(), x.dims()); - if (!std::get<0>(res_x).empty()) { - x_grad->Resize(phi::make_ddim(std::get<1>(res_x))); - SumKernel( - dev_ctx, x_grad_tmp, std::get<0>(res_x), x.dtype(), false, x_grad); - x_grad->Resize(x.dims()); - } else { - x_grad->ShareBufferWith(x_grad_tmp); + + if (x_grad) { + // do reduce, the implemetation of cpu SumKernel has bug, it changes + // the dims of output iternally, so we Resize x/y_grad twice. + auto res_x = GetReduceDims(x_grad_tmp.dims(), x.dims()); + if (!std::get<0>(res_x).empty()) { + x_grad->Resize(phi::make_ddim(std::get<1>(res_x))); + SumKernel( + dev_ctx, x_grad_tmp, std::get<0>(res_x), x.dtype(), false, x_grad); + x_grad->Resize(x.dims()); + } else { + x_grad->ShareBufferWith(x_grad_tmp); + } } - auto res_y = GetReduceDims(y_grad_tmp.dims(), y.dims()); - if (!std::get<0>(res_y).empty()) { - y_grad->Resize(phi::make_ddim(std::get<1>(res_y))); - SumKernel( - dev_ctx, y_grad_tmp, std::get<0>(res_y), y.dtype(), false, y_grad); - y_grad->Resize(y.dims()); - } else { - y_grad->ShareBufferWith(y_grad_tmp); + + if (y_grad) { + ScaleKernel(dev_ctx, x_grad_tmp, -1.0, 0.0, false, &y_grad_tmp); + auto res_y = GetReduceDims(y_grad_tmp.dims(), y.dims()); + if (!std::get<0>(res_y).empty()) { + y_grad->Resize(phi::make_ddim(std::get<1>(res_y))); + SumKernel( + dev_ctx, y_grad_tmp, std::get<0>(res_y), y.dtype(), false, y_grad); + y_grad->Resize(y.dims()); + } else { + y_grad->ShareBufferWith(y_grad_tmp); + } } } diff --git a/python/paddle/fluid/tests/unittests/test_dist_op.py b/python/paddle/fluid/tests/unittests/test_dist_op.py index 4ec55cb7938df..96c0de915cff2 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_op.py +++ b/python/paddle/fluid/tests/unittests/test_dist_op.py @@ -192,6 +192,15 @@ def test_api(self): ) np.testing.assert_allclose(dist(x_i, y_i, p), out[0], rtol=1e-05) + def test_grad_x(self): + paddle.disable_static() + a = paddle.rand([2, 2, 3, 2]) + b = paddle.rand([1, 1, 3, 1]) + a.stop_gradient = False + c = paddle.dist(a, b, 2) + c.backward() + paddle.enable_static() + if __name__ == '__main__': paddle.enable_static()