-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Typing][A-27,A-28,A-32,A-33,A-35,A-36] Add type annotations for paddle/nn/initializer/*
#65206
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -11,6 +11,11 @@ | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
from __future__ import annotations | ||||||
|
||||||
from typing import TYPE_CHECKING | ||||||
|
||||||
import paddle | ||||||
from paddle import _C_ops | ||||||
|
||||||
|
@@ -23,6 +28,9 @@ | |||||
) | ||||||
from .initializer import Initializer | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
import numpy | ||||||
|
||||||
__all__ = [] | ||||||
|
||||||
|
||||||
|
@@ -38,19 +46,21 @@ class NumpyArrayInitializer(Initializer): | |||||
|
||||||
""" | ||||||
|
||||||
def __init__(self, value): | ||||||
def __init__(self, value: numpy.ndarray): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
import numpy | ||||||
|
||||||
assert isinstance(value, numpy.ndarray) | ||||||
super().__init__() | ||||||
self._value = value | ||||||
|
||||||
def forward(self, var, block=None): | ||||||
def forward( | ||||||
self, var: paddle.Tensor, block: paddle.pir.Block | None = None | ||||||
) -> paddle.Tensor | None: | ||||||
"""Initialize the input tensor with Numpy array. | ||||||
|
||||||
Args: | ||||||
var(Tensor): Tensor that needs to be initialized. | ||||||
block(Block, optional): The block in which initialization ops | ||||||
block(Block|None, optional): The block in which initialization ops | ||||||
should be added. Used in static graph only, default None. | ||||||
|
||||||
Returns: | ||||||
|
@@ -172,7 +182,7 @@ class Assign(NumpyArrayInitializer): | |||||
|
||||||
Args: | ||||||
value (Tensor|numpy.ndarray|list|tuple): numpy array, list, tuple, or tensor to initialize the parameter. | ||||||
name(str, optional): Normally there is no need for user to set this | ||||||
name(str|None, optional): Normally there is no need for user to set this | ||||||
property. For more information, please refer to :ref:`api_guide_Name`. Default is None. | ||||||
|
||||||
Returns: | ||||||
|
@@ -239,7 +249,11 @@ class Assign(NumpyArrayInitializer): | |||||
[6.] | ||||||
""" | ||||||
|
||||||
def __init__(self, value, name=None): | ||||||
def __init__( | ||||||
self, | ||||||
value: numpy.ndarray | list[int] | tuple[int] | paddle.Tensor, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
name: str | None = None, | ||||||
): | ||||||
import numpy | ||||||
|
||||||
check_type( | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12,6 +12,8 @@ | |||||||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
from __future__ import annotations | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# TODO: define the initializers of Kaiming functions in neural network | ||||||||||||||||||||||||||||||
import math | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -64,11 +66,11 @@ class MSRAInitializer(Initializer): | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||
uniform=True, | ||||||||||||||||||||||||||||||
fan_in=None, | ||||||||||||||||||||||||||||||
seed=0, | ||||||||||||||||||||||||||||||
negative_slope=0, | ||||||||||||||||||||||||||||||
nonlinearity='relu', | ||||||||||||||||||||||||||||||
uniform: bool = True, | ||||||||||||||||||||||||||||||
fan_in: float | None = None, | ||||||||||||||||||||||||||||||
seed: int = 0, | ||||||||||||||||||||||||||||||
negative_slope: float = 0, | ||||||||||||||||||||||||||||||
nonlinearity: str = 'relu', | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是否直接用 Literal?见 Paddle/python/paddle/nn/initializer/initializer.py Lines 167 to 180 in 8261cde
使用 TypeAlias 写在 |
||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
"""Constructor for MSRAInitializer""" | ||||||||||||||||||||||||||||||
assert uniform is not None | ||||||||||||||||||||||||||||||
|
@@ -80,12 +82,14 @@ def __init__( | |||||||||||||||||||||||||||||
self._negative_slope = negative_slope | ||||||||||||||||||||||||||||||
self._nonlinearity = nonlinearity | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def forward(self, var, block=None): | ||||||||||||||||||||||||||||||
def forward( | ||||||||||||||||||||||||||||||
self, var: paddle.Tensor, block: paddle.pir.Block | None = None | ||||||||||||||||||||||||||||||
) -> paddle.Tensor | None: | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个返回值,基类以及标注过的几个文件也都统一一下吧 |
||||||||||||||||||||||||||||||
"""Initialize the input tensor with MSRA initialization. | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||
var(Tensor): Tensor that needs to be initialized. | ||||||||||||||||||||||||||||||
block(Block, optional): The block in which initialization ops | ||||||||||||||||||||||||||||||
block(Block|None, optional): The block in which initialization ops | ||||||||||||||||||||||||||||||
should be added. Used in static graph only, default None. | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||
|
@@ -271,7 +275,12 @@ class KaimingNormal(MSRAInitializer): | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def __init__(self, fan_in=None, negative_slope=0.0, nonlinearity='relu'): | ||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||
fan_in: float | None = None, | ||||||||||||||||||||||||||||||
negative_slope: float = 0.0, | ||||||||||||||||||||||||||||||
nonlinearity: str = 'relu', | ||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
super().__init__( | ||||||||||||||||||||||||||||||
uniform=False, | ||||||||||||||||||||||||||||||
fan_in=fan_in, | ||||||||||||||||||||||||||||||
|
@@ -317,7 +326,12 @@ class KaimingUniform(MSRAInitializer): | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def __init__(self, fan_in=None, negative_slope=0.0, nonlinearity='relu'): | ||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||
fan_in: float | None = None, | ||||||||||||||||||||||||||||||
negative_slope: float = 0.0, | ||||||||||||||||||||||||||||||
nonlinearity: str = 'relu', | ||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
super().__init__( | ||||||||||||||||||||||||||||||
uniform=True, | ||||||||||||||||||||||||||||||
fan_in=fan_in, | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.