Skip to content

Commit e7530ff

Browse files
[API Compatiblity] Fix range default dtype (#74772)
* fix range default dtype from int64 to float * fix range and its' UT * use view_decorator * fix * fix
1 parent e8ca424 commit e7530ff

File tree

4 files changed

+118
-25
lines changed

4 files changed

+118
-25
lines changed

python/paddle/jit/dy2static/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ def get_new_globals(original_fn, generated_fn):
639639
argdefs=callable_func.__defaults__,
640640
closure=get_new_closure(dyfunc, callable_func),
641641
)
642+
new_fn.__kwdefaults__ = callable_func.__kwdefaults__
642643

643644
return new_fn, f.name
644645

python/paddle/tensor/creation.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
from paddle.utils import deprecated
2929
from paddle.utils.decorator_utils import (
3030
ParamAliasDecorator,
31-
SizeArgsDecorator,
3231
param_one_alias,
3332
param_two_alias,
33+
size_args_decorator,
3434
)
3535
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
3636

@@ -1386,7 +1386,7 @@ def fill_constant(
13861386
return out
13871387

13881388

1389-
@SizeArgsDecorator()
1389+
@size_args_decorator
13901390
def ones(
13911391
shape: ShapeLike,
13921392
dtype: DTypeLike | None = None,
@@ -1513,7 +1513,7 @@ def ones_like(
15131513
)
15141514

15151515

1516-
@SizeArgsDecorator()
1516+
@size_args_decorator
15171517
def zeros(
15181518
shape: ShapeLike,
15191519
dtype: DTypeLike | None = None,
@@ -2073,13 +2073,14 @@ def arange(
20732073
reason=(
20742074
"paddle.range is deprecated and will be removed in a future release because its behavior is inconsistent with Python's range builtin."
20752075
"Instead, use paddle.arange, which produces values in [start, end)"
2076-
)
2076+
),
2077+
level=1,
20772078
)
20782079
def range(
20792080
start: float | paddle.Tensor = 0,
20802081
end: float | paddle.Tensor | None = None,
20812082
step: float | paddle.Tensor = 1,
2082-
dtype=None,
2083+
dtype: DTypeLike = None,
20832084
*,
20842085
out: paddle.Tensor | None = None,
20852086
device: PlaceLike | None = None,
@@ -2158,19 +2159,7 @@ def range(
21582159
start = 0
21592160

21602161
if dtype is None:
2161-
for val in [start, end, step]:
2162-
if isinstance(val, (Variable, paddle.pir.Value)):
2163-
if not paddle.is_integer(val):
2164-
dtype = paddle.get_default_dtype()
2165-
break
2166-
else:
2167-
dtype = 'int64'
2168-
else:
2169-
if not isinstance(val, np.integer) and not isinstance(val, int):
2170-
dtype = paddle.get_default_dtype()
2171-
break
2172-
else:
2173-
dtype = 'int64'
2162+
dtype = paddle.get_default_dtype()
21742163

21752164
is_value_input = (
21762165
not isinstance(start, (Variable, paddle.pir.Value))
@@ -2951,7 +2940,7 @@ def diag(
29512940
return out
29522941

29532942

2954-
@SizeArgsDecorator()
2943+
@size_args_decorator
29552944
def empty(
29562945
shape: ShapeLike,
29572946
dtype: DTypeLike | None = None,

python/paddle/utils/decorator_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,35 @@ def process(
273273
return args, kwargs
274274

275275

276+
def size_args_decorator(func: Callable) -> Callable:
277+
"""
278+
A decorator that normalizes the 'size' argument to 'shape'.
279+
280+
Usage Example:
281+
282+
paddle.ones(1, dtype=paddle.float32)
283+
paddle.ones(1, 2, 3, dtype=paddle.float32)
284+
paddle.ones([1, 2, 3], dtype=paddle.float32)
285+
paddle.ones(size=[1, 2, 3], dtype=paddle.float32)
286+
paddle.ones([1, 2, 3], paddle.float32)
287+
paddle.ones(shape=[1, 2, 3], dtype=paddle.float32)
288+
"""
289+
290+
@functools.wraps(func)
291+
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
292+
if 'size' in kwargs:
293+
kwargs['shape'] = kwargs.pop('size')
294+
elif len(args) >= 1 and isinstance(args[0], int):
295+
kwargs['shape'] = list(args)
296+
args = ()
297+
298+
return func(*args, **kwargs)
299+
300+
wrapped_func.__signature__ = inspect.signature(func)
301+
302+
return wrapped_func
303+
304+
276305
class VariableArgsDecorator(DecoratorBase):
277306
def __init__(self, var: str) -> None:
278307
super().__init__()

test/legacy_test/test_creation.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def setUp(self):
3535
self.devices.append(paddle.device.IPUPlace())
3636

3737
self.requires_grads = [True, False]
38-
self.dtypes = ["float32", paddle.float32, "int32", paddle.int32]
38+
self.dtypes = [None, "float32", paddle.float32, "int32", paddle.int32]
3939

4040
def test_ones(self):
4141
for device, requires_grad, dtype in product(
@@ -53,11 +53,31 @@ def test_ones(self):
5353
self.assertEqual(x.stop_gradient, not requires_grad)
5454
if isinstance(dtype, paddle.dtype):
5555
self.assertEqual(x.dtype, dtype)
56+
57+
def wrapped_ones(
58+
shape,
59+
dtype=None,
60+
name=None,
61+
*,
62+
out=None,
63+
device=None,
64+
requires_grad=False,
65+
):
66+
return paddle.ones(
67+
shape,
68+
dtype,
69+
name,
70+
out=out,
71+
device=device,
72+
requires_grad=requires_grad,
73+
)
74+
5675
st_f = paddle.jit.to_static(
57-
paddle.ones, full_graph=True, backend=None
76+
wrapped_ones, full_graph=True, backend=None
5877
)
5978
x = st_f(
6079
[2],
80+
out=None,
6181
dtype=dtype,
6282
requires_grad=requires_grad,
6383
device=device,
@@ -84,11 +104,31 @@ def test_zeros(self):
84104
self.assertEqual(x.stop_gradient, not requires_grad)
85105
if isinstance(dtype, paddle.dtype):
86106
self.assertEqual(x.dtype, dtype)
107+
108+
def wrapped_zeros(
109+
shape,
110+
dtype=None,
111+
name=None,
112+
*,
113+
out=None,
114+
device=None,
115+
requires_grad=False,
116+
):
117+
return paddle.zeros(
118+
shape,
119+
dtype,
120+
name,
121+
out=out,
122+
device=device,
123+
requires_grad=requires_grad,
124+
)
125+
87126
st_f = paddle.jit.to_static(
88-
paddle.zeros, full_graph=True, backend=None
127+
wrapped_zeros, full_graph=True, backend=None
89128
)
90129
x = st_f(
91130
[2],
131+
out=None,
92132
dtype=dtype,
93133
requires_grad=requires_grad,
94134
device=device,
@@ -148,11 +188,31 @@ def test_empty(self):
148188
self.assertEqual(x.stop_gradient, not requires_grad)
149189
if isinstance(dtype, paddle.dtype):
150190
self.assertEqual(x.dtype, dtype)
191+
192+
def wrapped_empty(
193+
shape,
194+
dtype=None,
195+
name=None,
196+
*,
197+
out=None,
198+
device=None,
199+
requires_grad=False,
200+
):
201+
return paddle.empty(
202+
shape,
203+
dtype,
204+
name,
205+
out=out,
206+
device=device,
207+
requires_grad=requires_grad,
208+
)
209+
151210
st_f = paddle.jit.to_static(
152-
paddle.empty, full_graph=True, backend=None
211+
wrapped_empty, full_graph=True, backend=None
153212
)
154213
x = st_f(
155214
[2],
215+
out=None,
156216
dtype=dtype,
157217
requires_grad=requires_grad,
158218
device=device,
@@ -368,6 +428,8 @@ def range_manual(start, end, step, dtype, device, requires_grad):
368428
if end is None:
369429
end = start
370430
start = 0
431+
if dtype is None:
432+
dtype = paddle.get_default_dtype()
371433
size_ = int(np.abs(np.trunc((end - start) / step))) + 1
372434
out = paddle.empty([size_])
373435

@@ -430,14 +492,26 @@ def range_manual(start, end, step, dtype, device, requires_grad):
430492
err_msg=f"[FAILED] wrong result when testing: range({start},{end},{step})",
431493
)
432494

495+
def wrapped_range(
496+
start, end, step, dtype, device, requires_grad
497+
):
498+
return paddle.range(
499+
start,
500+
end,
501+
step,
502+
dtype,
503+
device=device,
504+
requires_grad=requires_grad,
505+
)
506+
433507
st_f = paddle.jit.to_static(
434-
paddle.range, full_graph=True, backend=None
508+
wrapped_range, full_graph=True, backend=None
435509
)
436510
x = st_f(
437511
start,
438512
end,
439513
step,
440-
dtype=dtype,
514+
dtype,
441515
device=device,
442516
requires_grad=requires_grad,
443517
)

0 commit comments

Comments
 (0)