Skip to content

Commit a59ebbc

Browse files
authored
api compatiblity: modify softmax decorator name, add test (#74952)
1 parent 7fa90ed commit a59ebbc

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

python/paddle/utils/decorator_utils.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -427,22 +427,7 @@ def __init__(
427427
def process(
428428
self, args: tuple[Any, ...], kwargs: dict[str, Any]
429429
) -> tuple[tuple[Any, ...], dict[str, Any]]:
430-
found_keys = [key for key in self.illegal_keys if key in kwargs]
431-
432-
if found_keys:
433-
found_keys.sort()
434-
keys_str = ", ".join(f"'{key}'" for key in found_keys)
435-
plural = "s" if len(found_keys) > 1 else ""
436-
437-
raise TypeError(
438-
f"{self.func_name}() received unexpected keyword argument{plural} {keys_str}. "
439-
f"\nDid you mean to use {self.correct_name}() instead?"
440-
)
441-
if self.warn_msg is not None:
442-
warnings.warn(
443-
self.warn_msg,
444-
category=Warning,
445-
)
430+
args, kwargs = super().process(args, kwargs)
446431

447432
if self.ignore_param:
448433
name, index, typ = self.ignore_param

test/legacy_test/test_softmax_op.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def test_static_check(self):
790790
with paddle.static.program_guard(paddle.static.Program()):
791791
x = paddle.static.data('X', x_np.shape, 'float32')
792792
out1 = func(input=x, dim=None, _stacklevel=3)
793-
out2 = func(x)
793+
out2 = func(x, None, 3)
794794
exe = paddle.static.Executor(self.place)
795795
res = exe.run(feed={'X': x_np}, fetch_list=[out1, out2])
796796
for rr in res:
@@ -840,7 +840,7 @@ def test_dygraph_check(self):
840840
x = paddle.to_tensor(x_np)
841841
out1 = func(input=x, dim=None, _stacklevel=3)
842842
x = paddle.to_tensor(x_np)
843-
out2 = func(x)
843+
out2 = func(x, None, 3)
844844
for r in [out1, out2]:
845845
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
846846

@@ -940,6 +940,19 @@ def test_dygraph_check(self):
940940

941941
paddle.enable_static()
942942

943+
def test_forbid_keywords(self):
944+
with (
945+
static_guard(),
946+
paddle.static.program_guard(paddle.static.Program()),
947+
):
948+
x = paddle.static.data('X', [2, 3], 'float32')
949+
self.assertRaises(TypeError, compat.softmax, x=x, axis=-1)
950+
self.assertRaises(TypeError, compat.softmax, x=x, dim=-1)
951+
self.assertRaises(TypeError, compat.softmax, input=x, axis=-1)
952+
953+
if core.is_compiled_with_cuda():
954+
compat.softmax(input=x, dim=-1)
955+
943956

944957
if __name__ == "__main__":
945958
unittest.main()

0 commit comments

Comments
 (0)