Skip to content

Commit 9ecacf9

Browse files
authored
[API compatibility] paddle.nn.functional.one_hot (#74925)
* [API compatibility] one_hot * fix
1 parent 69caf6a commit 9ecacf9

File tree

3 files changed

+76
-46
lines changed

3 files changed

+76
-46
lines changed

python/paddle/nn/functional/input.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
__all__ = []
3131

3232

33+
@param_one_alias(["x", "input"])
3334
def one_hot(
3435
x: Tensor,
35-
num_classes: int,
36+
num_classes: int = -1,
3637
name: str | None = None,
3738
) -> Tensor:
3839
"""
@@ -72,11 +73,17 @@ def one_hot(
7273
so it throws an exception.
7374
7475
76+
.. note::
77+
Alias Support: The parameter name ``input`` can be used as an alias for ``x``.
78+
For example, ``one_hot(input=tensor_x, ...)`` is equivalent to ``one_hot(x=tensor_x, ...)``.
79+
80+
7581
Args:
7682
x(Tensor): Tensor with shape :math:`[N_1, N_2, ..., N_k]` ,
7783
which contains at least one dimension. The data type is int32 or int64.
84+
alias: ``input``.
7885
num_classes(int): An integer defining the `num_classes` of the one hot dimension. If input `x`
79-
is word id, `num_classes` is generally the dictionary size.
86+
is word id, `num_classes` is generally the dictionary size. Default value: -1.
8087
name(str|None, optional): For detailed information, please refer
8188
to :ref:`api_guide_Name`. Usually name is no need to set and
8289
None by default.
@@ -103,7 +110,8 @@ def one_hot(
103110
[1., 0., 0., 0.]])
104111
105112
"""
106-
113+
if not isinstance(num_classes, paddle.pir.Value) and num_classes == -1:
114+
num_classes = x.max() + 1
107115
if in_dynamic_or_pir_mode():
108116
return _C_ops.one_hot(x, num_classes)
109117
else:

test/ir/pir/test_special_op_translator.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -264,48 +264,6 @@ def test_op(self):
264264
_ = pir.translate_to_pir(main_program.desc)
265265

266266

267-
class TestOneHotOpTranscriber(unittest.TestCase):
268-
def test_mutable_attribute(self):
269-
with paddle.pir_utils.OldIrGuard():
270-
place = core.Place()
271-
place.set_place(paddle.CPUPlace())
272-
new_scope = paddle.static.Scope()
273-
main_program = paddle.static.Program()
274-
with (
275-
paddle.static.scope_guard(new_scope),
276-
paddle.static.program_guard(main_program),
277-
):
278-
depth = paddle.assign(np.array([10], dtype=np.int32))
279-
label = paddle.static.data(
280-
name="label", shape=[-1, 1], dtype="int64"
281-
)
282-
one_hot_label = paddle.nn.functional.one_hot(
283-
x=label, num_classes=depth
284-
)
285-
286-
_ = pir.translate_to_pir(main_program.desc)
287-
288-
def test_normal_attribute(self):
289-
with paddle.pir_utils.OldIrGuard():
290-
place = core.Place()
291-
place.set_place(paddle.CPUPlace())
292-
new_scope = paddle.static.Scope()
293-
main_program = paddle.static.Program()
294-
with (
295-
paddle.static.scope_guard(new_scope),
296-
paddle.static.program_guard(main_program),
297-
):
298-
depth = 10
299-
label = paddle.static.data(
300-
name="label", shape=[-1, 1], dtype="int64"
301-
)
302-
one_hot_label = paddle.nn.functional.one_hot(
303-
x=label, num_classes=depth
304-
)
305-
306-
_ = pir.translate_to_pir(main_program.desc)
307-
308-
309267
class TestReduceOpTranscriber(unittest.TestCase):
310268
def test_reduce_all(self):
311269
place = core.Place()

test/legacy_test/test_one_hot_v2_op.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import unittest
1616

1717
import numpy as np
18-
from op_test import OpTest
18+
from op_test import OpTest, get_places
1919

2020
import paddle
2121
from paddle import base
@@ -283,6 +283,70 @@ def test_check_output(self):
283283
self.check_output()
284284

285285

286+
class TestOneHotAPI_Compatibility(unittest.TestCase):
287+
def setUp(self):
288+
np.random.seed(123)
289+
paddle.enable_static()
290+
self.places = get_places()
291+
self.shape = [5]
292+
self.dtype = 'int32'
293+
self.init_data()
294+
295+
def init_data(self):
296+
self.np_input = np.random.randint(0, 8, self.shape).astype(self.dtype)
297+
self.num_classes = self.np_input.max() + 1
298+
self.np_out = np.eye(self.num_classes)[self.np_input]
299+
300+
def test_dygraph_Compatibility(self):
301+
paddle.disable_static()
302+
x = paddle.to_tensor(self.np_input)
303+
paddle_dygraph_out = []
304+
# Position args (args)
305+
out1 = paddle.nn.functional.one_hot(x, self.num_classes)
306+
paddle_dygraph_out.append(out1)
307+
# Key words args (kwargs) for paddle
308+
out2 = paddle.nn.functional.one_hot(x=x, num_classes=self.num_classes)
309+
paddle_dygraph_out.append(out2)
310+
# Key words args for torch
311+
out3 = paddle.nn.functional.one_hot(
312+
input=x, num_classes=self.num_classes
313+
)
314+
paddle_dygraph_out.append(out3)
315+
# default args
316+
out4 = paddle.nn.functional.one_hot(x, -1)
317+
paddle_dygraph_out.append(out4)
318+
# Check
319+
for out in paddle_dygraph_out:
320+
np.testing.assert_allclose(self.np_out, out.numpy())
321+
paddle.enable_static()
322+
323+
def test_static_Compatibility(self):
324+
main = paddle.static.Program()
325+
startup = paddle.static.Program()
326+
with base.program_guard(main, startup):
327+
x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype)
328+
# Position args (args)
329+
out1 = paddle.nn.functional.one_hot(x, self.num_classes)
330+
# Key words args (kwargs) for paddle
331+
out2 = paddle.nn.functional.one_hot(
332+
x=x, num_classes=self.num_classes
333+
)
334+
# Key words args for torch
335+
out3 = paddle.nn.functional.one_hot(
336+
input=x, num_classes=self.num_classes
337+
)
338+
# default args
339+
out4 = paddle.nn.functional.one_hot(x, -1)
340+
exe = base.Executor(paddle.CPUPlace())
341+
fetches = exe.run(
342+
main,
343+
feed={"x": self.np_input},
344+
fetch_list=[out1, out2, out3],
345+
)
346+
for out in fetches:
347+
np.testing.assert_allclose(out, self.np_out)
348+
349+
286350
if __name__ == '__main__':
287351
paddle.enable_static()
288352
unittest.main()

0 commit comments

Comments
 (0)