Skip to content

Commit 50e889f

Browse files
DongBaiYueLuckycheng222
authored andcommitted
[API compatibility] softmax, nonzero, randn (PaddlePaddle#74623)
* [API compatibility] softmax, nonzero, randn * delete chinese * deleta *shape * fix comment example * fix * fix
1 parent 8b2eddb commit 50e889f

File tree

6 files changed

+177
-19
lines changed

6 files changed

+177
-19
lines changed

python/paddle/nn/functional/activation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import paddle
2020
from paddle import _C_ops, in_dynamic_mode
2121
from paddle.framework import core, in_dynamic_or_pir_mode
22+
from paddle.utils.decorator_utils import ParamAliasDecorator
2223
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
2324

2425
from ...base.data_feeder import check_dtype, check_variable_and_dtype
@@ -1127,6 +1128,7 @@ def silu(x: Tensor, name: str | None = None) -> Tensor:
11271128
return out
11281129

11291130

1131+
@ParamAliasDecorator({"x": ["input"], "axis": ["dim"]})
11301132
def softmax(
11311133
x: Tensor,
11321134
axis: int = -1,
@@ -1208,12 +1210,18 @@ def softmax(
12081210
[0.26762315, 0.26762315, 0.26762315, 0.26762315],
12091211
[0.72747516, 0.72747516, 0.72747516, 0.72747516]]]
12101212
1213+
.. note::
1214+
Alias Support: The parameter name ``input`` can be used as an alias for ``x``, and ``dim`` can be used as an alias for ``axis``.
1215+
For example, ``softmax(input=tensor_x, dim=1, ...)`` is equivalent to ``softmax(x=tensor_x, axis=1, ...)``.
1216+
12111217
Parameters:
12121218
x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64.
1219+
alias: ``input``.
12131220
axis (int, optional): The axis along which to perform softmax
12141221
calculations. It should be in range [-D, D), where D is the
12151222
rank of ``x`` . If ``axis`` < 0, it works the same way as
12161223
:math:`axis + D` . Default is -1.
1224+
alias: ``dim``.
12171225
dtype (str, optional): The data type of the output tensor, can be bfloat16, float16, float32, float64.
12181226
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
12191227

python/paddle/tensor/random.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
in_pir_mode,
3030
use_pir_api,
3131
)
32-
from paddle.utils.decorator_utils import param_one_alias
32+
from paddle.utils.decorator_utils import SizeArgsDecorator, param_one_alias
3333

