Skip to content

Commit

Permalink
[Typing][B-52] Add type annotations for python/paddle/onnx/export.py (
Browse files Browse the repository at this point in the history
#66862)


---------

Co-authored-by: SigureMo <sigure.qaq@gmail.com>
  • Loading branch information
enkilee and SigureMo authored Aug 1, 2024
1 parent 2167fa5 commit 753329a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
9 changes: 5 additions & 4 deletions python/paddle/jit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@
)

if TYPE_CHECKING:
from paddle import Tensor
from paddle._typing import NestedStructure
from paddle.static import InputSpec

class _SaveOptions(TypedDict):
output_spec: NotRequired[Sequence[InputSpec]]
output_spec: NotRequired[Sequence[Tensor | int]]
with_hook: NotRequired[bool]
combine_params: NotRequired[bool]
clip_extra: NotRequired[bool]
Expand Down Expand Up @@ -880,7 +881,7 @@ def __call__(
self,
layer: Layer | Callable[..., Any],
path: str,
input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = ...,
input_spec: Sequence[InputSpec | Tensor | object] | None = ...,
**configs: Unpack[_SaveOptions],
) -> None:
...
Expand All @@ -891,7 +892,7 @@ def _run_save_pre_hooks(func: _SaveFunction) -> _SaveFunction:
def wrapper(
layer: Layer | Callable[..., Any],
path: str,
input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = None,
input_spec: Sequence[InputSpec | Tensor | object] | None = None,
**configs: Unpack[_SaveOptions],
) -> None:
global _save_pre_hooks
Expand Down Expand Up @@ -953,7 +954,7 @@ def _get_function_names_from_layer(layer: Layer) -> list[str]:
def save(
layer: Layer | Callable[..., Any],
path: str,
input_spec: Sequence[InputSpec | paddle.Tensor | object] | None = None,
input_spec: Sequence[InputSpec | Tensor | object] | None = None,
**configs: Unpack[_SaveOptions],
) -> None:
"""
Expand Down
25 changes: 21 additions & 4 deletions python/paddle/onnx/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Sequence

from paddle.utils import try_import

if TYPE_CHECKING:
from typing_extensions import Unpack

from paddle import Tensor
from paddle.jit.api import _SaveOptions
from paddle.nn import Layer
from paddle.static import InputSpec


__all__ = []


def export(layer, path, input_spec=None, opset_version=9, **configs):
def export(
layer: Layer,
path: str,
input_spec: Sequence[InputSpec | Tensor | object] | None = None,
opset_version: int = 9,
**configs: Unpack[_SaveOptions],
) -> None:
"""
Export Layer to ONNX format, which can use for inference via onnxruntime or other backends.
For more details, Please refer to `paddle2onnx <https://github.com/PaddlePaddle/paddle2onnx>`_ .
Expand All @@ -28,7 +45,7 @@ def export(layer, path, input_spec=None, opset_version=9, **configs):
layer (Layer): The Layer to be exported.
path (str): The path prefix to export model. The format is ``dirname/file_prefix`` or ``file_prefix`` ,
and the exported ONNX file suffix is ``.onnx`` .
input_spec (list[InputSpec|Tensor], optional): Describes the input of the exported model's forward
input_spec (list[InputSpec|Tensor]|None, optional): Describes the input of the exported model's forward
method, which can be described by InputSpec or example Tensor. If None, all input variables of
the original Layer's forward method would be the inputs of the exported ``ONNX`` model. Default: None.
opset_version(int, optional): Opset version of exported ONNX model.
Expand Down Expand Up @@ -62,7 +79,7 @@ def export(layer, path, input_spec=None, opset_version=9, **configs):
... x_spec = paddle.static.InputSpec(shape=[None, 128], dtype='float32')
... paddle.onnx.export(model, 'linear_net', input_spec=[x_spec])
...
>>> # doctest: +SKIP('raise ImportError')
>>> # doctest: +SKIP('Need install Paddle2ONNX')
>>> export_linear_net()
>>> class Logic(paddle.nn.Layer):
Expand All @@ -83,7 +100,7 @@ def export(layer, path, input_spec=None, opset_version=9, **configs):
... # Static and run model.
... paddle.jit.to_static(model)
... out = model(x, y, z=True)
... paddle.onnx.export(model, 'pruned', input_spec=[x, y, True], output_spec=[out], input_names_after_prune=[x])
... paddle.onnx.export(model, 'pruned', input_spec=[x, y, True], output_spec=[out], input_names_after_prune=[x.name])
...
>>> export_logic()
"""
Expand Down

0 comments on commit 753329a

Please sign in to comment.