Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][A-16,A-19] Add type annotations for base Layer and containers #65190

Merged
merged 7 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,9 +710,11 @@ def __impl__(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
# introducing compatibility issues, add this decorator
# NOTE(chenweihang): not using `wrap_decorator` here is because `wrap_decorator` will
# move kwargs to args, which doesn't work in this decorate case
def deprecate_stat_dict(func):
def deprecate_stat_dict(
func: Callable[_InputT, _RetT]
) -> Callable[_InputT, _RetT]:
@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
if "stat_dict" in kwargs:
warnings.warn(
"The argument `stat_dict` has deprecated, please change it to `state_dict`.",
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/base/unique_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __call__(self, name):
generator = UniqueNameGenerator()


def generate(key):
def generate(key: str) -> str:
"""
Generate unique name with prefix key. Currently, Paddle distinguishes the
names of the same key by numbering it from zero. For example, when key=fc,
Expand Down
91 changes: 56 additions & 35 deletions python/paddle/nn/layer/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import typing
from collections import OrderedDict
from collections.abc import Iterable, Mapping
from typing import Any, Iterator, Sequence

from typing_extensions import Self

from paddle import Tensor

from ...base.dygraph.base import param_guard
from ...base.framework import Parameter
Expand Down Expand Up @@ -67,30 +75,38 @@ class LayerDict(Layer):

"""

def __init__(self, sublayers=None):
def __init__(
self,
sublayers: (
LayerDict
| typing.Mapping[str, Layer]
| Sequence[tuple[str, Layer]]
| None
) = None,
) -> None:
super().__init__()
if sublayers is not None:
self.update(sublayers)

def __getitem__(self, key):
def __getitem__(self, key: str) -> Layer:
return self._sub_layers[key]

def __setitem__(self, key, sublayer):
def __setitem__(self, key: str, sublayer: Layer) -> Layer:
return self.add_sublayer(key, sublayer)

def __delitem__(self, key):
def __delitem__(self, key: str) -> None:
del self._sub_layers[key]

def __len__(self):
def __len__(self) -> int:
return len(self._sub_layers)

def __iter__(self):
def __iter__(self) -> Iterator[Layer]:
return iter(self._sub_layers)

def __contains__(self, key):
def __contains__(self, key: str) -> bool:
return key in self._sub_layers

def clear(self):
def clear(self) -> None:
"""
Clear all the sublayers in the LayerDict.

Expand Down Expand Up @@ -120,7 +136,7 @@ def clear(self):
"""
self._sub_layers.clear()

def pop(self, key):
def pop(self, key: str) -> Layer:
"""
Remove the key from the LayerDict and return the layer of the key.

Expand Down Expand Up @@ -152,7 +168,7 @@ def pop(self, key):
del self[key]
return v

def keys(self):
def keys(self) -> Iterable[str]:
"""
Return the iterable of the keys in LayerDict.

Expand Down Expand Up @@ -181,7 +197,7 @@ def keys(self):
"""
return self._sub_layers.keys()

def items(self):
def items(self) -> Iterable[tuple[str, Layer]]:
"""
Return the iterable of the key/value pairs in LayerDict.

Expand Down Expand Up @@ -210,7 +226,7 @@ def items(self):
"""
return self._sub_layers.items()

def values(self):
def values(self) -> Iterable[Layer]:
"""
Return the iterable of the values in LayerDict.

Expand Down Expand Up @@ -239,7 +255,12 @@ def values(self):
"""
return self._sub_layers.values()

def update(self, sublayers):
def update(
self,
sublayers: (
LayerDict | typing.Mapping[str, Layer] | Sequence[tuple[str, Layer]]
),
) -> None:
"""
Update the key/values pairs in sublayers to the LayerDict, overwriting the existing keys.

Expand Down Expand Up @@ -353,29 +374,29 @@ class ParameterList(Layer):
[5, 4]
"""

def __init__(self, parameters=None):
def __init__(self, parameters: Iterable[Tensor] | None = None) -> None:
super().__init__()
if parameters is not None:
for idx, param in enumerate(parameters):
assert isinstance(param, Parameter)
self.add_parameter(str(idx), param)

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Tensor:
with param_guard(self._parameters):
return self._parameters[str(idx)]

def __setitem__(self, idx, param):
def __setitem__(self, idx: int, param: Tensor) -> None:
assert isinstance(param, Parameter)
setattr(self, str(idx), param)

def __len__(self):
def __len__(self) -> int:
return len(self._parameters)

def __iter__(self):
def __iter__(self) -> Iterator[Tensor]:
with param_guard(self._parameters):
return iter(self._parameters.values())

def append(self, parameter):
def append(self, parameter: Tensor) -> Self:
"""Appends a given parameter at the end of the list.

Parameters:
Expand Down Expand Up @@ -412,13 +433,13 @@ class LayerList(Layer):
... return x
"""

def __init__(self, sublayers=None):
def __init__(self, sublayers: Iterable[Layer] | None = None) -> None:
super().__init__()
if sublayers is not None:
for idx, layer in enumerate(sublayers):
self.add_sublayer(str(idx), layer)

def _get_abs_idx(self, idx):
def _get_abs_idx(self, idx: int) -> int:
if isinstance(idx, int):
if not (-len(self) <= idx < len(self)):
raise IndexError(
Expand All @@ -428,18 +449,18 @@ def _get_abs_idx(self, idx):
idx += len(self)
return idx

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Layer:
if isinstance(idx, slice):
return self.__class__(list(self._sub_layers.values())[idx])
else:
idx = self._get_abs_idx(idx)
return self._sub_layers[str(idx)]

def __setitem__(self, idx, sublayer):
def __setitem__(self, idx: int, sublayer: Layer) -> None:
idx = self._get_abs_idx(idx)
return setattr(self, str(idx), sublayer)

def __delitem__(self, idx):
def __delitem__(self, idx: int) -> None:
if isinstance(idx, slice):
for k in range(len(self._sub_layers))[idx]:
delattr(self, str(k))
Expand All @@ -451,13 +472,13 @@ def __delitem__(self, idx):
list(zip(str_indices, self._sub_layers.values()))
)

def __len__(self):
def __len__(self) -> int:
return len(self._sub_layers)

def __iter__(self):
def __iter__(self) -> Iterator[Layer]:
return iter(self._sub_layers.values())

def append(self, sublayer):
def append(self, sublayer: Layer) -> Self:
"""
Appends a sublayer to the end of the list.

Expand All @@ -478,7 +499,7 @@ def append(self, sublayer):
self.add_sublayer(str(len(self)), sublayer)
return self

def insert(self, index, sublayer):
def insert(self, index: int, sublayer: Layer) -> None:
"""
Insert a sublayer before a given index in the list.

Expand Down Expand Up @@ -510,7 +531,7 @@ def insert(self, index, sublayer):
self._sub_layers[str(i)] = self._sub_layers[str(i - 1)]
self._sub_layers[str(index)] = sublayer

def extend(self, sublayers):
def extend(self, sublayers: Iterable[Layer]) -> Self:
"""
Appends sublayers to the end of the list.

Expand Down Expand Up @@ -575,7 +596,7 @@ class Sequential(Layer):

"""

def __init__(self, *layers):
def __init__(self, *layers: Layer | tuple[str, Layer] | list[Any]) -> None:
super().__init__()
if len(layers) > 0 and isinstance(layers[0], (list, tuple)):
for name, layer in layers:
Expand All @@ -584,7 +605,7 @@ def __init__(self, *layers):
for idx, layer in enumerate(layers):
self.add_sublayer(str(idx), layer)

def __getitem__(self, name):
def __getitem__(self, name: str) -> Layer:
if isinstance(name, slice):
return self.__class__(*(list(self._sub_layers.values())[name]))
elif isinstance(name, str):
Expand All @@ -598,19 +619,19 @@ def __getitem__(self, name):
raise IndexError(f'index {name} is out of range')
return list(self._sub_layers.values())[name]

def __setitem__(self, name, layer):
def __setitem__(self, name: str, layer: Layer) -> None:
assert isinstance(layer, Layer)
setattr(self, str(name), layer)

def __delitem__(self, name):
def __delitem__(self, name: str) -> None:
name = str(name)
assert name in self._sub_layers
del self._sub_layers[name]

def __len__(self):
def __len__(self) -> int:
return len(self._sub_layers)

def forward(self, input):
def forward(self, input: Any) -> Any:
for layer in self._sub_layers.values():
input = layer(input)
return input
Loading