diff --git a/onnxscript/main.py b/onnxscript/main.py index bfcbf0bc4b..7407baedd1 100644 --- a/onnxscript/main.py +++ b/onnxscript/main.py @@ -6,14 +6,18 @@ import ast import inspect import sys -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence, TypeVar import onnx.helper +from typing_extensions import ParamSpec import onnxscript from onnxscript import converter, irbuilder, values from onnxscript._internal import ast_utils +_R = TypeVar("_R") +_P = ParamSpec("_P") + def script_check( f: ast.FunctionDef, @@ -39,7 +43,7 @@ def script( opset: Optional[values.Opset] = None, default_opset: Optional[values.Opset] = None, **kwargs: Any, -) -> Callable[[Callable], onnxscript.OnnxFunction]: +) -> Callable[[Callable[_P, _R]], onnxscript.OnnxFunction[_P, _R]]: """Main decorator. Declares a function as an onnx function. Args: @@ -75,7 +79,7 @@ def log2(x): "Script parameter must be an opset. Did you use @script instead of @script()?" ) - def transform(f: Callable) -> onnxscript.OnnxFunction: + def transform(f: Callable[_P, _R]) -> onnxscript.OnnxFunction[_P, _R]: if not inspect.isfunction(f): raise TypeError("The ONNXScript decorator should be applied to functions only.") diff --git a/onnxscript/values.py b/onnxscript/values.py index 9907b16ee4..d748dc6e64 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -13,20 +13,27 @@ Any, Callable, ClassVar, + Generic, Optional, Protocol, Sequence, + TypeVar, _GenericAlias, ) import onnx import onnx.defs +from typing_extensions import ParamSpec from onnxscript import converter as converter_module from onnxscript import irbuilder, sourceinfo, type_annotation from onnxscript._internal import ast_utils, deprecation from onnxscript.ir import _schemas +_R = TypeVar("_R") +_P = ParamSpec("_P") + + _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { onnx.defs.OpSchema.AttrType.FLOAT: float, onnx.defs.OpSchema.AttrType.INT: int, @@ -464,7 +471,7 @@ def _op_schema_from_function_ir( ) -class OnnxFunction(Op): +class OnnxFunction(Op, Generic[_P, _R]): """Represents an ONNX op for which a function-body has been defined in onnxscript. Attributes: @@ -566,12 +573,12 @@ def fun(*args, **kwargs): return fun - def __call__(self, *args, **kwargs): + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: """Implements an eager-mode execution of an onnxscript function.""" # FIXME(after #225): Move import to the top of the file. from onnxscript import evaluator # pylint: disable=import-outside-toplevel - return evaluator.default().eval_function(self, args, kwargs) + return evaluator.default().eval_function(self, args, kwargs) # type: ignore[arg-type, return-value] def __repr__(self) -> str: return f"{self.__class__.__name__}({self.function!r})"