-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Typing][A-27,A-28,A-32,A-33,A-35,A-36] Add type annotations for paddle/nn/initializer/*
#65206
[Typing][A-27,A-28,A-32,A-33,A-35,A-36] Add type annotations for paddle/nn/initializer/*
#65206
Conversation
…addle/nn/initializer/*`
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -23,6 +28,9 @@ | |||
) | |||
from .initializer import Initializer | |||
|
|||
if TYPE_CHECKING: | |||
import numpy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import numpy | |
import numpy.typing as npt |
@@ -38,19 +46,21 @@ class NumpyArrayInitializer(Initializer): | |||
|
|||
""" | |||
|
|||
def __init__(self, value): | |||
def __init__(self, value: numpy.ndarray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, value: numpy.ndarray): | |
def __init__(self, value: npt.NDArray[Any]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__init__
标注一下返回值吧,都是 None
def __init__(self, value, name=None): | ||
def __init__( | ||
self, | ||
value: numpy.ndarray | list[int] | tuple[int] | paddle.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tuple[int]
表示 (1, )
这种只有一个元素的,如果是不定长元素,应该用 tuple[int, ...]
,在支持 list[T] | tuple[T, ...]
的情况下,更常见的是直接写 Sequence[T]
fan_in: float | None = None, | ||
seed: int = 0, | ||
negative_slope: float = 0, | ||
nonlinearity: str = 'relu', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否直接用 Literal?见
Paddle/python/paddle/nn/initializer/initializer.py
Lines 167 to 180 in 8261cde
recommended_gain = { | |
'sigmoid': 1, | |
'linear': 1, | |
'conv1d': 1, | |
'conv2d': 1, | |
'conv3d': 1, | |
'conv1d_transpose': 1, | |
'conv2d_transpose': 1, | |
'conv3d_transpose': 1, | |
'tanh': 5.0 / 3, | |
'relu': math.sqrt(2.0), | |
'leaky_relu': math.sqrt(2.0 / (1 + param**2)), | |
'selu': 3.0 / 4, | |
} |
使用 TypeAlias 写在 initializer.py
里这边 import 过来就好
def forward(self, var, block=None): | ||
def forward( | ||
self, var: paddle.Tensor, block: paddle.pir.Block | None = None | ||
) -> paddle.Tensor | None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个返回值,基类以及标注过的几个文件也都统一一下吧
@@ -28,6 +29,22 @@ | |||
) | |||
from .lazy_init import lazy_init_helper | |||
|
|||
if TYPE_CHECKING: | |||
NonLinearityT: TypeAlias = Literal[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NonLinearityT: TypeAlias = Literal[ | |
_NonLinearity: TypeAlias = Literal[ |
这样吧,T
还是统一表示泛型参数吧
@@ -16,6 +16,7 @@ | |||
|
|||
import functools | |||
import math | |||
from typing import TYPE_CHECKING, Literal, TypeAlias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TypeAlias
要从 typing_extensions
import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle/nn/initializer/*
paddle/nn/initializer/*
…dle/nn/initializer/*` (PaddlePaddle#65206)
PR Category
User Experience
PR Types
Improvements
Description
类型标注:
Related links
@SigureMo @megemini