Skip to content

Commit

Permalink
[Zero-Dim] support input 0D Tensor for reduce_sum/reduce_mean (#47219)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 authored Oct 31, 2022
1 parent 81b93eb commit c8fc337
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 116 deletions.
9 changes: 5 additions & 4 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2685,7 +2685,7 @@ DDim ReduceInferDim(const MetaTensor& x,

bool full_dim = true;
std::set<int64_t> dims_set(formated_axis.begin(), formated_axis.end());
for (int64_t i = 0; i < x.dims().size(); ++i) {
for (int64_t i = 0; i < x_rank; ++i) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
Expand All @@ -2695,23 +2695,23 @@ DDim ReduceInferDim(const MetaTensor& x,

std::vector<int64_t> out_dim_vector;
if (keep_dim) {
for (int64_t i = 0; i < x.dims().size(); ++i) {
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
out_dim_vector.push_back(1);
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}
} else {
for (int64_t i = 0; i < x.dims().size(); ++i) {
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
continue;
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}

if (out_dim_vector.size() == 0) {
if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
}
}
Expand Down Expand Up @@ -3013,6 +3013,7 @@ void SetValueInferMeta(const MetaTensor& x, MetaTensor* out) {
phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.",
in_dims.size()));
out->set_dims(in_dims);
}

void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) {
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct DimensionsTransform {
int64_t in_idx = 0;
if (in_dim.size() < dim_size) {
DimVector tmp_dim(dim_size, 1);
do {
for (; in_idx < in_dim.size();) {
if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
tmp_dim[axis] = in_dim[in_idx];
in_idx++;
Expand All @@ -59,11 +59,11 @@ struct DimensionsTransform {
out_dims[axis],
in_dim[in_idx]));
}
} while (in_idx < in_dim.size());
}
in_dim.resize(dim_size);
std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
} else {
do {
for (; in_idx < dim_size;) {
if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
in_idx++;
} else {
Expand All @@ -76,7 +76,7 @@ struct DimensionsTransform {
out_dims[in_idx],
in_dim[in_idx]));
}
} while (in_idx < dim_size);
}
}
std::reverse(in_dim.begin(), in_dim.end());
}
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/kernels/funcs/reduce_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,14 @@ void ReduceKernel(const KPDevice& dev_ctx,
dev_ctx.Alloc<Ty>(y);

auto x_dim = phi::vectorize<int>(x.dims());

if (x_dim.size() == 0) {
std::vector<const DenseTensor*> inputs = {&x};
std::vector<DenseTensor*> outputs = {y};
funcs::ElementwiseKernel<Ty>(dev_ctx, inputs, &outputs, transform);
return;
}

auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
config.Run(dev_ctx);
int numel = x.numel();
Expand Down
31 changes: 21 additions & 10 deletions paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpu/reduce_grad.h"

namespace phi {

Expand All @@ -29,23 +29,34 @@ void ReduceMeanGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = x.dims().size();
if (dims.size() == 0) {
reduce_all = true;
}
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all);

auto update_dims = vectorize(x.dims());
int reduce_num = 1;
for (auto i : reduce_dims) {
reduce_num *= (x.dims())[i];
update_dims[i] = 1;
}

// make new tensor
DenseTensor new_out_grad(out_grad.dtype());
new_out_grad.ShareDataWith(out_grad);
new_out_grad.Resize(phi::make_ddim(update_dims));

// call BroadcastKernel
dev_ctx.Alloc(x_grad, x.dtype());
std::vector<const DenseTensor*> inputs = {&new_out_grad};
std::vector<DenseTensor*> outputs = {x_grad};

using MPType = typename kps::details::MPTypeTrait<T>::Type;
ReduceGradKernel<T, T, Context, kps::DivideFunctor<T, MPType>>(
dev_ctx,
x,
out_grad,
dims.GetData(),
keep_dim,
reduce_all,
x_grad,
kps::DivideFunctor<T, MPType>(reduce_num));
funcs::BroadcastKernel<phi::ElementwiseType::kUnary, T, T>(
dev_ctx, inputs, &outputs, 0, kps::DivideFunctor<T, MPType>(reduce_num));
}

} // namespace phi
Expand Down
36 changes: 13 additions & 23 deletions paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,32 @@ void ReduceSumGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
using MPType = typename kps::details::MPTypeTrait<T>::Type;
auto out_dtype = x.dtype();
auto* in_x = &x;
auto* d_out = &out_grad;
auto* d_x = x_grad;

// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
// get reduce_dim for reduce_mean_grad
int dim_size = x.dims().size();
if (dims.size() == 0) {
reduce_all = true;
}
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all);

auto update_dims = vectorize(d_x->dims());
int reduce_num = 1;
auto update_dims = vectorize(x.dims());
for (auto i : reduce_dims) {
reduce_num *= (in_x->dims())[i];
update_dims[i] = 1;
}

// make new tensor
DenseTensor new_d_out(d_out->dtype());
new_d_out.ShareDataWith(*d_out);
new_d_out.Resize(phi::make_ddim(update_dims));
DenseTensor new_out_grad(out_grad.dtype());
new_out_grad.ShareDataWith(out_grad);
new_out_grad.Resize(phi::make_ddim(update_dims));

dev_ctx.Alloc(d_x, x.dtype());
auto pt_out_dtype = x.dtype();
auto pt_d_out = new_d_out;
auto pt_d_x = *d_x;
std::vector<const DenseTensor*> inputs = {&pt_d_out};
std::vector<DenseTensor*> outputs = {&pt_d_x};
// call ReduceGrad
dev_ctx.Alloc(x_grad, x.dtype());
using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::ReduceGrad<T, kps::IdentityFunctor<T, MPType>>(
dev_ctx,
&pt_d_out,
&pt_d_x,
pt_out_dtype,
&new_out_grad,
x_grad,
x.dtype(),
kps::IdentityFunctor<T, MPType>());
}

Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/reduce_mean_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ void MeanKernel(const Context& dev_ctx,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0) {
reduce_all = true;
}
MeanRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}

Expand Down
3 changes: 0 additions & 3 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5096,9 +5096,6 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_sum(y, dim=[0, 1]) # [16, 20]

"""
if dim is not None and not isinstance(dim, list):
dim = [dim]

reduce_all, dim = _get_reduce_dim(dim, input)

if in_dygraph_mode():
Expand Down
15 changes: 15 additions & 0 deletions python/paddle/fluid/tests/unittests/test_mean_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ def test_checkout_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)


class TestMeanOp_ZeroDim(OpTest):
def setUp(self):
self.op_type = "mean"
self.python_api = paddle.mean
self.dtype = np.float64
self.inputs = {'X': np.random.random([]).astype(self.dtype)}
self.outputs = {'Out': np.mean(self.inputs["X"])}

def test_check_output(self):
self.check_output(check_eager=True)

def test_checkout_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)


class TestMeanOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
Expand Down
15 changes: 15 additions & 0 deletions python/paddle/fluid/tests/unittests/test_reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)


class TestSumOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.sum
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random([]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=None)}
self.attrs = {'dim': [], 'reduce_all': True}

def test_check_output(self):
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)


class TestSumOp_fp16(OpTest):
def setUp(self):
self.python_api = paddle.sum
Expand Down
51 changes: 51 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import unittest


unary_api_list = [
paddle.nn.functional.elu,
paddle.nn.functional.gelu,
Expand Down Expand Up @@ -159,5 +160,55 @@ def test_static_unary(self):
paddle.disable_static()


reduce_api_list = [
paddle.sum,
paddle.mean,
paddle.nansum,
paddle.nanmean,
]


class TestReduceAPI(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
for api in reduce_api_list:
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
out.backward()

self.assertEqual(x.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])

paddle.enable_static()

def test_static(self):
paddle.enable_static()
for api in reduce_api_list:
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
x = paddle.rand([])

x.stop_gradient = False
out = api(x, None)
fluid.backward.append_backward(out)

# Test compile shape, grad is always [1]
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, ())

exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, out])

# Test runtime shape
self.assertEqual(result[0].shape, ())
self.assertEqual(result[1].shape, ())

paddle.disable_static()


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit c8fc337

Please sign in to comment.