Skip to content

Commit 1dafe95

Browse files
committed
[API Compatibility] Add paddle.compat.nn.Linear
1 parent 84d14f8 commit 1dafe95

File tree

3 files changed

+619
-4
lines changed

3 files changed

+619
-4
lines changed

python/paddle/compat/nn/__init__.py

Lines changed: 157 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from math import sqrt
1718
from typing import TYPE_CHECKING
1819

1920
import paddle
@@ -23,18 +24,18 @@
2324
)
2425
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
2526

26-
from . import functional # noqa: F401
27+
from . import functional
2728

2829
if TYPE_CHECKING:
2930
from paddle import Tensor
3031
from paddle._typing import (
32+
DTypeLike,
33+
PlaceLike,
3134
Size2,
3235
)
3336

3437

35-
__all__ = [
36-
'Unfold',
37-
]
38+
__all__ = ['Unfold', 'Linear']
3839

3940

4041
class Unfold(nn.Unfold):
@@ -114,3 +115,155 @@ def to_list_if_necessary(x, size_check=False):
114115
dilations=to_list_if_necessary(self.dilations),
115116
name=self.name,
116117
)
118+
119+
120+
class Linear(nn.Layer):
121+
r"""
122+
123+
Python compatible fully-connected linear transformation layer. For each input :math:`X` ,
124+
the equation is:
125+
126+
.. math::
127+
128+
Out = XW^T + b
129+
130+
where :math:`W` is the weight and :math:`b` is the bias.
131+
132+
Linear layer takes only one multi-dimensional tensor as input with the
133+
shape :math:`[*, in\_features]` , where :math:`*` means any
134+
number of additional dimensions. It multiplies input tensor with the transpose
135+
of weight (a 2-D tensor of shape :math:`[out\_features, in\_features]` ) and
136+
produces an output tensor of shape :math:`[*, out\_features]` .
137+
If ``bias`` is not False, the bias (a 1-D tensor of
138+
shape :math:`[out\_features]` ) will be created and added to the output. At the
139+
end of the initialization, ``reset_parameters`` will be called to initialize
140+
the ``weight`` and ``bias`` (if available) randomly.
141+
142+
Parameters:
143+
in_features (int):
144+
The number of input units.
145+
out_features (int):
146+
The number of output units.
147+
bias (bool): If True, the bias (a 1-D tensor of shape :math:`[out\_features]` ) will be created and
148+
added to the output. Default: True.
149+
device (PlaceLike): The device of the parameters created. Default: None,
150+
representing the default paddle device.
151+
dtype (DTypeLike): The dtype of the parameters created. Default: None, and is set by
152+
the default dtype of Linear (float32).
153+
154+
Variables:
155+
weight (paddle.Tensor): learnable parameters of the module of shape :math:`[out\_features, in\_features]`.
156+
The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k` is :math:`\frac{1}{in\_features}`.
157+
bias (paddle.Tensor): learnable parameters of the module of shape :math:`[out\_features]`. If ``bias`` is True,
158+
the values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k` is :math:`\frac{1}{in\_features}`.
159+
160+
Shape:
161+
- input: Multi-dimensional tensor with shape :math:`[*, in\_features]` . Its data types are float16, float32, float64 ,The default is float32 .
162+
- output: Multi-dimensional tensor with shape :math:`[*, out\_features]` . The data type is the same as the input .
163+
164+
Examples:
165+
.. code-block:: python
166+
167+
>>> import paddle
168+
>>> paddle.seed(100)
169+
170+
>>> # Define the linear layer.
171+
>>> linear = paddle.compat.nn.Linear(2, 4, bias=True)
172+
>>> print(linear.weight)
173+
Parameter containing:
174+
Tensor(shape=[4, 2], dtype=float32, place=Place(cpu), stop_gradient=False,
175+
[[-0.49191639, 0.28120756],
176+
[-0.17887023, 0.40572405],
177+
[ 0.35139430, 0.45717543],
178+
[-0.06135514, -0.21088189]])
179+
180+
>>> print(linear.bias)
181+
Parameter containing:
182+
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=False,
183+
[ 0.49166456, -0.06108528, -0.14973064, 0.31168410])
184+
185+
>>> x = paddle.arange(6, dtype="float32").reshape([3, 2])
186+
>>> y = linear(x)
187+
>>> print(y)
188+
Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=False,
189+
[[ 0.77287209, 0.34463876, 0.30744481, 0.10080221],
190+
[ 0.35145447, 0.79834640, 1.92458415, -0.44367185],
191+
[-0.06996319, 1.25205410, 3.54172373, -0.98814595]])
192+
"""
193+
194+
__constants__ = ["in_features", "out_features"]
195+
in_features: int
196+
out_features: int
197+
weight: Tensor
198+
199+
@ForbidKeywordsDecorator(
200+
illegal_keys={"weight_attr", "bias_attr", "name"},
201+
func_name="paddle.compat.nn.Linear",
202+
correct_name="paddle.nn.Linear",
203+
)
204+
def __init__(
205+
self,
206+
in_features: int,
207+
out_features: int,
208+
bias: bool = True,
209+
device: PlaceLike | None = None,
210+
dtype: DTypeLike | None = None,
211+
) -> None:
212+
super().__init__()
213+
self._dtype = (
214+
self._helper.get_default_dtype() if dtype is None else dtype
215+
)
216+
self.in_features = in_features
217+
self.out_features = out_features
218+
self.weight = self.create_parameter(
219+
shape=[out_features, in_features],
220+
attr=None,
221+
dtype=self._dtype,
222+
is_bias=False,
223+
device=device,
224+
)
225+
self.bias = None
226+
if bias:
227+
self.bias = self.create_parameter(
228+
shape=[out_features],
229+
attr=None,
230+
dtype=self._dtype,
231+
is_bias=True,
232+
device=device,
233+
)
234+
# The same parameter initialization as PyTorch
235+
self.reset_parameters()
236+
237+
def forward(self, input: Tensor) -> Tensor:
238+
return functional.linear.__wrapped__( # bypass ForbidKeywordsDecorator
239+
input=input, weight=self.weight, bias=self.bias
240+
)
241+
242+
def extra_repr(self) -> str:
243+
"""
244+
Return the extra representation of the module.
245+
"""
246+
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
247+
248+
def reset_parameters(self) -> None:
249+
"""
250+
Resets parameters based on their initialization used in ``__init__``.
251+
"""
252+
253+
# KaimingUniform initializer should be more flexible: user should be able to specify place
254+
expected_place = paddle.base.framework._current_expected_place()
255+
original_place = self.weight.place
256+
nn.init.kaiming_uniform_(self.weight, a=sqrt(5))
257+
258+
place_mismatch = expected_place != original_place
259+
if place_mismatch and in_dynamic_mode():
260+
self.weight = self.weight.to(original_place)
261+
if self.bias is not None:
262+
# nn.init._calculate_fan_in_and_fan_out(self.weight) for 2D array
263+
# is equivalent to returning (weight.shape[1], weight.shape[0])
264+
fan_in = self.weight.shape[1]
265+
bound = 1 / sqrt(fan_in) if fan_in > 0 else 0
266+
nn.init.uniform_(self.bias, -bound, bound)
267+
268+
if place_mismatch and in_dynamic_mode():
269+
self.bias = self.bias.to(original_place)

python/paddle/nn/layer/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,12 @@ class Linear(Layer):
198198
bias: Tensor
199199
name: str | None
200200

201+
@ForbidKeywordsDecorator(
202+
illegal_keys={"bias", "device", "dtype"},
203+
func_name="paddle.nn.Linear",
204+
correct_name="paddle.compat.nn.Linear",
205+
url_suffix="nn/torch.nn.Linear",
206+
)
201207
def __init__(
202208
self,
203209
in_features: int,

0 commit comments

Comments
 (0)