Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][B-49] Add type annotations for python/paddle/base/layer_helper.py #66639

Merged
merged 3 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions python/paddle/base/layer_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
from typing import TYPE_CHECKING, Any, Generator

import paddle
from paddle import _C_ops
Expand All @@ -28,9 +30,14 @@
from .layer_helper_base import LayerHelperBase
from .param_attr import ParamAttr

if TYPE_CHECKING:
from paddle import Tensor
from paddle._typing import DTypeLike, ParamAttrLike
from paddle.base.framework import Operator


class LayerHelper(LayerHelperBase):
def __init__(self, layer_type, **kwargs):
def __init__(self, layer_type: str, **kwargs: Any) -> None:
self.kwargs = kwargs
name = self.kwargs.get('name', None)
# TODO(panyx0718, minqiyang): dygraph mode
Expand All @@ -46,10 +53,10 @@ def __init__(self, layer_type, **kwargs):

super().__init__(self.kwargs['name'], layer_type=layer_type)

def append_op(self, *args, **kwargs):
def append_op(self, *args: Any, **kwargs: Any) -> Operator:
return self.main_program.current_block().append_op(*args, **kwargs)

def multiple_input(self, input_param_name='input'):
def multiple_input(self, input_param_name: str = 'input') -> list[Tensor]:
inputs = self.kwargs.get(input_param_name, [])
ret = []
if isinstance(inputs, (list, tuple)):
Expand All @@ -59,22 +66,22 @@ def multiple_input(self, input_param_name='input'):
ret.append(self.to_variable(inputs))
return ret

def input(self, input_param_name='input'):
def input(self, input_param_name: str = 'input') -> paddle.Tensor:
ooooo-create marked this conversation as resolved.
Show resolved Hide resolved
inputs = self.multiple_input(input_param_name)
if len(inputs) != 1:
raise f"{self.layer_type} layer only takes one input"
return inputs[0]

@property
def param_attr(self):
def param_attr(self) -> ParamAttrLike | list[ParamAttrLike]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用的 ParamAttrLike 是不是都应该是 ParamAttr

Copy link
Contributor Author

@ooooo-create ooooo-create Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,已修改~

return ParamAttr._to_attr(self.kwargs.get('param_attr', None))

@property
def bias_attr(self):
def bias_attr(self) -> ParamAttrLike | list[ParamAttrLike]:
return ParamAttr._to_attr(self.kwargs.get('bias_attr', None))

# TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of param_attr
def multiple_param_attr(self, length):
def multiple_param_attr(self, length: int) -> list[ParamAttrLike]:
param_attr = self.param_attr
if isinstance(param_attr, ParamAttr):
param_attr = [param_attr]
Expand All @@ -88,12 +95,14 @@ def multiple_param_attr(self, length):
param_attr = tmp
return param_attr

def iter_inputs_and_params(self, input_param_name='input'):
def iter_inputs_and_params(
self, input_param_name: str = 'input'
) -> Generator[tuple[Tensor, ParamAttrLike]]:
inputs = self.multiple_input(input_param_name)
param_attrs = self.multiple_param_attr(len(inputs))
yield from zip(inputs, param_attrs)

def input_dtype(self, input_param_name='input'):
def input_dtype(self, input_param_name: str = 'input') -> None | DTypeLike:
ooooo-create marked this conversation as resolved.
Show resolved Hide resolved
inputs = self.multiple_input(input_param_name)
dtype = None
for each in inputs:
Expand All @@ -105,14 +114,16 @@ def input_dtype(self, input_param_name='input'):
)
return dtype

def get_parameter(self, name):
def get_parameter(self, name: str) -> Parameter:
ooooo-create marked this conversation as resolved.
Show resolved Hide resolved
param = self.main_program.global_block().var(name)
if not isinstance(param, Parameter):
raise ValueError(f"no Parameter name {name} found")
return param

# TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of bias_attr
def append_bias_op(self, input_var, dim_start=1, dim_end=None):
def append_bias_op(
self, input_var: Tensor, dim_start: int = 1, dim_end: int | None = None
) -> Tensor:
"""
Append bias operator and return its output. If the user does not set
bias_attr, append_bias_op will return input_var
Expand Down Expand Up @@ -146,7 +157,7 @@ def append_bias_op(self, input_var, dim_start=1, dim_end=None):
return tmp

# TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of act
def append_activation(self, input_var):
def append_activation(self, input_var: Tensor) -> Tensor:
act = self.kwargs.get('act', None)
if act is None:
return input_var
Expand Down Expand Up @@ -197,7 +208,7 @@ def _get_default_initializer(self, dtype):
return paddle.nn.initializer.Constant()

# TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of kwargs
def is_instance(self, param_name, cls):
def is_instance(self, param_name: str, cls: Any) -> None:
param = self.kwargs.get(param_name, None)
if not isinstance(param, cls):
raise TypeError(
Expand Down
10 changes: 9 additions & 1 deletion python/paddle/base/param_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence

import paddle
from paddle.base.data_feeder import check_type
from paddle.regularizer import WeightDecayRegularizer

if TYPE_CHECKING:
from paddle._typing import ParamAttrLike

__all__ = []


Expand Down Expand Up @@ -154,7 +160,9 @@ def _set_default_bias_initializer(self):
self._set_default_initializer(paddle.nn.initializer.Constant(0.0))

@staticmethod
def _to_attr(arg):
def _to_attr(
arg: ParamAttrLike | Sequence[ParamAttrLike],
) -> ParamAttrLike | list[ParamAttrLike]:
ooooo-create marked this conversation as resolved.
Show resolved Hide resolved
"""
Create ParamAttr[s].

Expand Down