diff --git a/python/paddle/nn/utils/transform_parameters.py b/python/paddle/nn/utils/transform_parameters.py index 8db65d61bb5ba..d75bc0a0467b8 100644 --- a/python/paddle/nn/utils/transform_parameters.py +++ b/python/paddle/nn/utils/transform_parameters.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from functools import reduce +from typing import TYPE_CHECKING, Iterable import paddle from paddle import _C_ops @@ -23,9 +26,13 @@ in_dygraph_mode, ) +if TYPE_CHECKING: + from paddle import Tensor + from paddle._typing import ShapeLike + # input==output, inplace strategy of reshape has no cost almostly -def _inplace_reshape_dygraph(x, shape): +def _inplace_reshape_dygraph(x: Tensor, shape: ShapeLike) -> None: x_shape = _create_tensor(dtype='int64') if in_dygraph_mode(): with paddle.base.dygraph.no_grad(): @@ -42,12 +49,12 @@ def _inplace_reshape_dygraph(x, shape): @dygraph_only -def _stride_column(param): +def _stride_column(param: Tensor) -> None: """ A tool function. Permute date of parameter as a 'columns' stride. Now, it only support 2-D parameter. Args: - param(Tensor]): The param that will be strided according to 'columns'. + param(Tensor): The param that will be strided according to 'columns'. Examples: .. code-block:: python @@ -75,7 +82,9 @@ def _stride_column(param): @dygraph_only -def parameters_to_vector(parameters, name=None): +def parameters_to_vector( + parameters: Iterable[Tensor], name: str | None = None +) -> Tensor: """ Flatten parameters to a 1-D Tensor. @@ -126,7 +135,9 @@ def parameters_to_vector(parameters, name=None): @dygraph_only -def vector_to_parameters(vec, parameters, name=None): +def vector_to_parameters( + vec: Tensor, parameters: Iterable[Tensor], name: str | None = None +) -> None: """ Transform a 1-D Tensor to the input ``parameters`` .