Skip to content

Commit

Permalink
[Zero-Dim] support input 0D Tensor for sundary api
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Nov 7, 2022
1 parent 40cd527 commit b1c1c76
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 28 deletions.
7 changes: 6 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3249,7 +3249,12 @@ void ShardIndexInferMeta(const MetaTensor& in,

void SizeInferMeta(const MetaTensor& input, MetaTensor* out) {
out->set_dtype(DataType::INT64);
out->set_dims({1});
if (input.dims().size() == 0) {
out->set_dims(phi::make_ddim({}));
} else {
// TODO(zhouwei): will change shape [1] to [] to support zero-dim
out->set_dims(phi::make_ddim({1}));
}
}

void SliceRawInferMeta(const MetaTensor& input,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/impl/size_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ void SizeKernel(const Context& ctx,
DenseTensor* out) {
auto place = ctx.GetPlace();
auto out_data = ctx.template Alloc<int64_t>(out);
auto cpu_place = phi::CPUPlace();
if (place == cpu_place) {

if (place == phi::CPUPlace()) {
out_data[0] = input.numel();
} else {
DenseTensor cpu_tensor;
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def create_tensor(value, dtype, shape):
return out

def create_scalar(value, dtype):
return create_tensor(value, dtype, shape=[1])
return create_tensor(value, dtype, shape=[])

def astype(self, dtype):
"""
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def create_tensor(block, value, dtype, shape):
return var

def create_scalar(block, value, dtype):
return create_tensor(block, value, dtype, shape=[1])
return create_tensor(block, value, dtype, shape=[])

def create_tensor_with_batchsize(ref_var, value, dtype):
assert isinstance(ref_var, Variable)
Expand Down Expand Up @@ -417,7 +417,7 @@ def __impl__(self, other_var):
out = create_new_tmp_var(current_block(self), dtype=lhs_dtype)

axis = -1
if other_var.shape[0] == -1:
if other_var.ndim > 0 and other_var.shape[0] == -1:
stack = inspect.stack()[1]
file_name = stack[1]
line_num = stack[2]
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/test_numel_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def setUp(self):
self.inputs = {
'Input': x,
}
# TODO(zhouwei): will change shape [1] to [] to support zero-dim
self.outputs = {'Out': np.array([np.size(x)])}

def test_check_output(self):
Expand Down Expand Up @@ -67,6 +68,7 @@ def test_numel_static(self):
},
fetch_list=[out_1, out_2],
)
# TODO(zhouwei): will change shape [1] to [] to support zero-dim
assert np.array_equal(
res_1, np.array([np.size(input_1)]).astype("int64")
)
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/test_size_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_size_static(self):
},
fetch_list=[out_1, out_2],
)
# TODO(zhouwei): will change shape [1] to [] to support zero-dim
assert np.array_equal(
res_1, np.array([np.size(input_1)]).astype("int64")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import paddle
import paddle.fluid as fluid
import paddle.nn.functional as F
import numpy as np
import unittest

Expand Down Expand Up @@ -67,7 +68,7 @@
]


# Use to test zero-dim in the whole API
# Use to test zero-dim in unary API.
class TestUnaryAPI(unittest.TestCase):
def test_dygraph_unary(self):
paddle.disable_static()
Expand Down Expand Up @@ -176,6 +177,7 @@ def test_static_unary(self):
]


# Use to test zero-dim of reduce API
class TestReduceAPI(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
Expand Down Expand Up @@ -232,31 +234,32 @@ def test_static(self):
{'func': paddle.multiply, 'cls_method': '__mul__'},
{'func': paddle.divide, 'cls_method': '__div__'},
{'func': paddle.subtract, 'cls_method': '__sub__'},
paddle.pow,
]

binary_api_list_without_grad = [
{'func': paddle.pow, 'cls_method': '__pow__'},
{'func': paddle.add, 'cls_method': '__add__'},
{'func': paddle.subtract, 'cls_method': '__sub__'},
{'func': paddle.multiply, 'cls_method': '__mul__'},
{'func': paddle.divide, 'cls_method': '__div__'},
{'func': paddle.subtract, 'cls_method': '__sub__'},
paddle.pow,
{'func': paddle.mod, 'cls_method': '__mod__'},
paddle.floor_mod,
paddle.remainder,
{'func': paddle.pow, 'cls_method': '__pow__'},
]

binary_api_list_without_grad = [
{'func': paddle.equal, 'cls_method': '__eq__'},
{'func': paddle.not_equal, 'cls_method': '__ne__'},
{'func': paddle.greater_equal, 'cls_method': '__ge__'},
{'func': paddle.greater_than, 'cls_method': '__gt__'},
{'func': paddle.less_equal, 'cls_method': '__le__'},
{'func': paddle.less_than, 'cls_method': '__lt__'},
{'func': paddle.remainder, 'cls_method': '__mod__'},
paddle.mod,
paddle.floor_mod,
paddle.logical_and,
paddle.logical_or,
paddle.logical_xor,
]


# Use to test zero-dim of binary API
class TestBinaryAPI(unittest.TestCase):
def test_dygraph_binary(self):
paddle.disable_static()
Expand All @@ -274,8 +277,6 @@ def test_dygraph_binary(self):
else:
out = api(x, y)

self.assertEqual(x.shape, [])
self.assertEqual(y.shape, [])
self.assertEqual(out.shape, [])

if api not in binary_api_list_without_grad:
Expand All @@ -296,8 +297,6 @@ def test_dygraph_binary(self):
else:
out = api(x, y)

self.assertEqual(x.shape, [2, 3, 4])
self.assertEqual(y.shape, [])
self.assertEqual(out.shape, [2, 3, 4])

if api not in binary_api_list_without_grad:
Expand All @@ -317,10 +316,7 @@ def test_dygraph_binary(self):
np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else:
out = api(x, y)
out.backward()

self.assertEqual(x.shape, [])
self.assertEqual(y.shape, [2, 3, 4])
self.assertEqual(out.shape, [2, 3, 4])

if api not in binary_api_list_without_grad:
Expand All @@ -329,19 +325,32 @@ def test_dygraph_binary(self):
self.assertEqual(y.grad.shape, [2, 3, 4])
self.assertEqual(out.grad.shape, [2, 3, 4])

# 4) x is 0D , y is scalar
x = paddle.rand([])
y = 0.5
x.stop_gradient = False
if isinstance(api, dict):
out = getattr(paddle.Tensor, api['cls_method'])(x, y)
self.assertEqual(out.shape, [])

paddle.enable_static()

def test_static_unary(self):
paddle.enable_static()
for api in binary_api_list:
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
# 1) x/y is 0D
x = paddle.rand([])
y = paddle.rand([])
x.stop_gradient = False
y.stop_gradient = False
if isinstance(api, dict):
out = api['func'](x, y)
out_cls = getattr(
paddle.static.Variable, api['cls_method']
)(x, y)
self.assertEqual(out.shape, out_cls.shape)
else:
out = api(x, y)
fluid.backward.append_backward(out)
Expand All @@ -351,20 +360,112 @@ def test_static_unary(self):
block = prog.global_block()

# Test compile shape
self.assertEqual(x.shape, ())
self.assertEqual(y.shape, ())
self.assertEqual(out.shape, ())

exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, y, out])

# Test runtime shape
self.assertEqual(result[0].shape, ())
self.assertEqual(result[1].shape, ())
self.assertEqual(result[2].shape, ())

# 2) x is 0D , y is scalar
x = paddle.rand([])
y = 0.5
x.stop_gradient = False
if isinstance(api, dict):
out = getattr(paddle.static.Variable, api['cls_method'])(
x, y
)
self.assertEqual(out.shape, ())

paddle.disable_static()


# Use to test zero-dim of Sundry API, which is simple and do
# not have backward, or is not need to test backward in OpTest.
class TestSundryAPI(unittest.TestCase):
def setUp(self):
self.x = paddle.rand([])

def test_linear(self):
x = paddle.randn([3, 2])
w = paddle.full(shape=[2, 4], fill_value=0.5)
b = paddle.zeros([])

np.testing.assert_array_equal(
F.linear(x, w, b).numpy(), F.linear(x, w).numpy()
)

def test_is_complex(self):
x = paddle.rand([]) + 1j * paddle.rand([])
self.assertTrue(paddle.is_complex(x))

def test_is_floating_point(self):
self.assertTrue(paddle.is_floating_point(self.x))

def test_is_integer(self):
x = paddle.randint(0, 10, [])
self.assertTrue(paddle.is_integer(x))

def test_is_tensor(self):
self.assertTrue(paddle.is_tensor(self.x))

def test_is_empty(self):
x = paddle.rand([3, 0, 5])
self.assertTrue(paddle.is_empty(x))

def test_isfinite(self):
out = paddle.isfinite(self.x)
np.testing.assert_array_equal(out.numpy(), np.array(True))

def test_isinf(self):
x = paddle.to_tensor(np.array(float('-inf')))
out = paddle.isinf(x)
np.testing.assert_array_equal(out.numpy(), np.array(True))

def test_isnan(self):
x = paddle.to_tensor(np.array(float('nan')))
out = paddle.isnan(x)
np.testing.assert_array_equal(out.numpy(), np.array(True))

def test_isclose(self):
out = paddle.isclose(self.x, self.x)
np.testing.assert_array_equal(out.numpy(), np.array(True))

def test_clone(self):
out = paddle.clone(self.x)
np.testing.assert_array_equal(out.numpy(), self.x.numpy())

def test_assign(self):
out = paddle.assign(self.x)
np.testing.assert_array_equal(out.numpy(), self.x.numpy())

def test_item(self):
x = paddle.full([], 0.5)
self.assertEqual(x.item(), 0.5)

def test_tolist(self):
x = paddle.full([], 0.5)
self.assertEqual(x.tolist(), 0.5)

def test_numpy(self):
x = paddle.full([], 0.5)
np.testing.assert_array_equal(x.numpy(), np.array(0.5))

def test_numel(self):
out = paddle.numel(self.x)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(1))

def test_rank(self):
out = paddle.rank(self.x)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(0))

def test_shape(self):
out = paddle.shape(self.x)
self.assertEqual(out.shape, [0])
np.testing.assert_array_equal(out.numpy(), np.array([]))


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions python/paddle/incubate/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ def batch_norm_orig2prim(

@REGISTER_ORIG2PRIM('size')
def size_orig2prim(op, x):
# TODO(zhouwei): will change shape [1] to [] to support zero-dim
return fill_const(
functools.reduce(operator.mul, x.shape), (1,), paddle.int64
)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
)
n = n.astype(dtype)
if unbiased:
one_const = paddle.ones([1], x.dtype)
one_const = paddle.ones([], x.dtype)
n = where(n > one_const, n - 1.0, one_const)
out /= n
return out
Expand Down

0 comments on commit b1c1c76

Please sign in to comment.