Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lerp support 0 Tensor #49667

Merged
merged 11 commits into from
Jan 12, 2023
4 changes: 1 addition & 3 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -598,9 +598,7 @@ void LerpInferMeta(const MetaTensor& x,
auto w_dims = weight.dims();
DDim out_dims;
out_dims = funcs::GetOutputDims(x_dims, y_dims);
if (w_dims.size() > 1 || w_dims[0] != 1) {
out_dims = funcs::GetOutputDims(out_dims, w_dims);
}
out_dims = funcs::GetOutputDims(out_dims, w_dims);
out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
Expand Down
23 changes: 17 additions & 6 deletions paddle/phi/kernels/gpu/lerp_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,15 @@ __global__ void LerpGradScalarKernelImpl(const T* weight,
bool XYNeedReduce(const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
auto x_dims =
x.dims().size() ? x.dims() : make_ddim(std::vector<int64_t>(1, 1));
auto y_dims =
y.dims().size() ? y.dims() : make_ddim(std::vector<int64_t>(1, 1));

auto out_dims = out.dims();
if (out_dims.size() == 0) {
return false;
}
int x_rank = x_dims.size();
int y_rank = y_dims.size();
int out_rank = out_dims.size();
Expand Down Expand Up @@ -166,10 +172,10 @@ void LerpGradKernel(const Context& ctx,
const int rank = out.dims().size();
PADDLE_ENFORCE_GE(
rank,
1,
0,
phi::errors::InvalidArgument(
"The number of dimensions for LerpGradOp must be "
"greater than or equal to 1, but the value received is %d.",
"greater than or equal to 0, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank,
Expand Down Expand Up @@ -231,9 +237,12 @@ void LerpGradKernel(const Context& ctx,
x_grad_data,
y_grad_data);

auto zero_dim = make_ddim(std::vector<int64_t>(1, 1));
if (x_grad) {
std::vector<int> reduce_axis_x =
funcs::GetReduceDim(x_grad->dims(), b_xgrad.dims(), -1);
funcs::GetReduceDim(x_grad->dims().size() ? x_grad->dims() : zero_dim,
b_xgrad.dims(),
-1);
if (!reduce_axis_x.empty()) {
phi::funcs::
ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
Expand All @@ -245,7 +254,9 @@ void LerpGradKernel(const Context& ctx,

if (y_grad) {
std::vector<int> reduce_axis_y =
funcs::GetReduceDim(y_grad->dims(), b_ygrad.dims(), -1);
funcs::GetReduceDim(y_grad->dims().size() ? y_grad->dims() : zero_dim,
b_ygrad.dims(),
-1);
if (!reduce_axis_y.empty()) {
phi::funcs::
ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
Expand Down
55 changes: 45 additions & 10 deletions paddle/phi/kernels/impl/lerp_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,36 @@ static void LerpGradFunction(const Context& ctx,
auto* dx = x_grad;
auto* dy = y_grad;

auto dout_dims = dout.dims();
auto& out_dims = out.dims();
DDim dx_dims;
DDim dy_dims;

auto w_dims = phi::funcs::ExtendDims2Rank(w.dims(), D);
auto g_dims = phi::funcs::ExtendDims2Rank(out_grad.dims(), D);
Eigen::DSizes<int, D> dx_bcast_dims;
Eigen::DSizes<int, D> dy_bcast_dims;
Eigen::DSizes<int, D> w_bcast_dims;
Eigen::DSizes<int, D> g_bcast_dims;

if (dx) {
dx_dims = phi::funcs::ExtendDims2Rank(dx->dims(), D);
phi::funcs::GetBroadcastDims<D>(dx_dims, dout_dims, &dx_bcast_dims);
phi::funcs::GetBroadcastDims<D>(dx_dims, out_dims, &dx_bcast_dims);
}
if (dy) {
dy_dims = phi::funcs::ExtendDims2Rank(dy->dims(), D);
phi::funcs::GetBroadcastDims<D>(dy_dims, dout_dims, &dy_bcast_dims);
phi::funcs::GetBroadcastDims<D>(dy_dims, out_dims, &dy_bcast_dims);
}
phi::funcs::GetBroadcastDims<D>(w_dims, dout_dims, &w_bcast_dims);
phi::funcs::GetBroadcastDims<D>(w_dims, out_dims, &w_bcast_dims);
phi::funcs::GetBroadcastDims<D>(g_dims, out_dims, &g_bcast_dims);

auto eigen_w = phi::EigenTensor<T, D>::From(w, w_dims);
auto eigen_dout = phi::EigenTensor<T, D>::From(dout);
auto eigen_dout = phi::EigenTensor<T, D>::From(dout, g_dims);

Eigen::DSizes<int, D * 2> dx_reshape_dims;
Eigen::DSizes<int, D * 2> dy_reshape_dims;
Eigen::DSizes<int, D> reduce_dims;

for (int i = 0; i < dout_dims.size(); ++i) {
for (int i = 0; i < out_dims.size(); ++i) {
if (dx) {
dx_reshape_dims[2 * i] = dx_bcast_dims[i];
dx_reshape_dims[2 * i + 1] = dx_dims[i];
Expand All @@ -76,21 +79,49 @@ static void LerpGradFunction(const Context& ctx,
if (dx) {
ctx.template Alloc<T>(dx);
auto eigen_dx = phi::EigenTensor<T, D>::From(*dx, dx_dims);
auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * eigen_dout;
auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) *
eigen_dout.broadcast(g_bcast_dims);
eigen_dx.device(place) = eigen_expr.reshape(dx_reshape_dims)
.sum(reduce_dims)
.reshape(eigen_dx.dimensions());
}
if (dy) {
ctx.template Alloc<T>(dy);
auto eigen_dy = phi::EigenTensor<T, D>::From(*dy, dy_dims);
auto eigen_expr = eigen_w.broadcast(w_bcast_dims) * eigen_dout;
auto eigen_expr =
eigen_w.broadcast(w_bcast_dims) * eigen_dout.broadcast(g_bcast_dims);
eigen_dy.device(place) = eigen_expr.reshape(dy_reshape_dims)
.sum(reduce_dims)
.reshape(eigen_dy.dimensions());
}
}

template <typename Context, typename T>
static void LerpGradFunctionZero(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto dim = make_ddim(std::vector<int64_t>(1, 1));
auto eigen_w = phi::EigenTensor<T, 1>::From(weight, dim);
auto eigen_dout = phi::EigenTensor<T, 1>::From(out_grad, dim);

auto& place = *ctx.eigen_device();
if (x_grad) {
ctx.template Alloc<T>(x_grad);
auto eigen_dx = phi::EigenTensor<T, 1>::From(*x_grad, dim);
eigen_dx.device(place) = (1 - eigen_w) * eigen_dout;
}
if (y_grad) {
ctx.template Alloc<T>(y_grad);
auto eigen_dy = phi::EigenTensor<T, 1>::From(*y_grad, dim);
eigen_dy.device(place) = eigen_w * eigen_dout;
}
}

template <typename T, typename Context>
void LerpGradKernel(const Context& ctx,
const DenseTensor& x,
Expand All @@ -103,10 +134,10 @@ void LerpGradKernel(const Context& ctx,
int rank = out.dims().size();
PADDLE_ENFORCE_GE(
rank,
1,
0,
phi::errors::InvalidArgument(
"The number of dimensions for LerpGradOp must be "
"greater than or equal to 1, but the value received is %d.",
"greater than or equal to 0, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank,
Expand All @@ -116,6 +147,10 @@ void LerpGradKernel(const Context& ctx,
"less than or equal to 6, but the value received is %d.",
rank));
switch (rank) {
case 0:
LerpGradFunctionZero<Context, T>(
ctx, x, y, weight, out, out_grad, x_grad, y_grad);
break;
case 1:
LerpGradFunction<Context, T, 1>(
ctx, x, y, weight, out, out_grad, x_grad, y_grad);
Expand Down
26 changes: 23 additions & 3 deletions paddle/phi/kernels/impl/lerp_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ static void LerpFunction(const Context& ctx,
const DenseTensor& weight,
DenseTensor* out) {
ctx.template Alloc<T>(out);

const auto& out_dims = out->dims();
auto x_dims = phi::funcs::ExtendDims2Rank(x.dims(), D);
auto y_dims = phi::funcs::ExtendDims2Rank(y.dims(), D);
Expand All @@ -51,6 +50,24 @@ static void LerpFunction(const Context& ctx,
(eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims));
}

template <typename Context, typename T>
static void LerpFunctionZero(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& weight,
DenseTensor* out) {
ctx.template Alloc<T>(out);

auto dim = make_ddim(std::vector<int64_t>(1, 1));
auto eigen_x = phi::EigenTensor<T, 1>::From(x, dim);
auto eigen_y = phi::EigenTensor<T, 1>::From(y, dim);
auto eigen_w = phi::EigenTensor<T, 1>::From(weight, dim);
auto eigen_out = phi::EigenTensor<T, 1>::From(*out, dim);

auto& place = *ctx.eigen_device();
eigen_out.device(place) = eigen_x + eigen_w * (eigen_y - eigen_x);
}

template <typename T, typename Context>
void LerpKernel(const Context& ctx,
const DenseTensor& x,
Expand All @@ -60,10 +77,10 @@ void LerpKernel(const Context& ctx,
int rank = out->dims().size();
PADDLE_ENFORCE_GE(
rank,
1,
0,
phi::errors::InvalidArgument(
"The number of dimensions for LerpOp must be "
"greater than or equal to 1, but the value received is %d.",
"greater than or equal to 0, but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank,
Expand All @@ -73,6 +90,9 @@ void LerpKernel(const Context& ctx,
"less than or equal to 6, but the value received is %d.",
rank));
switch (rank) {
case 0:
LerpFunctionZero<Context, T>(ctx, x, y, weight, out);
break;
case 1:
LerpFunction<Context, T, 1>(ctx, x, y, weight, out);
break;
Expand Down
72 changes: 72 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,49 @@ def test_argsort(self):
self.assertEqual(x1.grad.numpy(), 0)
self.assertEqual(x2.grad.numpy(), 0)

def test_lerp(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个API目前支持广播吗?如果支持的话就需要有0D+0D,0D+ND,ND+0D三种case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,thanks

# 0D + 0D
x0 = paddle.rand([])
y0 = paddle.rand([])
w0 = paddle.rand([])
x0.stop_gradient = False
y0.stop_gradient = False

out0 = paddle.lerp(x0, y0, w0)
out0.backward()

self.assertEqual(out0.shape, [])
self.assertEqual(x0.grad.shape, [])
self.assertEqual(y0.grad.shape, [])

# 0D + ND
x1 = paddle.rand([])
y1 = paddle.rand([64, 64])
w1 = paddle.rand([])
x1.stop_gradient = False
y1.stop_gradient = False

out1 = paddle.lerp(x1, y1, w1)
out1.backward()

self.assertEqual(out1.shape, [64, 64])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(y1.grad.shape, [64, 64])

# ND + 0D
x2 = paddle.rand([64, 64])
y2 = paddle.rand([])
w2 = paddle.rand([])
x2.stop_gradient = False
y2.stop_gradient = False

out2 = paddle.lerp(x2, y2, w2)
out2.backward()

self.assertEqual(out2.shape, [64, 64])
self.assertEqual(x2.grad.shape, [64, 64])
self.assertEqual(y2.grad.shape, [])

def test_repeat_interleave(self):
places = ['cpu']
if paddle.is_compiled_with_cuda():
Expand Down Expand Up @@ -1408,6 +1451,35 @@ def test_argsort(self):
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())

@prog_scope()
def test_lerp(self):
shapes = [
[(), (), (), ()],
[(), (64, 64), (), (64, 64)],
[(64, 64), (), (), (64, 64)],
]
for shape in shapes:
x = paddle.rand(shape[0])
y = paddle.rand(shape[1])
w = paddle.rand(shape[2])

x.stop_gradient = False
y.stop_gradient = False
out = paddle.lerp(x, y, w)
paddle.static.append_backward(out.sum())

prog = paddle.static.default_main_program()
block = prog.global_block()
x_grad = block.var(fluid.framework.grad_var_name(x.name))
y_grad = block.var(fluid.framework.grad_var_name(y.name))
out_grad = block.var(fluid.framework.grad_var_name(out.name))

res = self.exe.run(prog, fetch_list=[out, out_grad, y_grad, x_grad])
self.assertEqual(res[0].shape, shape[3])
self.assertEqual(res[1].shape, shape[3])
self.assertEqual(res[2].shape, shape[1])
self.assertEqual(res[3].shape, shape[0])

@prog_scope()
def test_repeat_interleave(self):
x = paddle.full([], 1.0, 'float32')
Expand Down