From ddeabbe5deb01abf3ea9241baa2d376cbfee7181 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Sat, 27 Jul 2024 17:41:59 +0800 Subject: [PATCH 1/2] [Typing][B-45,B-46] Add type annotations for `python/paddle/framework/io.py` --- python/paddle/framework/io.py | 46 ++++++++++++++++++++++++++++++++--- python/paddle/jit/api.py | 36 +++++++++++++++------------ 2 files changed, 62 insertions(+), 20 deletions(-) diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index ccacf6064c804..3bfaf36e9820d 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import collections import copyreg import os @@ -20,6 +22,7 @@ import threading import warnings from collections.abc import Iterable +from typing import TYPE_CHECKING import numpy as np @@ -50,11 +53,35 @@ _unpack_saved_dict, ) +if TYPE_CHECKING: + from io import BytesIO + from typing import Any, Literal, TypedDict + + from typing_extensions import NotRequired, Unpack + + from paddle import Tensor + from paddle._typing import NestedStructure + from paddle.nn.layer.layers import _StateDict + + class _EmptyDict(TypedDict): + pass + + class _LoadOptions(TypedDict): + model_filename: NotRequired[str] + params_filename: NotRequired[str] + keep_name_table: NotRequired[bool] + return_numpy: NotRequired[bool] + + class _SaveOptions(TypedDict): + use_binary_format: NotRequired[bool] + pickle_protocol: NotRequired[Literal[2, 3, 4]] + + __all__ = [] async_save_queue = [] -def clear_async_save_task_queue(): +def clear_async_save_task_queue() -> None: ''' wait until all async save task to be done. ''' @@ -64,7 +91,13 @@ def clear_async_save_task_queue(): task.join() -def async_save(obj, path, protocol=4, sync_other_task=False, **configs): +def async_save( + obj: object, + path: str | BytesIO, + protocol: Literal[2, 3, 4] = 4, + sync_other_task: bool = False, + **configs: Unpack[_EmptyDict], +) -> None: ''' async version of paddle.save. Note: @@ -737,7 +770,12 @@ def _save_binary_var(obj, path): ) -def save(obj, path, protocol=4, **configs): +def save( + obj: _StateDict | NestedStructure[Tensor] | Program, + path: str | BytesIO, + protocol: Literal[2, 3, 4] = 4, + **configs: Unpack[_SaveOptions], +) -> None: ''' Save an object to the specified path. @@ -979,7 +1017,7 @@ def _legacy_save(obj, path, protocol=2): pickle.dump(saved_obj, f, protocol=protocol) -def load(path, **configs): +def load(path: str | BytesIO, **configs: Unpack[_LoadOptions]) -> Any: ''' Load an object can be used in paddle from specified path. diff --git a/python/paddle/jit/api.py b/python/paddle/jit/api.py index c112172e59652..56a6347ba91b6 100644 --- a/python/paddle/jit/api.py +++ b/python/paddle/jit/api.py @@ -91,6 +91,20 @@ from paddle._typing import NestedStructure from paddle.static import InputSpec + class _SaveOptions(TypedDict): + output_spec: NotRequired[Sequence[InputSpec]] + with_hook: NotRequired[bool] + combine_params: NotRequired[bool] + clip_extra: NotRequired[bool] + skip_forward: NotRequired[bool] + input_names_after_prune: NotRequired[list[str]] + skip_prune_program: NotRequired[bool] + + class _LoadOptions(TypedDict): + model_filename: NotRequired[str] + params_filename: NotRequired[str] + + ENV_ENABLE_SOT = BooleanEnvironmentVariable("ENABLE_FALL_BACK", True) @@ -493,17 +507,7 @@ def keep_name_table(self, value): self._keep_name_table = value -class _SaveLoadOptions(TypedDict): - output_spec: NotRequired[Sequence[InputSpec]] - with_hook: NotRequired[bool] - combine_params: NotRequired[bool] - clip_extra: NotRequired[bool] - skip_forward: NotRequired[bool] - input_names_after_prune: NotRequired[list[str]] - skip_prune_program: NotRequired[bool] - - -def _parse_save_configs(configs: _SaveLoadOptions): +def _parse_save_configs(configs: _SaveOptions) -> _SaveLoadConfig: supported_configs = [ "output_spec", "with_hook", @@ -536,7 +540,7 @@ def _parse_save_configs(configs: _SaveLoadOptions): return inner_config -def _parse_load_config(configs): +def _parse_load_config(configs: _LoadOptions) -> _SaveLoadConfig: supported_configs = ['model_filename', 'params_filename'] # input check @@ -877,7 +881,7 @@ def __call__( layer: Layer | Callable[..., Any], path: str, input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = ..., - **configs: Unpack[_SaveLoadOptions], + **configs: Unpack[_SaveOptions], ) -> None: ... @@ -888,7 +892,7 @@ def wrapper( layer: Layer | Callable[..., Any], path: str, input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = None, - **configs: Unpack[_SaveLoadOptions], + **configs: Unpack[_SaveOptions], ) -> None: global _save_pre_hooks for hook in _save_pre_hooks: @@ -950,7 +954,7 @@ def save( layer: Layer | Callable[..., Any], path: str, input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = None, - **configs: Unpack[_SaveLoadOptions], + **configs: Unpack[_SaveOptions], ) -> None: """ Saves input Layer or function as ``paddle.jit.TranslatedLayer`` @@ -1517,7 +1521,7 @@ def save( @dygraph_only def load( - path: str, **configs: Unpack[_SaveLoadOptions] + path: str, **configs: Unpack[_LoadOptions] ) -> TranslatedLayer | PirTranslatedLayer: """ :api_attr: imperative From f7840077e52406fe254c238ed153e5be1253c2b7 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Sat, 27 Jul 2024 19:14:33 +0800 Subject: [PATCH 2/2] fix DataLoader `places` typing --- python/paddle/io/reader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/io/reader.py b/python/paddle/io/reader.py index 5b7ab7597c5ba..520c903ea30cb 100644 --- a/python/paddle/io/reader.py +++ b/python/paddle/io/reader.py @@ -437,7 +437,7 @@ class DataLoader: collate_fn: _CollateFn | None use_buffer_reader: bool prefetch_factor: int - worker_init_fn: Callable[[int], None] + worker_init_fn: Callable[[int], None] | None dataset: Dataset feed_list: Sequence[Tensor] | None places: Sequence[PlaceLike] | None @@ -449,7 +449,7 @@ def __init__( self, dataset: Dataset, feed_list: Sequence[Tensor] | None = None, - places: Sequence[PlaceLike] | None = None, + places: PlaceLike | Sequence[PlaceLike] | None = None, return_list: bool = True, batch_sampler: BatchSampler | None = None, batch_size: int = 1, @@ -461,7 +461,7 @@ def __init__( prefetch_factor: int = 2, use_shared_memory: bool = True, timeout: int = 0, - worker_init_fn: Callable[[int], None] = None, + worker_init_fn: Callable[[int], None] | None = None, persistent_workers: bool = False, ) -> None: self.return_list = return_list