Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] support input 0D Tensor for sundary api #47734

Merged
merged 2 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3249,7 +3249,12 @@ void ShardIndexInferMeta(const MetaTensor& in,

void SizeInferMeta(const MetaTensor& input, MetaTensor* out) {
out->set_dtype(DataType::INT64);
out->set_dims({1});
if (input.dims().size() == 0) {
out->set_dims(phi::make_ddim({}));
} else {
// TODO(zhouwei): will change shape [1] to [] to support zero-dim
out->set_dims(phi::make_ddim({1}));
}
}

void SliceRawInferMeta(const MetaTensor& input,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/impl/size_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ void SizeKernel(const Context& ctx,
DenseTensor* out) {
auto place = ctx.GetPlace();
auto out_data = ctx.template Alloc<int64_t>(out);
auto cpu_place = phi::CPUPlace();
if (place == cpu_place) {

if (place == phi::CPUPlace()) {
out_data[0] = input.numel();
} else {
DenseTensor cpu_tensor;
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def create_tensor(value, dtype, shape):
return out

def create_scalar(value, dtype):
return create_tensor(value, dtype, shape=[1])
return create_tensor(value, dtype, shape=[])

def astype(self, dtype):
"""
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def create_tensor(block, value, dtype, shape):
return var

def create_scalar(block, value, dtype):
return create_tensor(block, value, dtype, shape=[1])
return create_tensor(block, value, dtype, shape=[])

def create_tensor_with_batchsize(ref_var, value, dtype):
assert isinstance(ref_var, Variable)
Expand Down Expand Up @@ -417,7 +417,7 @@ def __impl__(self, other_var):
out = create_new_tmp_var(current_block(self), dtype=lhs_dtype)

axis = -1
if other_var.shape[0] == -1:
if other_var.ndim > 0 and other_var.shape[0] == -1:
stack = inspect.stack()[1]
file_name = stack[1]
line_num = stack[2]
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/test_numel_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def setUp(self):
self.inputs = {
'Input': x,
}
# TODO(zhouwei): will change shape [1] to [] to support zero-dim
self.outputs = {'Out': np.array([np.size(x)])}

def test_check_output(self):
Expand Down Expand Up @@ -67,6 +68,7 @@ def test_numel_static(self):
},
fetch_list=[out_1, out_2],
)
# TODO(zhouwei): will change shape [1] to [] to support zero-dim
assert np.array_equal(
res_1, np.array([np.size(input_1)]).astype("int64")
)
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/test_size_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_size_static(self):
},
fetch_list=[out_1, out_2],
)
# TODO(zhouwei): will change shape [1] to [] to support zero-dim
assert np.array_equal(
res_1, np.array([np.size(input_1)]).astype("int64")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import paddle
import paddle.fluid as fluid
import paddle.nn.functional as F
import numpy as np
import unittest

Expand Down Expand Up @@ -67,7 +68,7 @@
]


# Use to test zero-dim in the whole API
# Use to test zero-dim in unary API.
class TestUnaryAPI(unittest.TestCase):
def test_dygraph_unary(self):
paddle.disable_static()
Expand Down Expand Up @@ -176,6 +177,7 @@ def test_static_unary(self):
]


# Use to test zero-dim of reduce API
class TestReduceAPI(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
Expand Down Expand Up @@ -231,32 +233,32 @@ def test_static(self):
{'func': paddle.subtract, 'cls_method': '__sub__'},
{'func': paddle.multiply, 'cls_method': '__mul__'},
{'func': paddle.divide, 'cls_method': '__div__'},
{'func': paddle.subtract, 'cls_method': '__sub__'},
paddle.pow,
{'func': paddle.pow, 'cls_method': '__pow__'},
]

binary_api_list_without_grad = [
{'func': paddle.add, 'cls_method': '__add__'},
{'func': paddle.subtract, 'cls_method': '__sub__'},
{'func': paddle.multiply, 'cls_method': '__mul__'},
{'func': paddle.divide, 'cls_method': '__div__'},
{'func': paddle.subtract, 'cls_method': '__sub__'},
paddle.pow,
{'func': paddle.mod, 'cls_method': '__mod__'},
paddle.floor_mod,
paddle.remainder,
{'func': paddle.equal, 'cls_method': '__eq__'},
{'func': paddle.not_equal, 'cls_method': '__ne__'},
{'func': paddle.greater_equal, 'cls_method': '__ge__'},
{'func': paddle.greater_than, 'cls_method': '__gt__'},
{'func': paddle.less_equal, 'cls_method': '__le__'},
{'func': paddle.less_than, 'cls_method': '__lt__'},
{'func': paddle.remainder, 'cls_method': '__mod__'},
paddle.mod,
paddle.floor_mod,
paddle.logical_and,
paddle.logical_or,
paddle.logical_xor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add paddle.bitwise_and/or/xor ?

Copy link
Contributor Author

@zhwesky2010 zhwesky2010 Nov 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加

]

binary_int_api_list_without_grad = [
paddle.bitwise_and,
paddle.bitwise_or,
paddle.bitwise_xor,
]


# Use to test zero-dim of binary API
class TestBinaryAPI(unittest.TestCase):
def test_dygraph_binary(self):
paddle.disable_static()
Expand All @@ -274,10 +276,7 @@ def test_dygraph_binary(self):
else:
out = api(x, y)

self.assertEqual(x.shape, [])
self.assertEqual(y.shape, [])
self.assertEqual(out.shape, [])

if api not in binary_api_list_without_grad:
out.backward()
self.assertEqual(x.grad.shape, [])
Expand All @@ -296,10 +295,7 @@ def test_dygraph_binary(self):
else:
out = api(x, y)

self.assertEqual(x.shape, [2, 3, 4])
self.assertEqual(y.shape, [])
self.assertEqual(out.shape, [2, 3, 4])

if api not in binary_api_list_without_grad:
out.backward()
self.assertEqual(x.grad.shape, [2, 3, 4])
Expand All @@ -317,54 +313,190 @@ def test_dygraph_binary(self):
np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else:
out = api(x, y)
out.backward()

self.assertEqual(x.shape, [])
self.assertEqual(y.shape, [2, 3, 4])
self.assertEqual(out.shape, [2, 3, 4])

if api not in binary_api_list_without_grad:
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(y.grad.shape, [2, 3, 4])
self.assertEqual(out.grad.shape, [2, 3, 4])

# 4) x is 0D , y is scalar
x = paddle.rand([])
y = 0.5
x.stop_gradient = False
if isinstance(api, dict):
out = getattr(paddle.Tensor, api['cls_method'])(x, y)
self.assertEqual(out.shape, [])

for api in binary_int_api_list_without_grad:
# 1) x/y is 0D
x = paddle.randint(-10, 10, [])
y = paddle.randint(-10, 10, [])
out = api(x, y)
self.assertEqual(out.shape, [])

# 2) x is not 0D , y is 0D
x = paddle.randint(-10, 10, [3, 5])
y = paddle.randint(-10, 10, [])
out = api(x, y)
self.assertEqual(out.shape, [3, 5])

# 3) x is 0D , y is not 0D
x = paddle.randint(-10, 10, [])
y = paddle.randint(-10, 10, [3, 5])
out = api(x, y)
self.assertEqual(out.shape, [3, 5])

paddle.enable_static()

def test_static_unary(self):
paddle.enable_static()
for api in binary_api_list:
for api in binary_api_list + binary_api_list_without_grad:
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
# 1) x/y is 0D
x = paddle.rand([])
y = paddle.rand([])
x.stop_gradient = False
y.stop_gradient = False
if isinstance(api, dict):
out = api['func'](x, y)
out_cls = getattr(
paddle.static.Variable, api['cls_method']
)(x, y)
self.assertEqual(out.shape, out_cls.shape)
else:
out = api(x, y)
fluid.backward.append_backward(out)

# append_backward always set grad shape to [1]
prog = paddle.static.default_main_program()
block = prog.global_block()

# Test compile shape
self.assertEqual(x.shape, ())
self.assertEqual(y.shape, ())
self.assertEqual(out.shape, ())

exe = fluid.Executor()
result = exe.run(main_prog, fetch_list=[x, y, out])

out_np = exe.run(main_prog, fetch_list=[out])[0]
# Test runtime shape
self.assertEqual(result[0].shape, ())
self.assertEqual(result[1].shape, ())
self.assertEqual(result[2].shape, ())
self.assertEqual(out_np.shape, ())

# 2) x is 0D , y is scalar
x = paddle.rand([])
y = 0.5
x.stop_gradient = False
if isinstance(api, dict):
out = getattr(paddle.static.Variable, api['cls_method'])(
x, y
)
self.assertEqual(out.shape, ())

for api in binary_int_api_list_without_grad:
main_prog = fluid.Program()
with fluid.program_guard(main_prog, fluid.Program()):
# 1) x/y is 0D
x = paddle.randint(-10, 10, [])
y = paddle.randint(-10, 10, [])
out = api(x, y)
self.assertEqual(out.shape, ())

# 2) x is not 0D , y is 0D
x = paddle.randint(-10, 10, [3, 5])
y = paddle.randint(-10, 10, [])
out = api(x, y)
self.assertEqual(out.shape, (3, 5))

# 3) x is 0D , y is not 0D
x = paddle.randint(-10, 10, [])
y = paddle.randint(-10, 10, [3, 5])
out = api(x, y)
self.assertEqual(out.shape, (3, 5))

paddle.disable_static()


# Use to test zero-dim of Sundry API, which is simple and do
# not have backward, or is not need to test backward in OpTest.
class TestSundryAPI(unittest.TestCase):
def setUp(self):
self.x = paddle.rand([])

def test_linear(self):
x = paddle.randn([3, 2])
w = paddle.full(shape=[2, 4], fill_value=0.5)
b = paddle.zeros([])

np.testing.assert_array_equal(
F.linear(x, w, b).numpy(), F.linear(x, w).numpy()
)

def test_is_complex(self):
x = paddle.rand([]) + 1j * paddle.rand([])
self.assertTrue(paddle.is_complex(x))

def test_is_floating_point(self):
self.assertTrue(paddle.is_floating_point(self.x))

def test_is_integer(self):
x = paddle.randint(0, 10, [])
self.assertTrue(paddle.is_integer(x))

def test_is_tensor(self):
self.assertTrue(paddle.is_tensor(self.x))

def test_is_empty(self):
x = paddle.rand([3, 0, 5])
self.assertTrue(paddle.is_empty(x))

def test_isfinite(self):
out = paddle.isfinite(self.x)
np.testing.assert_array_equal(out.numpy(), np.array(True))

def test_isinf(self):
x = paddle.to_tensor(np.array(float('-inf')))
out = paddle.isinf(x)
np.testing.assert_array_equal(out.numpy(), np.array(True))

def test_isnan(self):
x = paddle.to_tensor(np.array(float('nan')))
out = paddle.isnan(x)
np.testing.assert_array_equal(out.numpy(), np.array(True))

def test_isclose(self):
out = paddle.isclose(self.x, self.x)
np.testing.assert_array_equal(out.numpy(), np.array(True))

def test_clone(self):
out = paddle.clone(self.x)
np.testing.assert_array_equal(out.numpy(), self.x.numpy())

def test_assign(self):
out = paddle.assign(self.x)
np.testing.assert_array_equal(out.numpy(), self.x.numpy())

def test_item(self):
x = paddle.full([], 0.5)
self.assertEqual(x.item(), 0.5)

def test_tolist(self):
x = paddle.full([], 0.5)
self.assertEqual(x.tolist(), 0.5)

def test_numpy(self):
x = paddle.full([], 0.5)
np.testing.assert_array_equal(x.numpy(), np.array(0.5))

def test_numel(self):
out = paddle.numel(self.x)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(1))

def test_rank(self):
out = paddle.rank(self.x)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(0))

def test_shape(self):
out = paddle.shape(self.x)
self.assertEqual(out.shape, [0])
np.testing.assert_array_equal(out.numpy(), np.array([]))


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions python/paddle/incubate/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ def batch_norm_orig2prim(

@REGISTER_ORIG2PRIM('size')
def size_orig2prim(op, x):
# TODO(zhouwei): will change shape [1] to [] to support zero-dim
return fill_const(
functools.reduce(operator.mul, x.shape), (1,), paddle.int64
)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
)
n = n.astype(dtype)
if unbiased:
one_const = paddle.ones([1], x.dtype)
one_const = paddle.ones([], x.dtype)
n = where(n > one_const, n - 1.0, one_const)
out /= n
return out
Expand Down