Skip to content

Commit 544882d

Browse files
fix
1 parent 4f7efea commit 544882d

File tree

3 files changed

+99
-7
lines changed

3 files changed

+99
-7
lines changed

python/paddle/tensor/creation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
import paddle
2727
from paddle import _C_ops
2828
from paddle.utils import deprecated
29-
from paddle.utils.decorator_utils import ParamAliasDecorator, view_decorator
29+
from paddle.utils.decorator_utils import (
30+
ParamAliasDecorator,
31+
size_args_decorator,
32+
)
3033
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
3134

3235
from ..base.data_feeder import (
@@ -1278,7 +1281,7 @@ def fill_constant(
12781281
return out
12791282

12801283

1281-
@view_decorator()
1284+
@size_args_decorator
12821285
def ones(
12831286
shape: ShapeLike,
12841287
dtype: DTypeLike | None = None,
@@ -1405,7 +1408,7 @@ def ones_like(
14051408
)
14061409

14071410

1408-
@view_decorator()
1411+
@size_args_decorator
14091412
def zeros(
14101413
shape: ShapeLike,
14111414
dtype: DTypeLike | None = None,
@@ -2825,7 +2828,7 @@ def diag(
28252828
return out
28262829

28272830

2828-
@view_decorator()
2831+
@size_args_decorator
28292832
def empty(
28302833
shape: ShapeLike,
28312834
dtype: DTypeLike | None = None,

python/paddle/utils/decorator_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,35 @@ def process(
246246
return args, kwargs
247247

248248

249+
def size_args_decorator(func: Callable) -> Callable:
250+
"""
251+
A decorator that normalizes the 'size' argument to 'shape'.
252+
253+
Usage Example:
254+
255+
paddle.ones(1, dtype=paddle.float32)
256+
paddle.ones(1, 2, 3, dtype=paddle.float32)
257+
paddle.ones([1, 2, 3], dtype=paddle.float32)
258+
paddle.ones(size=[1, 2, 3], dtype=paddle.float32)
259+
paddle.ones([1, 2, 3], paddle.float32)
260+
paddle.ones(shape=[1, 2, 3], dtype=paddle.float32)
261+
"""
262+
263+
@functools.wraps(func)
264+
def wrapper(*args: Any, **kwargs: Any) -> Any:
265+
# 如果 kwargs 中有 'size',改成 'shape'
266+
if 'size' in kwargs:
267+
kwargs['shape'] = kwargs.pop('size')
268+
# 如果第一个位置参数是整数,则把所有位置参数组成 shape
269+
elif len(args) >= 1 and all(isinstance(a, int) for a in args):
270+
kwargs['shape'] = list(args)
271+
args = ()
272+
273+
return func(*args, **kwargs)
274+
275+
return wrapper
276+
277+
249278
class VariableArgsDecorator(DecoratorBase):
250279
def __init__(self, var: str) -> None:
251280
super().__init__()

test/legacy_test/test_creation.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,31 @@ def test_ones(self):
5353
self.assertEqual(x.stop_gradient, not requires_grad)
5454
if isinstance(dtype, paddle.dtype):
5555
self.assertEqual(x.dtype, dtype)
56+
57+
def wrapped_ones(
58+
shape,
59+
dtype=None,
60+
name=None,
61+
*,
62+
out=None,
63+
device=None,
64+
requires_grad=False,
65+
):
66+
return paddle.ones(
67+
shape,
68+
dtype,
69+
name,
70+
out=out,
71+
device=device,
72+
requires_grad=requires_grad,
73+
)
74+
5675
st_f = paddle.jit.to_static(
57-
paddle.ones, full_graph=True, backend=None
76+
wrapped_ones, full_graph=True, backend=None
5877
)
5978
x = st_f(
6079
[2],
80+
out=None,
6181
dtype=dtype,
6282
requires_grad=requires_grad,
6383
device=device,
@@ -84,11 +104,31 @@ def test_zeros(self):
84104
self.assertEqual(x.stop_gradient, not requires_grad)
85105
if isinstance(dtype, paddle.dtype):
86106
self.assertEqual(x.dtype, dtype)
107+
108+
def wrapped_zeros(
109+
shape,
110+
dtype=None,
111+
name=None,
112+
*,
113+
out=None,
114+
device=None,
115+
requires_grad=False,
116+
):
117+
return paddle.zeros(
118+
shape,
119+
dtype,
120+
name,
121+
out=out,
122+
device=device,
123+
requires_grad=requires_grad,
124+
)
125+
87126
st_f = paddle.jit.to_static(
88-
paddle.zeros, full_graph=True, backend=None
127+
wrapped_zeros, full_graph=True, backend=None
89128
)
90129
x = st_f(
91130
[2],
131+
out=None,
92132
dtype=dtype,
93133
requires_grad=requires_grad,
94134
device=device,
@@ -148,11 +188,31 @@ def test_empty(self):
148188
self.assertEqual(x.stop_gradient, not requires_grad)
149189
if isinstance(dtype, paddle.dtype):
150190
self.assertEqual(x.dtype, dtype)
191+
192+
def wrapped_empty(
193+
shape,
194+
dtype=None,
195+
name=None,
196+
*,
197+
out=None,
198+
device=None,
199+
requires_grad=False,
200+
):
201+
return paddle.empty(
202+
shape,
203+
dtype,
204+
name,
205+
out=out,
206+
device=device,
207+
requires_grad=requires_grad,
208+
)
209+
151210
st_f = paddle.jit.to_static(
152-
paddle.empty, full_graph=True, backend=None
211+
wrapped_empty, full_graph=True, backend=None
153212
)
154213
x = st_f(
155214
[2],
215+
out=None,
156216
dtype=dtype,
157217
requires_grad=requires_grad,
158218
device=device,

0 commit comments

Comments
 (0)