Skip to content

Commit

Permalink
[Typing][B-45,B-46] Add type annotations for `python/paddle/framework…
Browse files Browse the repository at this point in the history
…/io.py` (PaddlePaddle#66654)
  • Loading branch information
SigureMo authored and lixcli committed Aug 5, 2024
1 parent 5fd9708 commit 6ae50b0
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 23 deletions.
46 changes: 42 additions & 4 deletions python/paddle/framework/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import collections
import copyreg
import os
Expand All @@ -20,6 +22,7 @@
import threading
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING

import numpy as np

Expand Down Expand Up @@ -50,11 +53,35 @@
_unpack_saved_dict,
)

if TYPE_CHECKING:
from io import BytesIO
from typing import Any, Literal, TypedDict

from typing_extensions import NotRequired, Unpack

from paddle import Tensor
from paddle._typing import NestedStructure
from paddle.nn.layer.layers import _StateDict

class _EmptyDict(TypedDict):
pass

class _LoadOptions(TypedDict):
model_filename: NotRequired[str]
params_filename: NotRequired[str]
keep_name_table: NotRequired[bool]
return_numpy: NotRequired[bool]

class _SaveOptions(TypedDict):
use_binary_format: NotRequired[bool]
pickle_protocol: NotRequired[Literal[2, 3, 4]]


__all__ = []
async_save_queue = []


def clear_async_save_task_queue():
def clear_async_save_task_queue() -> None:
'''
wait until all async save task to be done.
'''
Expand All @@ -64,7 +91,13 @@ def clear_async_save_task_queue():
task.join()


def async_save(obj, path, protocol=4, sync_other_task=False, **configs):
def async_save(
obj: object,
path: str | BytesIO,
protocol: Literal[2, 3, 4] = 4,
sync_other_task: bool = False,
**configs: Unpack[_EmptyDict],
) -> None:
'''
async version of paddle.save.
Note:
Expand Down Expand Up @@ -737,7 +770,12 @@ def _save_binary_var(obj, path):
)


def save(obj, path, protocol=4, **configs):
def save(
obj: _StateDict | NestedStructure[Tensor] | Program,
path: str | BytesIO,
protocol: Literal[2, 3, 4] = 4,
**configs: Unpack[_SaveOptions],
) -> None:
'''
Save an object to the specified path.
Expand Down Expand Up @@ -979,7 +1017,7 @@ def _legacy_save(obj, path, protocol=2):
pickle.dump(saved_obj, f, protocol=protocol)


def load(path, **configs):
def load(path: str | BytesIO, **configs: Unpack[_LoadOptions]) -> Any:
'''
Load an object can be used in paddle from specified path.
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ class DataLoader:
collate_fn: _CollateFn | None
use_buffer_reader: bool
prefetch_factor: int
worker_init_fn: Callable[[int], None]
worker_init_fn: Callable[[int], None] | None
dataset: Dataset
feed_list: Sequence[Tensor] | None
places: Sequence[PlaceLike] | None
Expand All @@ -449,7 +449,7 @@ def __init__(
self,
dataset: Dataset,
feed_list: Sequence[Tensor] | None = None,
places: Sequence[PlaceLike] | None = None,
places: PlaceLike | Sequence[PlaceLike] | None = None,
return_list: bool = True,
batch_sampler: BatchSampler | None = None,
batch_size: int = 1,
Expand All @@ -461,7 +461,7 @@ def __init__(
prefetch_factor: int = 2,
use_shared_memory: bool = True,
timeout: int = 0,
worker_init_fn: Callable[[int], None] = None,
worker_init_fn: Callable[[int], None] | None = None,
persistent_workers: bool = False,
) -> None:
self.return_list = return_list
Expand Down
36 changes: 20 additions & 16 deletions python/paddle/jit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@
from paddle._typing import NestedStructure
from paddle.static import InputSpec

class _SaveOptions(TypedDict):
output_spec: NotRequired[Sequence[InputSpec]]
with_hook: NotRequired[bool]
combine_params: NotRequired[bool]
clip_extra: NotRequired[bool]
skip_forward: NotRequired[bool]
input_names_after_prune: NotRequired[list[str]]
skip_prune_program: NotRequired[bool]

class _LoadOptions(TypedDict):
model_filename: NotRequired[str]
params_filename: NotRequired[str]


ENV_ENABLE_SOT = BooleanEnvironmentVariable("ENABLE_FALL_BACK", True)


Expand Down Expand Up @@ -493,17 +507,7 @@ def keep_name_table(self, value):
self._keep_name_table = value


class _SaveLoadOptions(TypedDict):
output_spec: NotRequired[Sequence[InputSpec]]
with_hook: NotRequired[bool]
combine_params: NotRequired[bool]
clip_extra: NotRequired[bool]
skip_forward: NotRequired[bool]
input_names_after_prune: NotRequired[list[str]]
skip_prune_program: NotRequired[bool]


def _parse_save_configs(configs: _SaveLoadOptions):
def _parse_save_configs(configs: _SaveOptions) -> _SaveLoadConfig:
supported_configs = [
"output_spec",
"with_hook",
Expand Down Expand Up @@ -536,7 +540,7 @@ def _parse_save_configs(configs: _SaveLoadOptions):
return inner_config


def _parse_load_config(configs):
def _parse_load_config(configs: _LoadOptions) -> _SaveLoadConfig:
supported_configs = ['model_filename', 'params_filename']

# input check
Expand Down Expand Up @@ -877,7 +881,7 @@ def __call__(
layer: Layer | Callable[..., Any],
path: str,
input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = ...,
**configs: Unpack[_SaveLoadOptions],
**configs: Unpack[_SaveOptions],
) -> None:
...

Expand All @@ -888,7 +892,7 @@ def wrapper(
layer: Layer | Callable[..., Any],
path: str,
input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = None,
**configs: Unpack[_SaveLoadOptions],
**configs: Unpack[_SaveOptions],
) -> None:
global _save_pre_hooks
for hook in _save_pre_hooks:
Expand Down Expand Up @@ -950,7 +954,7 @@ def save(
layer: Layer | Callable[..., Any],
path: str,
input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = None,
**configs: Unpack[_SaveLoadOptions],
**configs: Unpack[_SaveOptions],
) -> None:
"""
Saves input Layer or function as ``paddle.jit.TranslatedLayer``
Expand Down Expand Up @@ -1517,7 +1521,7 @@ def save(

@dygraph_only
def load(
path: str, **configs: Unpack[_SaveLoadOptions]
path: str, **configs: Unpack[_LoadOptions]
) -> TranslatedLayer | PirTranslatedLayer:
"""
:api_attr: imperative
Expand Down

0 comments on commit 6ae50b0

Please sign in to comment.