-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.14】Add combinations API to Paddle #57792
Changes from 14 commits
156a027
a02bc01
383f9bc
838630f
f530330
767cda5
ad486ff
7b45322
308a4f1
f85ed3f
798736c
d96cb13
ba2f97d
6976300
34c8f7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7071,3 +7071,69 @@ def hypot_(x, y, name=None): | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
out = x.pow_(2).add_(y.pow(2)).sqrt_() | ||||||||||||||||||||||||||||||||||||
return out | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def combinations(x, r=2, with_replacement=False, name=None): | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Compute combinations of length r of the given tensor. The behavior is similar to python's itertools.combinations | ||||||||||||||||||||||||||||||||||||
when with_replacement is set to False, and itertools.combinations_with_replacement when with_replacement is set to True. | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
x (Tensor): 1-D input Tensor, the data type is float16, float32, float64, int32 or int64. | ||||||||||||||||||||||||||||||||||||
r (int, optional): number of elements to combine, default value is 2. | ||||||||||||||||||||||||||||||||||||
with_replacement (bool, optional): whether to allow duplication in combination, default value is False. | ||||||||||||||||||||||||||||||||||||
name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||
out (Tensor). Tensor concatenated by combinations, same dtype with x. | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Examples: | ||||||||||||||||||||||||||||||||||||
.. code-block:: python | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
>>> import paddle | ||||||||||||||||||||||||||||||||||||
>>> x = paddle.to_tensor([1, 2, 3], dtype='int32') | ||||||||||||||||||||||||||||||||||||
>>> res = paddle.combinations(x) | ||||||||||||||||||||||||||||||||||||
>>> print(res) | ||||||||||||||||||||||||||||||||||||
Tensor(shape=[3, 2], dtype=int32, place=Place(gpu:0), stop_gradient=True, | ||||||||||||||||||||||||||||||||||||
[[1, 2], | ||||||||||||||||||||||||||||||||||||
[1, 3], | ||||||||||||||||||||||||||||||||||||
[2, 3]]) | ||||||||||||||||||||||||||||||||||||
Comment on lines
+7094
to
+7101
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
if len(x.shape) != 1: | ||||||||||||||||||||||||||||||||||||
raise TypeError(f"Expect a 1-D vector, but got x shape {x.shape}") | ||||||||||||||||||||||||||||||||||||
if not isinstance(r, int) or r < 0: | ||||||||||||||||||||||||||||||||||||
raise ValueError(f"Expect a non-negative int, but got r={r}") | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if r == 0: | ||||||||||||||||||||||||||||||||||||
return paddle.empty(shape=[0], dtype=x.dtype) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if (r > x.shape[0] and not with_replacement) or ( | ||||||||||||||||||||||||||||||||||||
x.shape[0] == 0 and with_replacement | ||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||
return paddle.empty(shape=[0, r], dtype=x.dtype) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if r > 1: | ||||||||||||||||||||||||||||||||||||
t_l = [x for i in range(r)] | ||||||||||||||||||||||||||||||||||||
grids = paddle.meshgrid(t_l) | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
grids = [x] | ||||||||||||||||||||||||||||||||||||
num_elements = x.numel() | ||||||||||||||||||||||||||||||||||||
t_range = paddle.arange(num_elements, dtype='int64') | ||||||||||||||||||||||||||||||||||||
if r > 1: | ||||||||||||||||||||||||||||||||||||
t_l = [t_range for i in range(r)] | ||||||||||||||||||||||||||||||||||||
index_grids = paddle.meshgrid(t_l) | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
index_grids = [t_range] | ||||||||||||||||||||||||||||||||||||
mask = paddle.full(x.shape * r, True, dtype='bool') | ||||||||||||||||||||||||||||||||||||
if with_replacement: | ||||||||||||||||||||||||||||||||||||
for i in range(r - 1): | ||||||||||||||||||||||||||||||||||||
mask *= index_grids[i] <= index_grids[i + 1] | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
for i in range(r - 1): | ||||||||||||||||||||||||||||||||||||
mask *= index_grids[i] < index_grids[i + 1] | ||||||||||||||||||||||||||||||||||||
for i in range(r): | ||||||||||||||||||||||||||||||||||||
grids[i] = grids[i].masked_select(mask) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return paddle.stack(grids, 1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import random | ||
import unittest | ||
from itertools import combinations, combinations_with_replacement | ||
|
||
import numpy as np | ||
|
||
import paddle | ||
from paddle.base import Program | ||
|
||
paddle.enable_static() | ||
|
||
|
||
def convert_combinations_to_array(x, r=2, with_replacement=False): | ||
if r == 0: | ||
return np.array([]).astype(x.dtype) | ||
if with_replacement: | ||
combs = combinations_with_replacement(x, r) | ||
else: | ||
combs = combinations(x, r) | ||
combs = list(combs) | ||
res = [] | ||
for i in range(len(combs)): | ||
res.append(list(combs[i])) | ||
if len(res) != 0: | ||
return np.array(res).astype(x.dtype) | ||
else: | ||
return np.empty((0, r)) | ||
|
||
|
||
class TestCombinationsAPIBase(unittest.TestCase): | ||
def setUp(self): | ||
self.init_setting() | ||
self.modify_setting() | ||
self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) | ||
|
||
self.place = ['cpu'] | ||
if paddle.is_compiled_with_cuda(): | ||
self.place.append('gpu') | ||
|
||
def init_setting(self): | ||
self.dtype_np = 'float64' | ||
self.x_shape = [10] | ||
self.r = 5 | ||
self.with_replacement = False | ||
|
||
def modify_setting(self): | ||
pass | ||
|
||
def test_static_graph(self): | ||
paddle.enable_static() | ||
for place in self.place: | ||
with paddle.static.program_guard(Program()): | ||
x = paddle.static.data( | ||
name="x", shape=self.x_shape, dtype=self.dtype_np | ||
) | ||
out = paddle.combinations(x, self.r, self.with_replacement) | ||
exe = paddle.static.Executor(place=place) | ||
feed_list = {"x": self.x_np} | ||
pd_res = exe.run( | ||
paddle.static.default_main_program(), | ||
feed=feed_list, | ||
fetch_list=[out], | ||
)[0] | ||
ref_res = convert_combinations_to_array( | ||
self.x_np, self.r, self.with_replacement | ||
) | ||
np.testing.assert_allclose(ref_res, pd_res) | ||
|
||
def test_dygraph(self): | ||
paddle.disable_static() | ||
for place in self.place: | ||
paddle.device.set_device(place) | ||
x_pd = paddle.to_tensor(self.x_np) | ||
pd_res = paddle.combinations(x_pd, self.r, self.with_replacement) | ||
ref_res = convert_combinations_to_array( | ||
self.x_np, self.r, self.with_replacement | ||
) | ||
np.testing.assert_allclose(ref_res, pd_res) | ||
|
||
def test_errors(self): | ||
def test_input_not_1D(): | ||
data_np = np.random.random((10, 10)).astype(np.float32) | ||
res = paddle.combinations(data_np, self.r, self.with_replacement) | ||
|
||
self.assertRaises(TypeError, test_input_not_1D) | ||
|
||
def test_r_range(): | ||
res = paddle.combinations(self.x_np, -1, self.with_replacement) | ||
|
||
self.assertRaises(ValueError, test_r_range) | ||
|
||
|
||
class TestCombinationsAPI1(TestCombinationsAPIBase): | ||
def modify_setting(self): | ||
self.dtype_np = 'int32' | ||
self.x_shape = [10] | ||
self.r = 1 | ||
self.with_replacement = True | ||
|
||
|
||
class TestCombinationsAPI2(TestCombinationsAPIBase): | ||
def modify_setting(self): | ||
self.dtype_np = 'int64' | ||
self.x_shape = [10] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 缺少了输入为empty情况下的单测 |
||
self.r = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 缺少r>x_shape情况的单测 |
||
self.with_replacement = True | ||
|
||
|
||
class TestCombinationsEmpty(unittest.TestCase): | ||
def setUp(self): | ||
self.place = ['cpu'] | ||
if paddle.is_compiled_with_cuda(): | ||
self.place.append('gpu') | ||
|
||
def test_dygraph(self): | ||
paddle.disable_static() | ||
for place in self.place: | ||
paddle.device.set_device(place) | ||
a = paddle.rand([3], dtype='float32') | ||
c = paddle.combinations(a, r=4) | ||
expected = convert_combinations_to_array(a.numpy(), r=4) | ||
np.testing.assert_allclose(c, expected) | ||
|
||
# test empty input | ||
a = paddle.empty([random.randint(0, 8)]) | ||
c1 = paddle.combinations(a, r=2) | ||
c2 = paddle.combinations(a, r=2, with_replacement=True) | ||
expected1 = convert_combinations_to_array(a.numpy(), r=2) | ||
expected2 = convert_combinations_to_array( | ||
a.numpy(), r=2, with_replacement=True | ||
) | ||
np.testing.assert_allclose(c1, expected1) | ||
np.testing.assert_allclose(c2, expected2) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
严格按照模板(包括空行)