diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 3ee1c0c97b1a3..4746adc892e89 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3249,7 +3249,12 @@ void ShardIndexInferMeta(const MetaTensor& in, void SizeInferMeta(const MetaTensor& input, MetaTensor* out) { out->set_dtype(DataType::INT64); - out->set_dims({1}); + if (input.dims().size() == 0) { + out->set_dims(phi::make_ddim({})); + } else { + // TODO(zhouwei): will change shape [1] to [] to support zero-dim + out->set_dims(phi::make_ddim({1})); + } } void SliceRawInferMeta(const MetaTensor& input, diff --git a/paddle/phi/kernels/impl/size_kernel_impl.h b/paddle/phi/kernels/impl/size_kernel_impl.h index f9757bc447756..4c72f02f64349 100644 --- a/paddle/phi/kernels/impl/size_kernel_impl.h +++ b/paddle/phi/kernels/impl/size_kernel_impl.h @@ -24,8 +24,8 @@ void SizeKernel(const Context& ctx, DenseTensor* out) { auto place = ctx.GetPlace(); auto out_data = ctx.template Alloc(out); - auto cpu_place = phi::CPUPlace(); - if (place == cpu_place) { + + if (place == phi::CPUPlace()) { out_data[0] = input.numel(); } else { DenseTensor cpu_tensor; diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index c32056555cce2..6a864efc42eed 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -96,7 +96,7 @@ def create_tensor(value, dtype, shape): return out def create_scalar(value, dtype): - return create_tensor(value, dtype, shape=[1]) + return create_tensor(value, dtype, shape=[]) def astype(self, dtype): """ diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index fc71321f14732..fb3979434347f 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -99,7 +99,7 @@ def create_tensor(block, value, dtype, shape): return var def create_scalar(block, value, dtype): - return create_tensor(block, value, dtype, shape=[1]) + return create_tensor(block, value, dtype, shape=[]) def create_tensor_with_batchsize(ref_var, value, dtype): assert isinstance(ref_var, Variable) @@ -417,7 +417,7 @@ def __impl__(self, other_var): out = create_new_tmp_var(current_block(self), dtype=lhs_dtype) axis = -1 - if other_var.shape[0] == -1: + if other_var.ndim > 0 and other_var.shape[0] == -1: stack = inspect.stack()[1] file_name = stack[1] line_num = stack[2] diff --git a/python/paddle/fluid/tests/unittests/test_numel_op.py b/python/paddle/fluid/tests/unittests/test_numel_op.py index 4bb359a7bd2e0..fbc43bf4b8469 100644 --- a/python/paddle/fluid/tests/unittests/test_numel_op.py +++ b/python/paddle/fluid/tests/unittests/test_numel_op.py @@ -27,6 +27,7 @@ def setUp(self): self.inputs = { 'Input': x, } + # TODO(zhouwei): will change shape [1] to [] to support zero-dim self.outputs = {'Out': np.array([np.size(x)])} def test_check_output(self): @@ -67,6 +68,7 @@ def test_numel_static(self): }, fetch_list=[out_1, out_2], ) + # TODO(zhouwei): will change shape [1] to [] to support zero-dim assert np.array_equal( res_1, np.array([np.size(input_1)]).astype("int64") ) diff --git a/python/paddle/fluid/tests/unittests/test_size_op.py b/python/paddle/fluid/tests/unittests/test_size_op.py index 6f9898ade4fb0..87fcfdf5a9646 100644 --- a/python/paddle/fluid/tests/unittests/test_size_op.py +++ b/python/paddle/fluid/tests/unittests/test_size_op.py @@ -76,6 +76,7 @@ def test_size_static(self): }, fetch_list=[out_1, out_2], ) + # TODO(zhouwei): will change shape [1] to [] to support zero-dim assert np.array_equal( res_1, np.array([np.size(input_1)]).astype("int64") ) diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_shape.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py similarity index 66% rename from python/paddle/fluid/tests/unittests/test_zero_dim_shape.py rename to python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 2f3cf345f2687..26f24f5f7e952 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_shape.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -14,6 +14,7 @@ import paddle import paddle.fluid as fluid +import paddle.nn.functional as F import numpy as np import unittest @@ -67,7 +68,7 @@ ] -# Use to test zero-dim in the whole API +# Use to test zero-dim in unary API. class TestUnaryAPI(unittest.TestCase): def test_dygraph_unary(self): paddle.disable_static() @@ -176,6 +177,7 @@ def test_static_unary(self): ] +# Use to test zero-dim of reduce API class TestReduceAPI(unittest.TestCase): def test_dygraph(self): paddle.disable_static() @@ -231,32 +233,32 @@ def test_static(self): {'func': paddle.subtract, 'cls_method': '__sub__'}, {'func': paddle.multiply, 'cls_method': '__mul__'}, {'func': paddle.divide, 'cls_method': '__div__'}, - {'func': paddle.subtract, 'cls_method': '__sub__'}, - paddle.pow, + {'func': paddle.pow, 'cls_method': '__pow__'}, ] binary_api_list_without_grad = [ - {'func': paddle.add, 'cls_method': '__add__'}, - {'func': paddle.subtract, 'cls_method': '__sub__'}, - {'func': paddle.multiply, 'cls_method': '__mul__'}, - {'func': paddle.divide, 'cls_method': '__div__'}, - {'func': paddle.subtract, 'cls_method': '__sub__'}, - paddle.pow, - {'func': paddle.mod, 'cls_method': '__mod__'}, - paddle.floor_mod, - paddle.remainder, {'func': paddle.equal, 'cls_method': '__eq__'}, {'func': paddle.not_equal, 'cls_method': '__ne__'}, {'func': paddle.greater_equal, 'cls_method': '__ge__'}, {'func': paddle.greater_than, 'cls_method': '__gt__'}, {'func': paddle.less_equal, 'cls_method': '__le__'}, {'func': paddle.less_than, 'cls_method': '__lt__'}, + {'func': paddle.remainder, 'cls_method': '__mod__'}, + paddle.mod, + paddle.floor_mod, paddle.logical_and, paddle.logical_or, paddle.logical_xor, ] +binary_int_api_list_without_grad = [ + paddle.bitwise_and, + paddle.bitwise_or, + paddle.bitwise_xor, +] + +# Use to test zero-dim of binary API class TestBinaryAPI(unittest.TestCase): def test_dygraph_binary(self): paddle.disable_static() @@ -274,10 +276,7 @@ def test_dygraph_binary(self): else: out = api(x, y) - self.assertEqual(x.shape, []) - self.assertEqual(y.shape, []) self.assertEqual(out.shape, []) - if api not in binary_api_list_without_grad: out.backward() self.assertEqual(x.grad.shape, []) @@ -296,10 +295,7 @@ def test_dygraph_binary(self): else: out = api(x, y) - self.assertEqual(x.shape, [2, 3, 4]) - self.assertEqual(y.shape, []) self.assertEqual(out.shape, [2, 3, 4]) - if api not in binary_api_list_without_grad: out.backward() self.assertEqual(x.grad.shape, [2, 3, 4]) @@ -317,54 +313,190 @@ def test_dygraph_binary(self): np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) else: out = api(x, y) - out.backward() - self.assertEqual(x.shape, []) - self.assertEqual(y.shape, [2, 3, 4]) self.assertEqual(out.shape, [2, 3, 4]) - if api not in binary_api_list_without_grad: out.backward() self.assertEqual(x.grad.shape, []) self.assertEqual(y.grad.shape, [2, 3, 4]) self.assertEqual(out.grad.shape, [2, 3, 4]) + # 4) x is 0D , y is scalar + x = paddle.rand([]) + y = 0.5 + x.stop_gradient = False + if isinstance(api, dict): + out = getattr(paddle.Tensor, api['cls_method'])(x, y) + self.assertEqual(out.shape, []) + + for api in binary_int_api_list_without_grad: + # 1) x/y is 0D + x = paddle.randint(-10, 10, []) + y = paddle.randint(-10, 10, []) + out = api(x, y) + self.assertEqual(out.shape, []) + + # 2) x is not 0D , y is 0D + x = paddle.randint(-10, 10, [3, 5]) + y = paddle.randint(-10, 10, []) + out = api(x, y) + self.assertEqual(out.shape, [3, 5]) + + # 3) x is 0D , y is not 0D + x = paddle.randint(-10, 10, []) + y = paddle.randint(-10, 10, [3, 5]) + out = api(x, y) + self.assertEqual(out.shape, [3, 5]) + paddle.enable_static() def test_static_unary(self): paddle.enable_static() - for api in binary_api_list: + for api in binary_api_list + binary_api_list_without_grad: main_prog = fluid.Program() with fluid.program_guard(main_prog, fluid.Program()): + # 1) x/y is 0D x = paddle.rand([]) y = paddle.rand([]) x.stop_gradient = False y.stop_gradient = False if isinstance(api, dict): out = api['func'](x, y) + out_cls = getattr( + paddle.static.Variable, api['cls_method'] + )(x, y) + self.assertEqual(out.shape, out_cls.shape) else: out = api(x, y) fluid.backward.append_backward(out) - # append_backward always set grad shape to [1] - prog = paddle.static.default_main_program() - block = prog.global_block() - # Test compile shape - self.assertEqual(x.shape, ()) - self.assertEqual(y.shape, ()) self.assertEqual(out.shape, ()) - exe = fluid.Executor() - result = exe.run(main_prog, fetch_list=[x, y, out]) - + out_np = exe.run(main_prog, fetch_list=[out])[0] # Test runtime shape - self.assertEqual(result[0].shape, ()) - self.assertEqual(result[1].shape, ()) - self.assertEqual(result[2].shape, ()) + self.assertEqual(out_np.shape, ()) + + # 2) x is 0D , y is scalar + x = paddle.rand([]) + y = 0.5 + x.stop_gradient = False + if isinstance(api, dict): + out = getattr(paddle.static.Variable, api['cls_method'])( + x, y + ) + self.assertEqual(out.shape, ()) + + for api in binary_int_api_list_without_grad: + main_prog = fluid.Program() + with fluid.program_guard(main_prog, fluid.Program()): + # 1) x/y is 0D + x = paddle.randint(-10, 10, []) + y = paddle.randint(-10, 10, []) + out = api(x, y) + self.assertEqual(out.shape, ()) + + # 2) x is not 0D , y is 0D + x = paddle.randint(-10, 10, [3, 5]) + y = paddle.randint(-10, 10, []) + out = api(x, y) + self.assertEqual(out.shape, (3, 5)) + + # 3) x is 0D , y is not 0D + x = paddle.randint(-10, 10, []) + y = paddle.randint(-10, 10, [3, 5]) + out = api(x, y) + self.assertEqual(out.shape, (3, 5)) paddle.disable_static() +# Use to test zero-dim of Sundry API, which is simple and do +# not have backward, or is not need to test backward in OpTest. +class TestSundryAPI(unittest.TestCase): + def setUp(self): + self.x = paddle.rand([]) + + def test_linear(self): + x = paddle.randn([3, 2]) + w = paddle.full(shape=[2, 4], fill_value=0.5) + b = paddle.zeros([]) + + np.testing.assert_array_equal( + F.linear(x, w, b).numpy(), F.linear(x, w).numpy() + ) + + def test_is_complex(self): + x = paddle.rand([]) + 1j * paddle.rand([]) + self.assertTrue(paddle.is_complex(x)) + + def test_is_floating_point(self): + self.assertTrue(paddle.is_floating_point(self.x)) + + def test_is_integer(self): + x = paddle.randint(0, 10, []) + self.assertTrue(paddle.is_integer(x)) + + def test_is_tensor(self): + self.assertTrue(paddle.is_tensor(self.x)) + + def test_is_empty(self): + x = paddle.rand([3, 0, 5]) + self.assertTrue(paddle.is_empty(x)) + + def test_isfinite(self): + out = paddle.isfinite(self.x) + np.testing.assert_array_equal(out.numpy(), np.array(True)) + + def test_isinf(self): + x = paddle.to_tensor(np.array(float('-inf'))) + out = paddle.isinf(x) + np.testing.assert_array_equal(out.numpy(), np.array(True)) + + def test_isnan(self): + x = paddle.to_tensor(np.array(float('nan'))) + out = paddle.isnan(x) + np.testing.assert_array_equal(out.numpy(), np.array(True)) + + def test_isclose(self): + out = paddle.isclose(self.x, self.x) + np.testing.assert_array_equal(out.numpy(), np.array(True)) + + def test_clone(self): + out = paddle.clone(self.x) + np.testing.assert_array_equal(out.numpy(), self.x.numpy()) + + def test_assign(self): + out = paddle.assign(self.x) + np.testing.assert_array_equal(out.numpy(), self.x.numpy()) + + def test_item(self): + x = paddle.full([], 0.5) + self.assertEqual(x.item(), 0.5) + + def test_tolist(self): + x = paddle.full([], 0.5) + self.assertEqual(x.tolist(), 0.5) + + def test_numpy(self): + x = paddle.full([], 0.5) + np.testing.assert_array_equal(x.numpy(), np.array(0.5)) + + def test_numel(self): + out = paddle.numel(self.x) + self.assertEqual(out.shape, []) + np.testing.assert_array_equal(out.numpy(), np.array(1)) + + def test_rank(self): + out = paddle.rank(self.x) + self.assertEqual(out.shape, []) + np.testing.assert_array_equal(out.numpy(), np.array(0)) + + def test_shape(self): + out = paddle.shape(self.x) + self.assertEqual(out.shape, [0]) + np.testing.assert_array_equal(out.numpy(), np.array([])) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 0532ade86c65f..1efa41850c27d 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -605,6 +605,7 @@ def batch_norm_orig2prim( @REGISTER_ORIG2PRIM('size') def size_orig2prim(op, x): + # TODO(zhouwei): will change shape [1] to [] to support zero-dim return fill_const( functools.reduce(operator.mul, x.shape), (1,), paddle.int64 ) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index ad061673ab9f4..f77c5a5a96271 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -159,7 +159,7 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None): ) n = n.astype(dtype) if unbiased: - one_const = paddle.ones([1], x.dtype) + one_const = paddle.ones([], x.dtype) n = where(n > one_const, n - 1.0, one_const) out /= n return out