Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][B-45,B-46] Add type annotations for python/paddle/framework/io.py #66654

Merged
merged 2 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions python/paddle/framework/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import collections
import copyreg
import os
Expand All @@ -20,6 +22,7 @@
import threading
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING

import numpy as np

Expand Down Expand Up @@ -50,11 +53,35 @@
_unpack_saved_dict,
)

if TYPE_CHECKING:
from io import BytesIO
from typing import Any, Literal, TypedDict

from typing_extensions import NotRequired, Unpack

from paddle import Tensor
from paddle._typing import NestedStructure
from paddle.nn.layer.layers import _StateDict

class _EmptyDict(TypedDict):
pass

class _LoadOptions(TypedDict):
model_filename: NotRequired[str]
params_filename: NotRequired[str]
keep_name_table: NotRequired[bool]
return_numpy: NotRequired[bool]

class _SaveOptions(TypedDict):
use_binary_format: NotRequired[bool]
pickle_protocol: NotRequired[Literal[2, 3, 4]]


__all__ = []
async_save_queue = []


def clear_async_save_task_queue():
def clear_async_save_task_queue() -> None:
'''
wait until all async save task to be done.
'''
Expand All @@ -64,7 +91,13 @@ def clear_async_save_task_queue():
task.join()


def async_save(obj, path, protocol=4, sync_other_task=False, **configs):
def async_save(
obj: object,
path: str | BytesIO,
protocol: Literal[2, 3, 4] = 4,
sync_other_task: bool = False,
**configs: Unpack[_EmptyDict],
) -> None:
'''
async version of paddle.save.
Note:
Expand Down Expand Up @@ -737,7 +770,12 @@ def _save_binary_var(obj, path):
)


def save(obj, path, protocol=4, **configs):
def save(
obj: _StateDict | NestedStructure[Tensor] | Program,
path: str | BytesIO,
protocol: Literal[2, 3, 4] = 4,
**configs: Unpack[_SaveOptions],
) -> None:
'''
Save an object to the specified path.
Expand Down Expand Up @@ -979,7 +1017,7 @@ def _legacy_save(obj, path, protocol=2):
pickle.dump(saved_obj, f, protocol=protocol)


def load(path, **configs):
def load(path: str | BytesIO, **configs: Unpack[_LoadOptions]) -> Any:
'''
Load an object can be used in paddle from specified path.
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ class DataLoader:
collate_fn: _CollateFn | None
use_buffer_reader: bool
prefetch_factor: int
worker_init_fn: Callable[[int], None]
worker_init_fn: Callable[[int], None] | None
dataset: Dataset
feed_list: Sequence[Tensor] | None
places: Sequence[PlaceLike] | None
Expand All @@ -449,7 +449,7 @@ def __init__(
self,
dataset: Dataset,
feed_list: Sequence[Tensor] | None = None,
places: Sequence[PlaceLike] | None = None,
places: PlaceLike | Sequence[PlaceLike] | None = None,
return_list: bool = True,
batch_sampler: BatchSampler | None = None,
batch_size: int = 1,
Expand All @@ -461,7 +461,7 @@ def __init__(
prefetch_factor: int = 2,
use_shared_memory: bool = True,
timeout: int = 0,
worker_init_fn: Callable[[int], None] = None,
worker_init_fn: Callable[[int], None] | None = None,
persistent_workers: bool = False,
) -> None:
self.return_list = return_list
Expand Down
36 changes: 20 additions & 16 deletions python/paddle/jit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@
from paddle._typing import NestedStructure
from paddle.static import InputSpec

class _SaveOptions(TypedDict):
output_spec: NotRequired[Sequence[InputSpec]]
with_hook: NotRequired[bool]
combine_params: NotRequired[bool]
clip_extra: NotRequired[bool]
skip_forward: NotRequired[bool]
input_names_after_prune: NotRequired[list[str]]
skip_prune_program: NotRequired[bool]

class _LoadOptions(TypedDict):
model_filename: NotRequired[str]
params_filename: NotRequired[str]


ENV_ENABLE_SOT = BooleanEnvironmentVariable("ENABLE_FALL_BACK", True)


Expand Down Expand Up @@ -493,17 +507,7 @@ def keep_name_table(self, value):
self._keep_name_table = value


class _SaveLoadOptions(TypedDict):
output_spec: NotRequired[Sequence[InputSpec]]
with_hook: NotRequired[bool]
combine_params: NotRequired[bool]
clip_extra: NotRequired[bool]
skip_forward: NotRequired[bool]
input_names_after_prune: NotRequired[list[str]]
skip_prune_program: NotRequired[bool]


def _parse_save_configs(configs: _SaveLoadOptions):
def _parse_save_configs(configs: _SaveOptions) -> _SaveLoadConfig:
supported_configs = [
"output_spec",
"with_hook",
Expand Down Expand Up @@ -536,7 +540,7 @@ def _parse_save_configs(configs: _SaveLoadOptions):
return inner_config


def _parse_load_config(configs):
def _parse_load_config(configs: _LoadOptions) -> _SaveLoadConfig:
supported_configs = ['model_filename', 'params_filename']

# input check
Expand Down Expand Up @@ -877,7 +881,7 @@ def __call__(
layer: Layer | Callable[..., Any],
path: str,
input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = ...,
**configs: Unpack[_SaveLoadOptions],
**configs: Unpack[_SaveOptions],
) -> None:
...

Expand All @@ -888,7 +892,7 @@ def wrapper(
layer: Layer | Callable[..., Any],
path: str,
input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = None,
**configs: Unpack[_SaveLoadOptions],
**configs: Unpack[_SaveOptions],
) -> None:
global _save_pre_hooks
for hook in _save_pre_hooks:
Expand Down Expand Up @@ -950,7 +954,7 @@ def save(
layer: Layer | Callable[..., Any],
path: str,
input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = None,
**configs: Unpack[_SaveLoadOptions],
**configs: Unpack[_SaveOptions],
) -> None:
"""
Saves input Layer or function as ``paddle.jit.TranslatedLayer``
Expand Down Expand Up @@ -1517,7 +1521,7 @@ def save(

@dygraph_only
def load(
path: str, **configs: Unpack[_SaveLoadOptions]
path: str, **configs: Unpack[_LoadOptions]
) -> TranslatedLayer | PirTranslatedLayer:
"""
:api_attr: imperative
Expand Down