From 156a027f46bddb78f5993a3b47accd58d34cf54a Mon Sep 17 00:00:00 2001 From: Netpunk <2327994230@qq.com> Date: Wed, 27 Sep 2023 10:52:57 +0800 Subject: [PATCH 1/7] add combinations API --- python/paddle/__init__.py | 2 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/math.py | 70 ++++++++++++++++ test/legacy_test/test_combinations.py | 112 ++++++++++++++++++++++++++ 4 files changed, 186 insertions(+) create mode 100644 test/legacy_test/test_combinations.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 68ac8d3a8a577..8ff63e89a5bec 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -375,6 +375,7 @@ from .tensor.math import i1e # noqa: F401 from .tensor.math import polygamma # noqa: F401 from .tensor.math import polygamma_ # noqa: F401 +from .tensor.math import combinations # noqa: F401 from .tensor.random import bernoulli # noqa: F401 from .tensor.random import poisson # noqa: F401 @@ -843,4 +844,5 @@ 'i1e', 'polygamma', 'polygamma_', + 'combinations', ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 841925f8b7ff8..588cb4cd74f64 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -323,6 +323,7 @@ from .math import i1e # noqa: F401 from .math import polygamma # noqa: F401 from .math import polygamma_ # noqa: F401 +from .math import combinations # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -673,6 +674,7 @@ 'i1e', 'polygamma', 'polygamma_', + "combinations", ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 41b3cb38d036f..8a95df5ca422f 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -6940,3 +6940,73 @@ def ldexp_(x, y, name=None): y = paddle.cast(y, dtype=out_dtype) two = paddle.to_tensor(2, dtype=out_dtype) return paddle.multiply_(x, paddle.pow(two, y)) + + +def combinations(x, r=2, with_replacement=False, name=None): + """ + Compute combinations of length r of the given tensor. The behavior is similar to python’s itertools.combinations + when with_replacement is set to False, and itertools.combinations_with_replacement when with_replacement is set to True. + + Args: + x (Tensor): 1-D input Tensor, the data type is float16, float32, float64, int32 or int64. + r (int, optional): number of elements to combine, default value is 2. + with_replacement (bool, optional): whether to allow duplication in combination, default value is False. + name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): tensor concatenated by combinations, same dtype with x + + Examples: + + .. code-block:: python + + >>> import paddle + + >>> # example1 + >>> x = paddle.to_tensor([1, 2, 3], dtype='float32') + >>> y = paddle.to_tensor([2, 3, 4], dtype='int32') + >>> res = paddle.ldexp(x, y) + >>> print(res) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [4. , 16., 48.]) + + >>> # example2 + >>> x = paddle.to_tensor([1, 2, 3], dtype='float32') + >>> y = paddle.to_tensor([2], dtype='int32') + >>> res = paddle.ldexp(x, y) + >>> print(res) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [4. , 8. , 12.]) + + """ + if len(x.shape) != 1: + raise TypeError("Expect a 1-D vector, but got x shape {}".format(x.shape)) + if not isinstance(r, int) or r < 0: + raise ValueError("Expect a non-negative int, but got r={}".format(r)) + + if r == 0: + return paddle.empty([0], dtype=x.dtype) + + if r > 1: + t_l = [x for i in range(r)] + grids = paddle.meshgrid(t_l) + else: + grids = [x] + num_elements = x.numel() + t_range = paddle.arange(num_elements, dtype='int64') + if r > 1: + t_l = [t_range for i in range(r)] + index_grids = paddle.meshgrid(t_l) + else: + index_grids = [t_range] + mask = paddle.full(x.shape * r, True, dtype='bool') + if with_replacement: + for i in range(r - 1): + mask *= index_grids[i] <= index_grids[i + 1] + else: + for i in range(r - 1): + mask *= index_grids[i] < index_grids[i + 1] + for i in range(r): + grids[i] = grids[i].masked_select(mask) + + return paddle.stack(grids, 1) diff --git a/test/legacy_test/test_combinations.py b/test/legacy_test/test_combinations.py new file mode 100644 index 0000000000000..6b892e61afce1 --- /dev/null +++ b/test/legacy_test/test_combinations.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from itertools import combinations, combinations_with_replacement + +import numpy as np + +import paddle +from paddle.base import Program + +paddle.enable_static() + + +def convert_combinations_to_array(x, r, with_replacement): + if r == 0: + return np.array([]).astype(x.dtype) + if with_replacement: + combs = combinations_with_replacement(x, r) + else: + combs = combinations(x, r) + combs = list(combs) + res = [] + for i in range(len(combs)): + res.append(list(combs[i])) + return np.array(res).astype(x.dtype) + + +class TestCombinationsAPIBase(unittest.TestCase): + def setUp(self): + self.init_setting() + self.modify_setting() + self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) + + self.place = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + + def init_setting(self): + self.dtype_np = 'float64' + self.x_shape = [10] + self.r = 5 + self.with_replacement = False + + def modify_setting(self): + pass + + def test_static_graph(self): + paddle.enable_static() + for place in self.place: + with paddle.static.program_guard(Program()): + x = paddle.static.data( + name="x", shape=self.x_shape, dtype=self.dtype_np + ) + out = paddle.combinations(x, self.r, self.with_replacement) + exe = paddle.static.Executor(place=place) + feed_list = {"x": self.x_np} + pd_res = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out], + )[0] + ref_res = convert_combinations_to_array(self.x_np, self.r, self.with_replacement) + np.testing.assert_allclose(ref_res, pd_res, atol=1e-5) + + def test_dygraph(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + x_pd = paddle.to_tensor(self.x_np) + pd_res = paddle.combinations(x_pd, self.r, self.with_replacement) + ref_res = convert_combinations_to_array(self.x_np, self.r, self.with_replacement) + np.testing.assert_allclose(ref_res, pd_res, atol=1e-5) + + def test_errors(self): + def test_input_not_1D(): + data_np = np.random.random((10, 10)).astype(np.float32) + res = paddle.combinations(data_np, self.r, self.with_replacement) + + self.assertRaises(TypeError, test_input_not_1D) + + def test_r_range(): + res = paddle.combinations(self.x_np, -1, self.with_replacement) + + self.assertRaises(ValueError, test_r_range) + + +class TestIndexFillAPI1(TestCombinationsAPIBase): + def modify_setting(self): + self.dtype_np = 'int32' + self.x_shape = [10] + self.r = 1 + self.with_replacement = True + + +class TestIndexFillAPI2(TestCombinationsAPIBase): + def modify_setting(self): + self.dtype_np = 'int64' + self.x_shape = [10] + self.r = 0 + self.with_replacement = True From 838630ffac08b1faf584bcee4efb8a61b753bb3c Mon Sep 17 00:00:00 2001 From: Netpunk <2327994230@qq.com> Date: Wed, 11 Oct 2023 11:03:01 +0800 Subject: [PATCH 2/7] format code --- python/paddle/tensor/math.py | 6 ++++-- test/legacy_test/test_combinations.py | 8 ++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 9e70c823d828f..c5b7bd8a6b32a 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -6972,9 +6972,11 @@ def combinations(x, r=2, with_replacement=False, name=None): """ if len(x.shape) != 1: - raise TypeError("Expect a 1-D vector, but got x shape {}".format(x.shape)) + raise TypeError( + f"Expect a 1-D vector, but got x shape {x.shape}" + ) if not isinstance(r, int) or r < 0: - raise ValueError("Expect a non-negative int, but got r={}".format(r)) + raise ValueError(f"Expect a non-negative int, but got r={r}") if r == 0: return paddle.empty([0], dtype=x.dtype) diff --git a/test/legacy_test/test_combinations.py b/test/legacy_test/test_combinations.py index 6b892e61afce1..cf4a55e48dd70 100644 --- a/test/legacy_test/test_combinations.py +++ b/test/legacy_test/test_combinations.py @@ -71,7 +71,9 @@ def test_static_graph(self): feed=feed_list, fetch_list=[out], )[0] - ref_res = convert_combinations_to_array(self.x_np, self.r, self.with_replacement) + ref_res = convert_combinations_to_array( + self.x_np, self.r, self.with_replacement + ) np.testing.assert_allclose(ref_res, pd_res, atol=1e-5) def test_dygraph(self): @@ -80,7 +82,9 @@ def test_dygraph(self): paddle.device.set_device(place) x_pd = paddle.to_tensor(self.x_np) pd_res = paddle.combinations(x_pd, self.r, self.with_replacement) - ref_res = convert_combinations_to_array(self.x_np, self.r, self.with_replacement) + ref_res = convert_combinations_to_array( + self.x_np, self.r, self.with_replacement + ) np.testing.assert_allclose(ref_res, pd_res, atol=1e-5) def test_errors(self): From f5303303b5738b037428049a9d2a4a04674808c8 Mon Sep 17 00:00:00 2001 From: Netpunk <2327994230@qq.com> Date: Wed, 11 Oct 2023 13:31:19 +0800 Subject: [PATCH 3/7] format --- python/paddle/tensor/math.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index c5b7bd8a6b32a..3cd755f9a63ef 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -6972,9 +6972,7 @@ def combinations(x, r=2, with_replacement=False, name=None): """ if len(x.shape) != 1: - raise TypeError( - f"Expect a 1-D vector, but got x shape {x.shape}" - ) + raise TypeError(f"Expect a 1-D vector, but got x shape {x.shape}") if not isinstance(r, int) or r < 0: raise ValueError(f"Expect a non-negative int, but got r={r}") From 767cda556956b69c13e9c80704489f138ba4115c Mon Sep 17 00:00:00 2001 From: Netpunk <2327994230@qq.com> Date: Mon, 16 Oct 2023 14:16:43 +0800 Subject: [PATCH 4/7] add test sample --- python/paddle/tensor/math.py | 23 +++++++---------------- test/legacy_test/test_combinations.py | 10 +++++++++- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 3cd755f9a63ef..ac7536d3107bc 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -6953,22 +6953,13 @@ def combinations(x, r=2, with_replacement=False, name=None): .. code-block:: python >>> import paddle - - >>> # example1 - >>> x = paddle.to_tensor([1, 2, 3], dtype='float32') - >>> y = paddle.to_tensor([2, 3, 4], dtype='int32') - >>> res = paddle.ldexp(x, y) + >>> x = paddle.to_tensor([1, 2, 3], dtype='int32') + >>> res = paddle.combinations(x) >>> print(res) - Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, - [4. , 16., 48.]) - - >>> # example2 - >>> x = paddle.to_tensor([1, 2, 3], dtype='float32') - >>> y = paddle.to_tensor([2], dtype='int32') - >>> res = paddle.ldexp(x, y) - >>> print(res) - Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, - [4. , 8. , 12.]) + Tensor(shape=[3, 2], dtype=int32, place=Place(gpu:0), stop_gradient=True, + [[1, 2], + [1, 3], + [2, 3]]) """ if len(x.shape) != 1: @@ -6976,7 +6967,7 @@ def combinations(x, r=2, with_replacement=False, name=None): if not isinstance(r, int) or r < 0: raise ValueError(f"Expect a non-negative int, but got r={r}") - if r == 0: + if r == 0 or r > x.shape[0]: return paddle.empty([0], dtype=x.dtype) if r > 1: diff --git a/test/legacy_test/test_combinations.py b/test/legacy_test/test_combinations.py index cf4a55e48dd70..36b91584907b9 100644 --- a/test/legacy_test/test_combinations.py +++ b/test/legacy_test/test_combinations.py @@ -23,7 +23,7 @@ paddle.enable_static() -def convert_combinations_to_array(x, r, with_replacement): +def convert_combinations_to_array(x, r=2, with_replacement=False): if r == 0: return np.array([]).astype(x.dtype) if with_replacement: @@ -114,3 +114,11 @@ def modify_setting(self): self.x_shape = [10] self.r = 0 self.with_replacement = True + + +class TestIndexFillAPI3(TestCombinationsAPIBase): + def modify_setting(self): + self.dtype_np = 'float32' + self.x_shape = [0] + self.r = 10 + self.with_replacement = False From ad486ff14494d6c0239865a0129737217c023282 Mon Sep 17 00:00:00 2001 From: Netpunk <2327994230@qq.com> Date: Tue, 17 Oct 2023 11:07:25 +0800 Subject: [PATCH 5/7] add empty test --- python/paddle/tensor/math.py | 7 ++++-- test/legacy_test/test_combinations.py | 36 +++++++++++++++++++-------- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index ac7536d3107bc..12b91b6ea7ad3 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -6967,8 +6967,11 @@ def combinations(x, r=2, with_replacement=False, name=None): if not isinstance(r, int) or r < 0: raise ValueError(f"Expect a non-negative int, but got r={r}") - if r == 0 or r > x.shape[0]: - return paddle.empty([0], dtype=x.dtype) + if r == 0: + return paddle.empty(shape=[0], dtype=x.dtype) + + if r > x.shape[0]: + return paddle.empty(shape=[0, r], dtype=x.dtype) if r > 1: t_l = [x for i in range(r)] diff --git a/test/legacy_test/test_combinations.py b/test/legacy_test/test_combinations.py index 36b91584907b9..e1dbcac521893 100644 --- a/test/legacy_test/test_combinations.py +++ b/test/legacy_test/test_combinations.py @@ -74,7 +74,7 @@ def test_static_graph(self): ref_res = convert_combinations_to_array( self.x_np, self.r, self.with_replacement ) - np.testing.assert_allclose(ref_res, pd_res, atol=1e-5) + np.testing.assert_allclose(ref_res, pd_res) def test_dygraph(self): paddle.disable_static() @@ -85,7 +85,7 @@ def test_dygraph(self): ref_res = convert_combinations_to_array( self.x_np, self.r, self.with_replacement ) - np.testing.assert_allclose(ref_res, pd_res, atol=1e-5) + np.testing.assert_allclose(ref_res, pd_res) def test_errors(self): def test_input_not_1D(): @@ -100,7 +100,7 @@ def test_r_range(): self.assertRaises(ValueError, test_r_range) -class TestIndexFillAPI1(TestCombinationsAPIBase): +class TestCombinationsAPI1(TestCombinationsAPIBase): def modify_setting(self): self.dtype_np = 'int32' self.x_shape = [10] @@ -108,7 +108,7 @@ def modify_setting(self): self.with_replacement = True -class TestIndexFillAPI2(TestCombinationsAPIBase): +class TestCombinationsAPI2(TestCombinationsAPIBase): def modify_setting(self): self.dtype_np = 'int64' self.x_shape = [10] @@ -116,9 +116,25 @@ def modify_setting(self): self.with_replacement = True -class TestIndexFillAPI3(TestCombinationsAPIBase): - def modify_setting(self): - self.dtype_np = 'float32' - self.x_shape = [0] - self.r = 10 - self.with_replacement = False +class TestCombinationsEmpty(unittest.TestCase): + def setUp(self): + self.place = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + + def test_dygraph(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + a = paddle.to_tensor([1, 2, 3]) + c = paddle.combinations(a, r=4) + expected = paddle.empty([0, 4]) + np.testing.assert_allclose(c, expected) + + # test empty imput + a = paddle.empty([0]) + c1 = paddle.combinations(a) + c2 = paddle.combinations(a, with_replacement=True) + expected = paddle.empty([0, 2]) + np.testing.assert_allclose(c1, expected) + np.testing.assert_allclose(c2, expected) From 798736ce1c9e09938dac9a3d2754bf58b7eddc42 Mon Sep 17 00:00:00 2001 From: Netpunk <2327994230@qq.com> Date: Sat, 18 Nov 2023 12:27:59 +0800 Subject: [PATCH 6/7] modify test sample --- python/paddle/tensor/math.py | 4 +++- test/legacy_test/test_combinations.py | 32 ++++++++++++++++++--------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index dda758eeeba2a..4a63c938ffbf3 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7033,7 +7033,9 @@ def combinations(x, r=2, with_replacement=False, name=None): if r == 0: return paddle.empty(shape=[0], dtype=x.dtype) - if r > x.shape[0]: + if (r > x.shape[0] and not with_replacement) or ( + x.shape[0] == 0 and with_replacement + ): return paddle.empty(shape=[0, r], dtype=x.dtype) if r > 1: diff --git a/test/legacy_test/test_combinations.py b/test/legacy_test/test_combinations.py index e1dbcac521893..76bad819ff6c8 100644 --- a/test/legacy_test/test_combinations.py +++ b/test/legacy_test/test_combinations.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import random import unittest from itertools import combinations, combinations_with_replacement @@ -34,7 +34,10 @@ def convert_combinations_to_array(x, r=2, with_replacement=False): res = [] for i in range(len(combs)): res.append(list(combs[i])) - return np.array(res).astype(x.dtype) + if len(res) != 0: + return np.array(res).astype(x.dtype) + else: + return np.empty((0, r)) class TestCombinationsAPIBase(unittest.TestCase): @@ -126,15 +129,22 @@ def test_dygraph(self): paddle.disable_static() for place in self.place: paddle.device.set_device(place) - a = paddle.to_tensor([1, 2, 3]) + a = paddle.rand([3], dtype='float32') c = paddle.combinations(a, r=4) - expected = paddle.empty([0, 4]) + expected = convert_combinations_to_array(a.numpy(), r=4) np.testing.assert_allclose(c, expected) - # test empty imput - a = paddle.empty([0]) - c1 = paddle.combinations(a) - c2 = paddle.combinations(a, with_replacement=True) - expected = paddle.empty([0, 2]) - np.testing.assert_allclose(c1, expected) - np.testing.assert_allclose(c2, expected) + # test empty input + a = paddle.empty([random.randint(0, 8)]) + c1 = paddle.combinations(a, r=2) + c2 = paddle.combinations(a, r=2, with_replacement=True) + expected1 = convert_combinations_to_array(a.numpy(), r=2) + expected2 = convert_combinations_to_array( + a.numpy(), r=2, with_replacement=True + ) + np.testing.assert_allclose(c1, expected1) + np.testing.assert_allclose(c2, expected2) + + +if __name__ == '__main__': + unittest.main() From 6976300f11f75011c337579e0fdddf8e71c32aaa Mon Sep 17 00:00:00 2001 From: Netpunk <2327994230@qq.com> Date: Sun, 26 Nov 2023 11:25:12 +0800 Subject: [PATCH 7/7] format doc --- python/paddle/tensor/math.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 18106302bbad3..2f097b3c54460 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7075,17 +7075,22 @@ def hypot_(x, y, name=None): def combinations(x, r=2, with_replacement=False, name=None): """ - Compute combinations of length r of the given tensor. The behavior is similar to python’s itertools.combinations + + Compute combinations of length r of the given tensor. The behavior is similar to python's itertools.combinations when with_replacement is set to False, and itertools.combinations_with_replacement when with_replacement is set to True. + Args: x (Tensor): 1-D input Tensor, the data type is float16, float32, float64, int32 or int64. r (int, optional): number of elements to combine, default value is 2. with_replacement (bool, optional): whether to allow duplication in combination, default value is False. name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. + Returns: - out (Tensor): tensor concatenated by combinations, same dtype with x + out (Tensor). Tensor concatenated by combinations, same dtype with x. + Examples: .. code-block:: python + >>> import paddle >>> x = paddle.to_tensor([1, 2, 3], dtype='int32') >>> res = paddle.combinations(x) @@ -7094,6 +7099,7 @@ def combinations(x, r=2, with_replacement=False, name=None): [[1, 2], [1, 3], [2, 3]]) + """ if len(x.shape) != 1: raise TypeError(f"Expect a 1-D vector, but got x shape {x.shape}")