Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][B-49] Add type annotations for python/paddle/base/layer_helper.py #66639

Merged
merged 3 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions python/paddle/base/layer_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
from typing import TYPE_CHECKING, Any, Generator

import paddle
from paddle import _C_ops
Expand All @@ -28,9 +30,13 @@
from .layer_helper_base import LayerHelperBase
from .param_attr import ParamAttr

if TYPE_CHECKING:
from paddle import Tensor
from paddle.base.framework import Operator


class LayerHelper(LayerHelperBase):
def __init__(self, layer_type, **kwargs):
def __init__(self, layer_type: str, **kwargs: Any) -> None:
self.kwargs = kwargs
name = self.kwargs.get('name', None)
# TODO(panyx0718, minqiyang): dygraph mode
Expand All @@ -46,10 +52,10 @@ def __init__(self, layer_type, **kwargs):

super().__init__(self.kwargs['name'], layer_type=layer_type)

def append_op(self, *args, **kwargs):
def append_op(self, *args: Any, **kwargs: Any) -> Operator:
return self.main_program.current_block().append_op(*args, **kwargs)

def multiple_input(self, input_param_name='input'):
def multiple_input(self, input_param_name: str = 'input') -> list[Tensor]:
inputs = self.kwargs.get(input_param_name, [])
ret = []
if isinstance(inputs, (list, tuple)):
Expand All @@ -59,22 +65,22 @@ def multiple_input(self, input_param_name='input'):
ret.append(self.to_variable(inputs))
return ret

def input(self, input_param_name='input'):
def input(self, input_param_name: str = 'input') -> Tensor:
inputs = self.multiple_input(input_param_name)
if len(inputs) != 1:
raise f"{self.layer_type} layer only takes one input"
return inputs[0]

@property
def param_attr(self):
def param_attr(self) -> ParamAttr:
return ParamAttr._to_attr(self.kwargs.get('param_attr', None))

@property
def bias_attr(self):
def bias_attr(self) -> ParamAttr:
return ParamAttr._to_attr(self.kwargs.get('bias_attr', None))

# TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of param_attr
def multiple_param_attr(self, length):
def multiple_param_attr(self, length: int) -> list[ParamAttr]:
param_attr = self.param_attr
if isinstance(param_attr, ParamAttr):
param_attr = [param_attr]
Expand All @@ -88,12 +94,16 @@ def multiple_param_attr(self, length):
param_attr = tmp
return param_attr

def iter_inputs_and_params(self, input_param_name='input'):
def iter_inputs_and_params(
self, input_param_name: str = 'input'
) -> Generator[tuple[Tensor, ParamAttr]]:
inputs = self.multiple_input(input_param_name)
param_attrs = self.multiple_param_attr(len(inputs))
yield from zip(inputs, param_attrs)

def input_dtype(self, input_param_name='input'):
def input_dtype(
self, input_param_name: str = 'input'
) -> None | paddle.dtype:
inputs = self.multiple_input(input_param_name)
dtype = None
for each in inputs:
Expand All @@ -105,14 +115,16 @@ def input_dtype(self, input_param_name='input'):
)
return dtype

def get_parameter(self, name):
def get_parameter(self, name: str) -> Tensor:
param = self.main_program.global_block().var(name)
if not isinstance(param, Parameter):
raise ValueError(f"no Parameter name {name} found")
return param

# TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of bias_attr
def append_bias_op(self, input_var, dim_start=1, dim_end=None):
def append_bias_op(
self, input_var: Tensor, dim_start: int = 1, dim_end: int | None = None
) -> Tensor:
"""
Append bias operator and return its output. If the user does not set
bias_attr, append_bias_op will return input_var
Expand Down Expand Up @@ -146,7 +158,7 @@ def append_bias_op(self, input_var, dim_start=1, dim_end=None):
return tmp

# TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of act
def append_activation(self, input_var):
def append_activation(self, input_var: Tensor) -> Tensor:
act = self.kwargs.get('act', None)
if act is None:
return input_var
Expand Down Expand Up @@ -197,7 +209,7 @@ def _get_default_initializer(self, dtype):
return paddle.nn.initializer.Constant()

# TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of kwargs
def is_instance(self, param_name, cls):
def is_instance(self, param_name: str, cls: Any) -> None:
param = self.kwargs.get(param_name, None)
if not isinstance(param, cls):
raise TypeError(
Expand Down
21 changes: 21 additions & 0 deletions python/paddle/base/param_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence, overload

import paddle
from paddle.base.data_feeder import check_type
from paddle.regularizer import WeightDecayRegularizer

if TYPE_CHECKING:
from paddle._typing import ParamAttrLike

__all__ = []


Expand Down Expand Up @@ -153,6 +159,21 @@ def _set_default_bias_initializer(self):
"""
self._set_default_initializer(paddle.nn.initializer.Constant(0.0))

@overload
@staticmethod
def _to_attr(arg: None) -> ParamAttr:
...

@overload
@staticmethod
def _to_attr(arg: ParamAttrLike) -> ParamAttr:
...

@overload
@staticmethod
def _to_attr(arg: Sequence[ParamAttrLike]) -> list[ParamAttr]:
...

@staticmethod
def _to_attr(arg):
"""
Expand Down