Skip to content

Commit

Permalink
[Zero-Dim] Support 0D for paddle.diagflat (#48735)
Browse files Browse the repository at this point in the history
* [Zero-Dim] Support 0D for paddle.diagflat
  • Loading branch information
Courtesy-Xs authored Dec 7, 2022
1 parent 6542027 commit 1a3d259
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 16 deletions.
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ void DiagInferMeta(const MetaTensor& x,
MetaTensor* out) {
auto x_dims = x.dims();

if (x_dims.size() == 1UL) {
int64_t size_ = x_dims[0] + std::abs(offset);
if (x_dims.size() <= 1) {
int64_t size_ = (x_dims.size() == 1UL ? x_dims[0] : 1) + std::abs(offset);
out->set_dims({size_, size_});
out->set_dtype(x.dtype());
} else if (x_dims.size() == 2UL) {
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/cpu/diag_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ void DiagGradKernel(const Context& dev_ctx,
auto dx_dims = x_grad->dims();
auto dout_dims = out_grad.dims();

if (dx_dims.size() == 1) {
auto dx_length = dx_dims[0];
int dx_stride = phi::funcs::ComputeStride(0, dx_dims);
if (dx_dims.size() <= 1) {
auto dx_length = (dx_dims.size() == 1 ? dx_dims[0] : int64_t(1));
int dx_stride = 1;

auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims);
auto dout_stride_1 = phi::funcs::ComputeStride(1, dout_dims);
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/cpu/diag_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ void DiagKernel(const Context& dev_ctx,
auto out_dims = out->dims();

int64_t i;
if (x_dims.size() == 1) {
if (x_dims.size() <= 1) {
phi::funcs::SetConstant<Context, T> set_padding_value;
set_padding_value(dev_ctx, out, static_cast<T>(padding_value));

auto x_length = x_dims[0];
const int& x_stride = phi::funcs::ComputeStride(0, x_dims);
auto x_length = (x_dims.size() == 1UL ? x_dims[0] : int64_t(1));
const int& x_stride = 1;

auto out_stride_0 = phi::funcs::ComputeStride(0, out_dims);
auto out_stride_1 = phi::funcs::ComputeStride(1, out_dims);
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/gpu/diag_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ void DiagGradKernel(const Context& dev_ctx,
return std::tuple<int64_t, int64_t>{block_size, grid_size};
};

if (dx_dims.size() == 1) {
auto dx_length = dx_dims[0];
if (dx_dims.size() <= 1) {
auto dx_length = (dx_dims.size() == 1 ? dx_dims[0] : int64_t(1));
auto size = (offset > 0) ? dx_length + offset : dx_length - offset;
int dx_stride = phi::funcs::ComputeStride(0, dx_dims);
int dx_stride = 1;
if (size > 0) {
auto dout_stride_0 = phi::funcs::ComputeStride(0, dout_dims);
auto dout_stride_1 = phi::funcs::ComputeStride(1, dout_dims);
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/gpu/diag_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ void DiagKernel(const Context& dev_ctx,
return std::tuple<int64_t, int64_t>{block_size, grid_size};
};

if (x_dims.size() == 1) {
if (x_dims.size() <= 1) {
phi::funcs::SetConstant<Context, T> set_padding_value;
set_padding_value(dev_ctx, out, static_cast<T>(padding_value));

auto x_length = x_dims[0];
auto x_length = (x_dims.size() == 1UL ? x_dims[0] : int64_t(1));
auto size = (offset > 0) ? x_length + offset : x_length - offset;
const int& x_stride = phi::funcs::ComputeStride(0, x_dims);
const int& x_stride = 1;
if (size > 0) {
const auto& out_stride_0 = phi::funcs::ComputeStride(0, out_dims);
const auto& out_stride_1 = phi::funcs::ComputeStride(1, out_dims);
Expand Down
49 changes: 49 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,35 @@ def test_scatter_XD(self):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])
self.assertEqual(out.grad.shape, [2, 3])

def test_diagflat(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
x3 = paddle.rand([])

x1.stop_gradient = False
x2.stop_gradient = False
x3.stop_gradient = False

out1 = paddle.diagflat(x1, 1)
out2 = paddle.diagflat(x2, -1)
out3 = paddle.diagflat(x3, 0)

out1.backward()
out2.backward()
out3.backward()

self.assertEqual(out1.shape, [2, 2])
self.assertEqual(out2.shape, [2, 2])
self.assertEqual(out3.shape, [1, 1])

self.assertEqual(out1.grad.shape, [2, 2])
self.assertEqual(out2.grad.shape, [2, 2])
self.assertEqual(out3.grad.shape, [1, 1])

self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x3.grad.shape, [])


class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -796,6 +825,26 @@ def test_scatter_XD(self):
for i in range(3):
self.assertEqual(res[0][1][i], 4)

@prog_scope()
def test_diagflat(self):
x1 = paddle.rand([])
out1 = paddle.diagflat(x1, 1)
paddle.static.append_backward(out1)

x2 = paddle.rand([])
out2 = paddle.diagflat(x2, -1)
paddle.static.append_backward(out2)

x3 = paddle.rand([])
out3 = paddle.diagflat(x3)
paddle.static.append_backward(out3)

prog = paddle.static.default_main_program()
res1, res2, res3 = self.exe.run(prog, fetch_list=[out1, out2, out3])
self.assertEqual(res1.shape, (2, 2))
self.assertEqual(res2.shape, (2, 2))
self.assertEqual(res3.shape, (1, 1))


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,35 @@ def test_scatter_XD(self):
for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])

def test_diagflat(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
x3 = paddle.rand([])

x1.stop_gradient = False
x2.stop_gradient = False
x3.stop_gradient = False

out1 = paddle.diagflat(x1, 1)
out2 = paddle.diagflat(x2, -1)
out3 = paddle.diagflat(x3, 0)

out1.backward()
out2.backward()
out3.backward()

self.assertEqual(out1.shape, [2, 2])
self.assertEqual(out2.shape, [2, 2])
self.assertEqual(out3.shape, [1, 1])

self.assertEqual(out1.grad.shape, [2, 2])
self.assertEqual(out2.grad.shape, [2, 2])
self.assertEqual(out3.grad.shape, [1, 1])

self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x3.grad.shape, [])


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,7 @@ def diagflat(x, offset=0, name=None):
"""
padding_value = 0
if in_dygraph_mode():
if len(x.shape) == 1:
if len(x.shape) <= 1:
return _C_ops.diag(x, offset, padding_value)
else:
y = _C_ops.flatten(x, 0, -1)
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def diagflat(x, offset=0, name=None):
out1_shape = helper.create_variable_for_type_inference(x.dtype)
out2 = helper.create_variable_for_type_inference(dtype=x.dtype)

if len(x.shape) == 1:
if len(x.shape) <= 1:
helper.append_op(
type='diag_v2',
inputs={'X': x},
Expand Down

0 comments on commit 1a3d259

Please sign in to comment.