Skip to content

Commit

Permalink
[Typing][A-16,A-19] Add type annotations for base Layer and containers (
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Jun 17, 2024
1 parent b29ab37 commit 4ca9b7b
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 138 deletions.
6 changes: 4 additions & 2 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,9 +710,11 @@ def __impl__(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
# introducing compatibility issues, add this decorator
# NOTE(chenweihang): not using `wrap_decorator` here is because `wrap_decorator` will
# move kwargs to args, which doesn't work in this decorate case
def deprecate_stat_dict(func):
def deprecate_stat_dict(
func: Callable[_InputT, _RetT]
) -> Callable[_InputT, _RetT]:
@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
if "stat_dict" in kwargs:
warnings.warn(
"The argument `stat_dict` has deprecated, please change it to `state_dict`.",
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/base/unique_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __call__(self, name):
generator = UniqueNameGenerator()


def generate(key):
def generate(key: str) -> str:
"""
Generate unique name with prefix key. Currently, Paddle distinguishes the
names of the same key by numbering it from zero. For example, when key=fc,
Expand Down
91 changes: 56 additions & 35 deletions python/paddle/nn/layer/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import typing
from collections import OrderedDict
from collections.abc import Iterable, Mapping
from typing import Any, Iterator, Sequence

from typing_extensions import Self

from paddle import Tensor

from ...base.dygraph.base import param_guard
from ...base.framework import Parameter
Expand Down Expand Up @@ -67,30 +75,38 @@ class LayerDict(Layer):
"""

def __init__(self, sublayers=None):
def __init__(
self,
sublayers: (
LayerDict
| typing.Mapping[str, Layer]
| Sequence[tuple[str, Layer]]
| None
) = None,
) -> None:
super().__init__()
if sublayers is not None:
self.update(sublayers)

def __getitem__(self, key):
def __getitem__(self, key: str) -> Layer:
return self._sub_layers[key]

def __setitem__(self, key, sublayer):
def __setitem__(self, key: str, sublayer: Layer) -> Layer:
return self.add_sublayer(key, sublayer)

def __delitem__(self, key):
def __delitem__(self, key: str) -> None:
del self._sub_layers[key]

def __len__(self):
def __len__(self) -> int:
return len(self._sub_layers)

def __iter__(self):
def __iter__(self) -> Iterator[Layer]:
return iter(self._sub_layers)

def __contains__(self, key):
def __contains__(self, key: str) -> bool:
return key in self._sub_layers

def clear(self):
def clear(self) -> None:
"""
Clear all the sublayers in the LayerDict.
Expand Down Expand Up @@ -120,7 +136,7 @@ def clear(self):
"""
self._sub_layers.clear()

def pop(self, key):
def pop(self, key: str) -> Layer:
"""
Remove the key from the LayerDict and return the layer of the key.
Expand Down Expand Up @@ -152,7 +168,7 @@ def pop(self, key):
del self[key]
return v

def keys(self):
def keys(self) -> Iterable[str]:
"""
Return the iterable of the keys in LayerDict.
Expand Down Expand Up @@ -181,7 +197,7 @@ def keys(self):
"""
return self._sub_layers.keys()

def items(self):
def items(self) -> Iterable[tuple[str, Layer]]:
"""
Return the iterable of the key/value pairs in LayerDict.
Expand Down Expand Up @@ -210,7 +226,7 @@ def items(self):
"""
return self._sub_layers.items()

def values(self):
def values(self) -> Iterable[Layer]:
"""
Return the iterable of the values in LayerDict.
Expand Down Expand Up @@ -239,7 +255,12 @@ def values(self):
"""
return self._sub_layers.values()

def update(self, sublayers):
def update(
self,
sublayers: (
LayerDict | typing.Mapping[str, Layer] | Sequence[tuple[str, Layer]]
),
) -> None:
"""
Update the key/values pairs in sublayers to the LayerDict, overwriting the existing keys.
Expand Down Expand Up @@ -353,29 +374,29 @@ class ParameterList(Layer):
[5, 4]
"""

def __init__(self, parameters=None):
def __init__(self, parameters: Iterable[Tensor] | None = None) -> None:
super().__init__()
if parameters is not None:
for idx, param in enumerate(parameters):
assert isinstance(param, Parameter)
self.add_parameter(str(idx), param)

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Tensor:
with param_guard(self._parameters):
return self._parameters[str(idx)]

def __setitem__(self, idx, param):
def __setitem__(self, idx: int, param: Tensor) -> None:
assert isinstance(param, Parameter)
setattr(self, str(idx), param)

def __len__(self):
def __len__(self) -> int:
return len(self._parameters)

def __iter__(self):
def __iter__(self) -> Iterator[Tensor]:
with param_guard(self._parameters):
return iter(self._parameters.values())

def append(self, parameter):
def append(self, parameter: Tensor) -> Self:
"""Appends a given parameter at the end of the list.
Parameters:
Expand Down Expand Up @@ -412,13 +433,13 @@ class LayerList(Layer):
... return x
"""

def __init__(self, sublayers=None):
def __init__(self, sublayers: Iterable[Layer] | None = None) -> None:
super().__init__()
if sublayers is not None:
for idx, layer in enumerate(sublayers):
self.add_sublayer(str(idx), layer)

def _get_abs_idx(self, idx):
def _get_abs_idx(self, idx: int) -> int:
if isinstance(idx, int):
if not (-len(self) <= idx < len(self)):
raise IndexError(
Expand All @@ -428,18 +449,18 @@ def _get_abs_idx(self, idx):
idx += len(self)
return idx

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Layer:
if isinstance(idx, slice):
return self.__class__(list(self._sub_layers.values())[idx])
else:
idx = self._get_abs_idx(idx)
return self._sub_layers[str(idx)]

def __setitem__(self, idx, sublayer):
def __setitem__(self, idx: int, sublayer: Layer) -> None:
idx = self._get_abs_idx(idx)
return setattr(self, str(idx), sublayer)

def __delitem__(self, idx):
def __delitem__(self, idx: int) -> None:
if isinstance(idx, slice):
for k in range(len(self._sub_layers))[idx]:
delattr(self, str(k))
Expand All @@ -451,13 +472,13 @@ def __delitem__(self, idx):
list(zip(str_indices, self._sub_layers.values()))
)

def __len__(self):
def __len__(self) -> int:
return len(self._sub_layers)

def __iter__(self):
def __iter__(self) -> Iterator[Layer]:
return iter(self._sub_layers.values())

def append(self, sublayer):
def append(self, sublayer: Layer) -> Self:
"""
Appends a sublayer to the end of the list.
Expand All @@ -478,7 +499,7 @@ def append(self, sublayer):
self.add_sublayer(str(len(self)), sublayer)
return self

def insert(self, index, sublayer):
def insert(self, index: int, sublayer: Layer) -> None:
"""
Insert a sublayer before a given index in the list.
Expand Down Expand Up @@ -510,7 +531,7 @@ def insert(self, index, sublayer):
self._sub_layers[str(i)] = self._sub_layers[str(i - 1)]
self._sub_layers[str(index)] = sublayer

def extend(self, sublayers):
def extend(self, sublayers: Iterable[Layer]) -> Self:
"""
Appends sublayers to the end of the list.
Expand Down Expand Up @@ -575,7 +596,7 @@ class Sequential(Layer):
"""

def __init__(self, *layers):
def __init__(self, *layers: Layer | tuple[str, Layer] | list[Any]) -> None:
super().__init__()
if len(layers) > 0 and isinstance(layers[0], (list, tuple)):
for name, layer in layers:
Expand All @@ -584,7 +605,7 @@ def __init__(self, *layers):
for idx, layer in enumerate(layers):
self.add_sublayer(str(idx), layer)

def __getitem__(self, name):
def __getitem__(self, name: str) -> Layer:
if isinstance(name, slice):
return self.__class__(*(list(self._sub_layers.values())[name]))
elif isinstance(name, str):
Expand All @@ -598,19 +619,19 @@ def __getitem__(self, name):
raise IndexError(f'index {name} is out of range')
return list(self._sub_layers.values())[name]

def __setitem__(self, name, layer):
def __setitem__(self, name: str, layer: Layer) -> None:
assert isinstance(layer, Layer)
setattr(self, str(name), layer)

def __delitem__(self, name):
def __delitem__(self, name: str) -> None:
name = str(name)
assert name in self._sub_layers
del self._sub_layers[name]

def __len__(self):
def __len__(self) -> int:
return len(self._sub_layers)

def forward(self, input):
def forward(self, input: Any) -> Any:
for layer in self._sub_layers.values():
input = layer(input)
return input
Loading

0 comments on commit 4ca9b7b

Please sign in to comment.