Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions onnxscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
import ast
import inspect
import sys
from typing import Any, Callable, Optional, Sequence
from typing import Any, Callable, Optional, Sequence, TypeVar

import onnx.helper
from typing_extensions import ParamSpec

import onnxscript
from onnxscript import converter, irbuilder, values
from onnxscript._internal import ast_utils

_R = TypeVar("_R")
_P = ParamSpec("_P")


def script_check(
f: ast.FunctionDef,
Expand All @@ -39,7 +43,7 @@ def script(
opset: Optional[values.Opset] = None,
default_opset: Optional[values.Opset] = None,
**kwargs: Any,
) -> Callable[[Callable], onnxscript.OnnxFunction]:
) -> Callable[[Callable[_P, _R]], onnxscript.OnnxFunction[_P, _R]]:
"""Main decorator. Declares a function as an onnx function.

Args:
Expand Down Expand Up @@ -75,7 +79,7 @@ def log2(x):
"Script parameter must be an opset. Did you use @script instead of @script()?"
)

def transform(f: Callable) -> onnxscript.OnnxFunction:
def transform(f: Callable[_P, _R]) -> onnxscript.OnnxFunction[_P, _R]:
if not inspect.isfunction(f):
raise TypeError("The ONNXScript decorator should be applied to functions only.")

Expand Down
13 changes: 10 additions & 3 deletions onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,27 @@
Any,
Callable,
ClassVar,
Generic,
Optional,
Protocol,
Sequence,
TypeVar,
_GenericAlias,
)

import onnx
import onnx.defs
from typing_extensions import ParamSpec

from onnxscript import converter as converter_module
from onnxscript import irbuilder, sourceinfo, type_annotation
from onnxscript._internal import ast_utils, deprecation
from onnxscript.ir import _schemas

_R = TypeVar("_R")
_P = ParamSpec("_P")


_ATTRIBUTE_TYPE_TO_PYTHON_TYPE = {
onnx.defs.OpSchema.AttrType.FLOAT: float,
onnx.defs.OpSchema.AttrType.INT: int,
Expand Down Expand Up @@ -464,7 +471,7 @@ def _op_schema_from_function_ir(
)


class OnnxFunction(Op):
class OnnxFunction(Op, Generic[_P, _R]):
"""Represents an ONNX op for which a function-body has been defined in onnxscript.

Attributes:
Expand Down Expand Up @@ -566,12 +573,12 @@ def fun(*args, **kwargs):

return fun

def __call__(self, *args, **kwargs):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""Implements an eager-mode execution of an onnxscript function."""
# FIXME(after #225): Move import to the top of the file.
from onnxscript import evaluator # pylint: disable=import-outside-toplevel

return evaluator.default().eval_function(self, args, kwargs)
return evaluator.default().eval_function(self, args, kwargs) # type: ignore[arg-type, return-value]

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.function!r})"
Expand Down
Loading