3434
from ..base.data_feeder import (
3535
check_dtype,
@@ -903,6 +903,7 @@ def standard_normal(
903903
return gaussian(shape=shape, mean=0.0, std=1.0, dtype=dtype, name=name)
904904

905905

906+
@SizeArgsDecorator()
906907
def randn(
907908
shape: ShapeLike, dtype: DTypeLike | None = None, name: str | None = None
908909
) -> Tensor:
@@ -912,9 +913,11 @@ def randn(
912913
and ``dtype``.
913914
914915
Args:
915-
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
916+
shape (tuple|list|Tensor|*shape): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
916917
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
917918
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
919+
If ``shape`` is *shape, directly pass integers as variable-length arguments (e.g., `randn(2, 3)`).
920+
alias: ``size``.
918921
dtype (str|np.dtype|paddle.dtype|None, optional): The data type of the output Tensor.
919922
Supported data types: float16, bfloat16, float32, float64, complex64, complex128.
920923
Default is None, use global default dtype (see ``get_default_dtype``

python/paddle/tensor/search.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import paddle
2323
from paddle import _C_ops
2424
from paddle.common_ops_import import VarDesc, Variable
25-
from paddle.utils.decorator_utils import ParamAliasDecorator
25+
from paddle.utils.decorator_utils import ParamAliasDecorator, param_one_alias
2626
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
2727

2828
from ..base.data_feeder import check_dtype, check_variable_and_dtype
@@ -467,7 +467,8 @@ def nonzero(x: Tensor, as_tuple: Literal[True] = ...) -> tuple[Tensor, ...]: ...
467467
def nonzero(x: Tensor, as_tuple: bool = ...) -> Tensor | tuple[Tensor, ...]: ...
468468

469469

470-
def nonzero(x: Tensor, as_tuple=False):
470+
@param_one_alias(['x', 'input'])
471+
def nonzero(x: Tensor, as_tuple=False, *, out: Tensor | None = None):
471472
"""
472473
Return a tensor containing the indices of all non-zero elements of the `input`
473474
tensor. If as_tuple is True, return a tuple of 1-D tensors, one for each dimension
@@ -477,9 +478,15 @@ def nonzero(x: Tensor, as_tuple=False):
477478
number of all non-zero elements in the `input` tensor. If as_tuple is True, we can get
478479
a 1-D tensor tuple of length `n`, and the shape of each 1-D tensor is [z, 1].
479480
481+
.. note::
482+
Alias Support: The parameter name ``input`` can be used as an alias for ``x``.
483+
For example, ``nonzero(input=tensor_x)`` is equivalent to ``nonzero(x=tensor_x)``.
484+
480485
Args:
481486
x (Tensor): The input tensor variable.
487+
alias: ``input``.
482488
as_tuple (bool, optional): Return type, Tensor or tuple of Tensor.
489+
out (Tensor|None, optional): The output tensor. Default: None.
483490
484491
Returns:
485492
Tensor or tuple of Tensor, The data type is int64.
@@ -504,14 +511,10 @@ def nonzero(x: Tensor, as_tuple=False):
504511
>>> out_z1_tuple = paddle.nonzero(x1, as_tuple=True)
505512
>>> for out in out_z1_tuple:
506513
... print(out)
507-
Tensor(shape=[3, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
508-
[[0],
509-
[1],
510-
[2]])
511-
Tensor(shape=[3, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
512-
[[0],
513-
[1],
514-
[2]])
514+
Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True,
515+
[0, 1, 2])
516+
Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True,
517+
[0, 1, 2])
515518
516519
>>> out_z2 = paddle.nonzero(x2)
517520
>>> print(out_z2)
@@ -522,13 +525,12 @@ def nonzero(x: Tensor, as_tuple=False):
522525
>>> out_z2_tuple = paddle.nonzero(x2, as_tuple=True)
523526
>>> for out in out_z2_tuple:
524527
... print(out)
525-
Tensor(shape=[2, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
526-
[[1],
527-
[3]])
528+
Tensor(shape=[2], dtype=int64, place=Place(cpu), stop_gradient=True,
529+
[1, 3])
528530
529531
"""
530532
if in_dynamic_or_pir_mode():
531-
outs = _C_ops.nonzero(x)
533+
outs = _C_ops.nonzero(x, out=out)
532534
else:
533535
check_variable_and_dtype(
534536
x,

test/legacy_test/test_nonzero_api.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
from op_test import OpTest, convert_float_to_uint16
19+
from utils import dygraph_guard
1920

2021
import paddle
2122
from paddle import base
@@ -228,5 +229,75 @@ def test_check_output(self):
228229
self.check_output(check_pir=True, check_symbol_infer=True)
229230

230231

232+
class TestNonzeroCompatibility(unittest.TestCase):
233+
def setUp(self):
234+
self.places = [paddle.CPUPlace()]
235+
if paddle.base.core.is_compiled_with_cuda():
236+
self.places.append(paddle.CUDAPlace(0))
237+
self.input_data = [[1, 0, 3], [0, 5, 0], [7, 0, 9]]
238+
self.expected_indices = np.array(
239+
[[0, 0], [0, 2], [1, 1], [2, 0], [2, 2]]
240+
)
241+
242+
def test_nonzero_with_param_aliases(self):
243+
with dygraph_guard():
244+
for place in self.places:
245+
paddle.device.set_device(place)
246+
input_tensor = paddle.to_tensor(
247+
self.input_data, dtype='float32'
248+
)
249+
for param_name in ['x', 'input']:
250+
for as_tuple in [False, True]:
251+
kwargs = {
252+
param_name: input_tensor,
253+
'as_tuple': as_tuple,
254+
}
255+
result = paddle.nonzero(**kwargs)
256+
if as_tuple:
257+
combined = np.stack(
258+
[r.numpy() for r in result], axis=1
259+
)
260+
np.testing.assert_array_equal(
261+
combined, self.expected_indices
262+
)
263+
else:
264+
np.testing.assert_array_equal(
265+
result.numpy(), self.expected_indices
266+
)
267+
268+
def test_nonzero_with_out(self):
269+
def run_nonzero(test_type):
270+
x = paddle.to_tensor(self.input_data, dtype='float32')
271+
x.stop_gradient = False
272+
out_shape = [len(self.expected_indices), 2]
273+
out = (
274+
paddle.zeros(out_shape, dtype='int64')
275+
if test_type in ["with_out", "both"]
276+
else None
277+
)
278+
if test_type == "return":
279+
out = paddle.nonzero(x, out=None)
280+
elif test_type == "with_out":
281+
paddle.nonzero(x, out=out)
282+
elif test_type == "both":
283+
out = paddle.nonzero(x, out=out)
284+
expected = paddle._C_ops.nonzero(x)
285+
np.testing.assert_array_equal(out.numpy(), expected.numpy())
286+
loss = out.sum().astype('float32')
287+
loss.backward()
288+
return out, x.grad
289+
290+
with dygraph_guard():
291+
for place in self.places:
292+
paddle.device.set_device(place)
293+
out1, _ = run_nonzero("return")
294+
out2, _ = run_nonzero("with_out")
295+
out3, _ = run_nonzero("both")
296+
for out in [out2, out3]:
297+
np.testing.assert_allclose(
298+
out1.numpy(), out.numpy(), rtol=1e-10
299+
)
300+
301+
231302
if __name__ == "__main__":
232303
unittest.main()

test/legacy_test/test_randn_op.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
from op_test import get_device_place
19+
from utils import dygraph_guard
1920

2021
import paddle
2122
from paddle.static import Program, program_guard
@@ -74,13 +75,45 @@ def test_api(self):
7475
class TestRandnOpError(unittest.TestCase):
7576
def test_error(self):
7677
with program_guard(Program(), Program()):
77-
# The argument shape's type of randn_op should be list or tuple.
78-
self.assertRaises(TypeError, paddle.randn, 1)
7978

8079
# The argument dtype of randn_op should be float32 or float64.
8180
self.assertRaises(TypeError, paddle.randn, [1, 2], 'int32')
8281

8382

83+
class TestRandnOpCompatibility(unittest.TestCase):
84+
def setUp(self):
85+
self.places = [paddle.CPUPlace()]
86+
if paddle.base.core.is_compiled_with_cuda():
87+
self.places.append(paddle.CUDAPlace(0))
88+
self.expected_shape = [2, 3]
89+
self.dtype = paddle.float32
90+
91+
def test_gather_with_param_aliases(self):
92+
with dygraph_guard():
93+
for place in self.places:
94+
paddle.device.set_device(place)
95+
for param_name in ['shape', 'size']:
96+
97+
tensor = paddle.randn(
98+
**{param_name: self.expected_shape}, dtype=self.dtype
99+
)
100+
self.assertEqual(tensor.shape, self.expected_shape)
101+
self.assertEqual(tensor.dtype, self.dtype)
102+
103+
shape_tensor = paddle.to_tensor(
104+
self.expected_shape, dtype='int32'
105+
)
106+
tensor = paddle.randn(
107+
**{param_name: shape_tensor}, dtype=self.dtype
108+
)
109+
self.assertEqual(tensor.shape, self.expected_shape)
110+
self.assertEqual(tensor.dtype, self.dtype)
111+
112+
tensor = paddle.randn(*self.expected_shape, dtype=self.dtype)
113+
self.assertEqual(tensor.shape, self.expected_shape)
114+
self.assertEqual(tensor.dtype, self.dtype)
115+
116+
84117
if __name__ == "__main__":
85118
paddle.enable_static()
86119
unittest.main()

test/legacy_test/test_softmax_op.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
get_device_place,
2222
get_places,
2323
)
24-
from utils import static_guard
24+
from utils import dygraph_guard, static_guard
2525

2626
import paddle
2727
import paddle.nn.functional as F
@@ -662,5 +662,46 @@ def test_dygraph(self):
662662
paddle.enable_static()
663663

664664

665+
class TestSoftmaxCompatibility(unittest.TestCase):
666+
def setUp(self):
667+
self.input = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
668+
self.axes = [0, 1]
669+
self.places = [paddle.CPUPlace()]
670+
if paddle.base.core.is_compiled_with_cuda():
671+
self.places.append(paddle.CUDAPlace(0))
672+
673+
def test_gather_with_param_aliases(self):
674+
with dygraph_guard():
675+
for place in self.places:
676+
paddle.device.set_device(place)
677+
for axis in self.axes:
678+
input_tensor = paddle.to_tensor(self.input, dtype='float32')
679+
for param_x in ['x', 'input']:
680+
for param_axis in ['axis', 'dim']:
681+
kwargs = {param_x: input_tensor, param_axis: axis}
682+
result = paddle.nn.functional.softmax(**kwargs)
683+
expected = np.exp(
684+
input_tensor.numpy()
685+
- np.max(
686+
input_tensor.numpy(),
687+
axis=axis,
688+
keepdims=True,
689+
)
690+
)
691+
expected = expected / np.sum(
692+
expected, axis=axis, keepdims=True
693+
)
694+
np.testing.assert_allclose(
695+
(
696+
result.numpy()
697+
if place.is_cpu_place()
698+
else result.cpu().numpy()
699+
),
700+
expected,
701+
rtol=1e-5,
702+
err_msg=f"Failed at axis={axis}, param_x={param_x}, param_axis={param_axis}",
703+
)
704+
705+
665706
if __name__ == "__main__":
666707
unittest.main()

0 commit comments

Comments
 (0)