Skip to content

Commit baabd60

Browse files
committed
[Fix] Fixed compat.nn.functional import
1 parent d6233c0 commit baabd60

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

python/paddle/tensor/compat_softmax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
@ForbidKeywordsIgnoreOneParamDecorator(
3131
illegal_keys={"x", "axis", "name"},
3232
ignore_param=('_stacklevel', 2, int),
33-
func_name="paddle.compat.softmax",
33+
func_name="paddle.compat.nn.functional.softmax",
3434
correct_name="paddle.nn.functional.softmax",
3535
)
3636
def softmax(
@@ -41,7 +41,7 @@ def softmax(
4141
out: Tensor | None = None,
4242
) -> Tensor:
4343
r"""
44-
This operator implements the compat.softmax. The calculation process is as follows:
44+
This operator implements PyTorch compatible softmax. The calculation process is as follows:
4545
4646
1. The dimension :attr:`dim` of ``input`` will be permuted to the last.
4747
@@ -139,8 +139,8 @@ def softmax(
139139
... [[1.0, 2.0, 3.0, 4.0],
140140
... [5.0, 6.0, 7.0, 8.0],
141141
... [6.0, 7.0, 8.0, 9.0]]],dtype='float32')
142-
>>> out1 = paddle.compat.softmax(x, -1)
143-
>>> out2 = paddle.compat.softmax(x, -1, dtype='float64')
142+
>>> out1 = paddle.compat.nn.functional.softmax(x, -1)
143+
>>> out2 = paddle.compat.nn.functional.softmax(x, -1, dtype='float64')
144144
>>> #out1's data type is float32; out2's data type is float64
145145
>>> #out1 and out2's value is as follows:
146146
>>> print(out1)

test/legacy_test/test_softmax_op.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ def setUp(self):
804804
def test_static_check(self):
805805
with static_guard():
806806
for x_np, out_ref in zip(self.x_np_list, self.out_ref_list):
807-
func = compat.softmax
807+
func = compat.nn.functional.softmax
808808
with paddle.static.program_guard(paddle.static.Program()):
809809
x = paddle.static.data('X', x_np.shape, 'float32')
810810
out1 = func(input=x, dim=None, _stacklevel=3)
@@ -854,7 +854,7 @@ def test_static_check(self):
854854
def test_dygraph_check(self):
855855
paddle.disable_static(self.place)
856856
for x_np, out_ref in zip(self.x_np_list, self.out_ref_list):
857-
func = compat.softmax
857+
func = compat.nn.functional.softmax
858858
x = paddle.to_tensor(x_np)
859859
out1 = func(input=x, dim=None, _stacklevel=3)
860860
x = paddle.to_tensor(x_np)
@@ -964,12 +964,18 @@ def test_forbid_keywords(self):
964964
paddle.static.program_guard(paddle.static.Program()),
965965
):
966966
x = paddle.static.data('X', [2, 3], 'float32')
967-
self.assertRaises(TypeError, compat.softmax, x=x, axis=-1)
968-
self.assertRaises(TypeError, compat.softmax, x=x, dim=-1)
969-
self.assertRaises(TypeError, compat.softmax, input=x, axis=-1)
967+
self.assertRaises(
968+
TypeError, compat.nn.functional.softmax, x=x, axis=-1
969+
)
970+
self.assertRaises(
971+
TypeError, compat.nn.functional.softmax, x=x, dim=-1
972+
)
973+
self.assertRaises(
974+
TypeError, compat.nn.functional.softmax, input=x, axis=-1
975+
)
970976

971977
if core.is_compiled_with_cuda() or is_custom_device():
972-
compat.softmax(input=x, dim=-1)
978+
compat.nn.functional.softmax(input=x, dim=-1)
973979

974980

975981
if __name__ == "__main__":

0 commit comments

Comments
 (0)