Skip to content

Commit

Permalink
[Typing] Add type annotations for paddle.jit (#64867)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Jun 5, 2024
1 parent 64ba20d commit 0d686ca
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 82 deletions.
28 changes: 4 additions & 24 deletions python/paddle/_typing/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,10 @@
Numberic: TypeAlias = Union[int, float, complex, np.number, "Tensor"]
TensorLike: TypeAlias = Union[npt.NDArray[Any], "Tensor", Numberic]

_T = TypeVar("_T", bound=Numberic)
_SeqLevel1: TypeAlias = Sequence[_T]
_SeqLevel2: TypeAlias = Sequence[Sequence[_T]]
_SeqLevel3: TypeAlias = Sequence[Sequence[Sequence[_T]]]
_SeqLevel4: TypeAlias = Sequence[Sequence[Sequence[Sequence[_T]]]]
_SeqLevel5: TypeAlias = Sequence[Sequence[Sequence[Sequence[Sequence[_T]]]]]
_SeqLevel6: TypeAlias = Sequence[
Sequence[Sequence[Sequence[Sequence[Sequence[_T]]]]]
]

IntSequence: TypeAlias = _SeqLevel1[int]

NumbericSequence: TypeAlias = _SeqLevel1[Numberic]

NestedSequence: TypeAlias = Union[
_T,
_SeqLevel1[_T],
_SeqLevel2[_T],
_SeqLevel3[_T],
_SeqLevel4[_T],
_SeqLevel5[_T],
_SeqLevel6[_T],
]
_T = TypeVar("_T")

NestedSequence = Union[_T, Sequence["NestedSequence[_T]"]]
IntSequence = Sequence[int]
NumbericSequence = Sequence[Numberic]
NestedNumbericSequence: TypeAlias = NestedSequence[Numberic]

TensorOrTensors: TypeAlias = Union["Tensor", Sequence["Tensor"]]
13 changes: 7 additions & 6 deletions python/paddle/_typing/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,24 @@
# limitations under the License.
from __future__ import annotations

from typing import List, Tuple, Union
from typing import TYPE_CHECKING, List, Tuple, Union

from typing_extensions import TypeAlias

from .. import Tensor
if TYPE_CHECKING:
from .. import Tensor

DynamicShapeLike: TypeAlias = Union[
Tuple[Union[int, Tensor, None], ...],
List[Union[int, Tensor, None]],
Tensor,
Tuple[Union[int, "Tensor", None], ...],
List[Union[int, "Tensor", None]],
"Tensor",
]


ShapeLike: TypeAlias = Union[
Tuple[int, ...],
List[int],
Tensor,
"Tensor",
]


Expand Down
129 changes: 123 additions & 6 deletions python/paddle/jit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,21 @@
import types
import warnings
from collections import OrderedDict
from collections.abc import Callable, Sequence
from contextlib import contextmanager
from typing import Any
from types import ModuleType
from typing import (
Any,
Protocol,
TypedDict,
TypeVar,
overload,
)

from typing_extensions import Literal, NotRequired, ParamSpec, TypeAlias, Unpack

import paddle
from paddle._typing import NestedSequence
from paddle.base import core, dygraph
from paddle.base.compiler import (
BuildStrategy,
Expand All @@ -45,6 +56,7 @@
from paddle.base.wrapped_decorator import wrap_decorator
from paddle.framework import use_pir_api
from paddle.nn import Layer
from paddle.static import InputSpec
from paddle.static.io import save_inference_model
from paddle.utils.environments import (
BooleanEnvironmentVariable,
Expand All @@ -71,6 +83,11 @@

ENV_ENABLE_SOT = BooleanEnvironmentVariable("ENABLE_FALL_BACK", True)

_LayerT = TypeVar("_LayerT", bound=Layer)
_RetT = TypeVar("_RetT")
_InputT = ParamSpec("_InputT")
Backends: TypeAlias = Literal["CINN"]


@contextmanager
def sot_mode_guard(value: bool):
Expand Down Expand Up @@ -98,13 +115,13 @@ def copy_decorator_attrs(original_func, decorated_obj):
return decorated_obj


def ignore_module(modules: list[Any]):
def ignore_module(modules: list[ModuleType]) -> None:
"""
Adds modules that ignore transcription.
Builtin modules that have been ignored are collections, pdb, copy, inspect, re, numpy, logging, six
Args:
modules (List[Any]): Ignored modules that you want to add
modules (list[ModuleType]): Ignored modules that you want to add
Examples:
.. code-block:: python
Expand Down Expand Up @@ -133,6 +150,67 @@ def _check_and_set_backend(backend, build_strategy):
build_strategy.build_cinn_pass = True


class ToStaticOptions(TypedDict):
property: NotRequired[bool]
full_graph: NotRequired[bool]


class ToStaticDecorator(Protocol):
@overload
def __call__(self, function: _LayerT) -> _LayerT:
...

@overload
def __call__(
self, function: Callable[_InputT, _RetT]
) -> StaticFunction[_InputT, _RetT]:
...


@overload
def to_static(
function: _LayerT,
input_spec: NestedSequence[InputSpec] | None = ...,
build_strategy: BuildStrategy | None = ...,
backend: Backends | None = ...,
**kwargs: Unpack[ToStaticOptions],
) -> _LayerT:
...


@overload
def to_static(
function: Callable[_InputT, _RetT],
input_spec: NestedSequence[InputSpec] | None = ...,
build_strategy: BuildStrategy | None = ...,
backend: Backends | None = ...,
**kwargs: Unpack[ToStaticOptions],
) -> StaticFunction[_InputT, _RetT]:
...


@overload
def to_static(
function: Any,
input_spec: NestedSequence[InputSpec] | None = ...,
build_strategy: BuildStrategy | None = ...,
backend: Backends | None = ...,
**kwargs: Unpack[ToStaticOptions],
) -> Any:
...


@overload
def to_static(
function: None = ...,
input_spec: NestedSequence[InputSpec] | None = ...,
build_strategy: BuildStrategy | None = ...,
backend: Backends | None = ...,
**kwargs: Unpack[ToStaticOptions],
) -> ToStaticDecorator:
...


def to_static(
function=None,
input_spec=None,
Expand Down Expand Up @@ -254,6 +332,28 @@ def decorated(python_func):
return decorated


class NotToStaticDecorator(Protocol):
@overload
def __call__(
self, func: Callable[_InputT, _RetT]
) -> Callable[_InputT, _RetT]:
...

@overload
def __call__(self, func: None = ...) -> NotToStaticDecorator:
...


@overload
def not_to_static(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
...


@overload
def not_to_static(func: None = ...) -> NotToStaticDecorator:
...


def not_to_static(func=None):
"""
A Decorator to suppresses the convention of a function.
Expand Down Expand Up @@ -393,7 +493,17 @@ def keep_name_table(self, value):
self._keep_name_table = value


def _parse_save_configs(configs):
class _SaveLoadOptions(TypedDict):
output_spec: NotRequired[Sequence[InputSpec]]
with_hook: NotRequired[bool]
combine_params: NotRequired[bool]
clip_extra: NotRequired[bool]
skip_forward: NotRequired[bool]
input_names_after_prune: NotRequired[list[str]]
skip_prune_program: NotRequired[bool]


def _parse_save_configs(configs: _SaveLoadOptions):
supported_configs = [
"output_spec",
"with_hook",
Expand Down Expand Up @@ -795,7 +905,12 @@ def set_property(meta, key, val):

@_run_save_pre_hooks
@switch_to_static_graph
def save(layer, path, input_spec=None, **configs):
def save(
layer: Callable[_InputT, _RetT],
path: str,
input_spec: InputSpec | None = None,
**configs: Unpack[_SaveLoadOptions],
) -> None:
"""
Saves input Layer or function as ``paddle.jit.TranslatedLayer``
format model, which can be used for inference or fine-tuning after loading.
Expand Down Expand Up @@ -1362,7 +1477,9 @@ def save(layer, path, input_spec=None, **configs):


@dygraph_only
def load(path, **configs):
def load(
path: str, **configs: Unpack[_SaveLoadOptions]
) -> TranslatedLayer | PirTranslatedLayer:
"""
:api_attr: imperative
Expand Down
8 changes: 5 additions & 3 deletions python/paddle/jit/dy2static/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _output_to_stdout(self, msg, *args):
_TRANSLATOR_LOGGER = TranslatorLogger()


def set_verbosity(level=0, also_to_stdout=False):
def set_verbosity(level: int = 0, also_to_stdout: bool = False) -> None:
"""
Sets the verbosity level of log for dygraph to static graph. Logs can be output to stdout by setting `also_to_stdout`.
Expand Down Expand Up @@ -215,11 +215,13 @@ def set_verbosity(level=0, also_to_stdout=False):
_TRANSLATOR_LOGGER.need_to_echo_log_to_stdout = also_to_stdout


def get_verbosity():
def get_verbosity() -> int:
return _TRANSLATOR_LOGGER.verbosity_level


def set_code_level(level=LOG_AllTransformer, also_to_stdout=False):
def set_code_level(
level: int = LOG_AllTransformer, also_to_stdout: bool = False
) -> None:
"""
Sets the level to print code from specific level Ast Transformer. Code can be output to stdout by setting `also_to_stdout`.
Expand Down
Loading

0 comments on commit 0d686ca

Please sign in to comment.