Skip to content

Commit 84d14f8

Browse files
authored
[API Compatibility] Add paddle.compat.nn.functional.linear and remove paddle.compat.softmax (#76144)
* [API Compatibility] Add paddle.compat.nn.functional.linear paddle.compat.softmax is removed Adjust some of the import line position * [Fix] Fixed compat.nn.functional import
1 parent f6f5552 commit 84d14f8

File tree

6 files changed

+407
-22
lines changed

6 files changed

+407
-22
lines changed

python/paddle/compat/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,10 @@
1919
import warnings
2020
from contextlib import contextmanager
2121

22-
from paddle.tensor import softmax
23-
2422
from . import nn # noqa: F401
2523

2624
__all__ = [
2725
'slogdet',
28-
'softmax',
2926
'sort',
3027
'split',
3128
'min',
@@ -43,16 +40,15 @@
4340
from paddle.framework import (
4441
in_dynamic_mode,
4542
)
43+
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
44+
45+
from .utils import _check_out_status
4646

4747
if TYPE_CHECKING:
4848
from collections.abc import Sequence
4949

5050
from paddle import Tensor
5151

52-
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
53-
54-
from .utils import _check_out_status
55-
5652

5753
class MedianRetType(NamedTuple):
5854
values: Tensor

python/paddle/compat/nn/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from typing import TYPE_CHECKING
1818

1919
import paddle
20+
from paddle import nn
2021
from paddle.framework import (
2122
in_dynamic_mode,
2223
)
24+
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
2325

2426
from . import functional # noqa: F401
2527

@@ -30,9 +32,6 @@
3032
)
3133

3234

33-
from paddle import nn
34-
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
35-
3635
__all__ = [
3736
'Unfold',
3837
]

python/paddle/compat/nn/functional/__init__.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
in_dynamic_mode,
2424
)
2525
from paddle.tensor import softmax
26+
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
2627

2728
if TYPE_CHECKING:
2829
from typing_extensions import TypeAlias
@@ -36,9 +37,8 @@
3637
"zeros", "constant", "reflect", "replicate", "circular"
3738
]
3839

39-
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
4040

41-
__all__ = ['pad', 'softmax']
41+
__all__ = ['pad', 'softmax', 'linear']
4242

4343

4444
def _check_valid_pad_len(pad_len, x_dim, is_constant):
@@ -191,3 +191,78 @@ def pad(
191191
if ndim_to_unsqueeze:
192192
return out.squeeze(axis=ndim_to_unsqueeze)
193193
return out
194+
195+
196+
@ForbidKeywordsDecorator(
197+
illegal_keys={"x", "name"},
198+
func_name="paddle.compat.nn.functional.linear",
199+
correct_name="paddle.nn.functional.linear",
200+
)
201+
def linear(input: Tensor, weight: Tensor, bias: Tensor | None = None) -> Tensor:
202+
r"""
203+
204+
Fully-connected linear transformation operator. For each input :math:`x` ,
205+
the equation is:
206+
207+
.. math::
208+
209+
Out = xW^T + b
210+
211+
where :math: `W` is the weight and :math:`b` is the bias.
212+
213+
If the weight is a 2-D tensor of shape :math:`[out\_features, in\_features]` ,
214+
input should be a multi-dimensional tensor of shape
215+
:math:`[*, in\_features]` , where :math:`*` means any number of
216+
additional dimensions. The linear operator multiplies input tensor with
217+
weight and produces an output tensor of shape :math:`[*, out\_features]` ,
218+
If :math:`bias` is not None, the bias should be a 1-D tensor of shape
219+
:math:`[out\_features]` and will be added to the output.
220+
221+
This implementation is aligned with PyTorch's linear function which computes
222+
:math:`y = xW^T + b`.
223+
224+
Parameters:
225+
input (Tensor): Input tensor. The data type should be bfloat16, float16, float32 or float64.
226+
The input tensor should be of shape :math:`[*, in\_features]`, where :math:`*` means any number of additional dimensions, including none
227+
weight (Tensor): Weight tensor. The data type should be float16, float32 or float64.
228+
Shape should be [out_features, in_features].
229+
bias (Tensor, optional): Bias tensor. The data type should be float16, float32 or float64.
230+
If it is set to None, no bias will be added to the output units.
231+
232+
Returns:
233+
Tensor, the shape is :math:`[*, out\_features]` and the
234+
data type is the same with input :math:`x` .
235+
236+
Examples:
237+
.. code-block:: python
238+
239+
>>> import paddle
240+
241+
>>> paddle.seed(2025)
242+
243+
>>> x = paddle.arange(6, dtype=paddle.float32).reshape([3, 2])
244+
>>> x
245+
Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
246+
[[0., 1.],
247+
[2., 3.],
248+
[4., 5.]])
249+
>>> weight = paddle.full(shape=[4, 2], fill_value=0.5, dtype="float32", name="weight")
250+
>>> weight
251+
Tensor(shape=[4, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
252+
[[0.50000000, 0.50000000],
253+
[0.50000000, 0.50000000],
254+
[0.50000000, 0.50000000],
255+
[0.50000000, 0.50000000]])
256+
>>> bias = paddle.ones(shape=[4], dtype="float32", name="bias")
257+
>>> y = paddle.compat.nn.functional.linear(x, weight, bias)
258+
>>> print(y)
259+
Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
260+
[[1.50000000, 1.50000000, 1.50000000, 1.50000000],
261+
[3.50000000, 3.50000000, 3.50000000, 3.50000000],
262+
[5.50000000, 5.50000000, 5.50000000, 5.50000000]])
263+
"""
264+
# transpose y is True, since _C_ops.linear(input, weight.T, bias) can introduce more overhead. With CINN, matmul and add can be fused.
265+
out = _C_ops.matmul(input, weight, False, True)
266+
if bias is not None:
267+
out = _C_ops.add(out, bias)
268+
return out

python/paddle/tensor/compat_softmax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
@ForbidKeywordsIgnoreOneParamDecorator(
3131
illegal_keys={"x", "axis", "name"},
3232
ignore_param=('_stacklevel', 2, int),
33-
func_name="paddle.compat.softmax",
33+
func_name="paddle.compat.nn.functional.softmax",
3434
correct_name="paddle.nn.functional.softmax",
3535
)
3636
def softmax(
@@ -41,7 +41,7 @@ def softmax(
4141
out: Tensor | None = None,
4242
) -> Tensor:
4343
r"""
44-
This operator implements the compat.softmax. The calculation process is as follows:
44+
This operator implements PyTorch compatible softmax. The calculation process is as follows:
4545
4646
1. The dimension :attr:`dim` of ``input`` will be permuted to the last.
4747
@@ -139,8 +139,8 @@ def softmax(
139139
... [[1.0, 2.0, 3.0, 4.0],
140140
... [5.0, 6.0, 7.0, 8.0],
141141
... [6.0, 7.0, 8.0, 9.0]]],dtype='float32')
142-
>>> out1 = paddle.compat.softmax(x, -1)
143-
>>> out2 = paddle.compat.softmax(x, -1, dtype='float64')
142+
>>> out1 = paddle.compat.nn.functional.softmax(x, -1)
143+
>>> out2 = paddle.compat.nn.functional.softmax(x, -1, dtype='float64')
144144
>>> #out1's data type is float32; out2's data type is float64
145145
>>> #out1 and out2's value is as follows:
146146
>>> print(out1)

0 commit comments

Comments
 (0)