diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index 9540342968..063ded7f0d 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -4,11 +4,11 @@ # -------------------------------------------------------------------------- from __future__ import annotations -from typing import Any, Optional +from typing import Any, Optional, Sequence import numpy import onnx -from onnx import FunctionProto, ModelProto, TensorProto, ValueInfoProto +from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto from onnx.helper import make_node import onnxscript.onnx_types @@ -23,11 +23,9 @@ {% if unique_types %} from onnxscript.onnx_types import {{ ", ".join(unique_types) }} {%- endif %} -from onnxscript.onnx_opset import opset{{ opsets[''] }} -{% for domain, version in unique_function_domain_version: %} -{{ domain }}{{ version }} = Opset("{{ domain }}", {{ version }}){% endfor %} +{{translate_opset_imports_of(main_model)}} {% for domain, name, fct in functions: %} -@script({{ domain }}1) +@script({{make_opset_name(domain, 1)}}) def {{ python_make_node_name(fct['proto'].domain, 1, fct['proto'].name) }}{{ translate_function_signature(fct['proto'])}} {% if fct['proto'].doc_string %}""" @@ -201,22 +199,6 @@ def _attribute_value(attr: onnx.AttributeProto): raise NotImplementedError(f"Unable to return a value for attribute {attr!r}.") -def _python_make_node_name(domain, version, name, node=False): - name = _rename_variable(name) - if node: - if version is None: - version = 1 - if not isinstance(version, int): - raise TypeError( - f"version must be an integer not {version!r} for domain={domain!r} " - f"and name={name!r}." - ) - if domain == "": - return f"opset{version}.{name}" - return f"{domain.replace('.', '_')}{version}.{name}" - return name - - class Exporter: """Class used for recursive traversal of Proto structures.""" @@ -230,6 +212,28 @@ def _rename_variable_s(self, name): """Renames all names equal to a python keyword.""" return str(self._rename_variable(name)) + def _rename_domain(self, domain: str) -> str: + if domain == "": + return "opset" + return domain.replace(".", "_") + + def make_opset_name(self, domain, version): + return f"{self._rename_domain(domain)}{version}" + + def _python_make_node_name(self, domain, version, name, node=False): + name = _rename_variable(name) + if node: + if version is None: + version = 1 + if not isinstance(version, int): + raise TypeError( + f"version must be an integer not {version!r} for domain={domain!r} " + f"and name={name!r}." + ) + opset = self.make_opset_name(domain, version) + return f"{opset}.{name}" + return name + def _python_make_node_graph(self, graph, opsets, indent=0, output_names=None): """Translates a GraphProto into python.""" code = [] @@ -403,7 +407,7 @@ def _python_make_node(self, onnx_node, opsets, indent=0): f"{sindent}{self._rename_variable(node.output[0])} = " f"{(f' {ops[node.op_type]} ').join(map(self.lookup, node.input))}" ) - name = _python_make_node_name( + name = self._python_make_node_name( node.domain, opsets[node.domain], node.op_type, node=True ) attributes_str = self._python_make_node_make_attribute_str(node) @@ -428,6 +432,29 @@ def _python_make_node(self, onnx_node, opsets, indent=0): ] return "".join(text) + def translate_opset_import(self, domain: str, version: int) -> str: + if domain in {"", "ai.onnx"}: + return f"from onnxscript.onnx_opset import opset{version}\n" + else: + varname = self.make_opset_name(domain, version) + return f"{varname} = Opset('{domain}', {version})\n" + + def translate_opset_imports(self, opset_imports: Sequence[onnx.OperatorSetIdProto]) -> str: + return "".join( + [self.translate_opset_import(x.domain, x.version) for x in opset_imports] + ) + + def translate_opset_imports_of( + self, proto: ModelProto | FunctionProto | GraphProto + ) -> str: + if hasattr(proto, "opset_import"): + text = self.translate_opset_imports(proto.opset_import) + if isinstance(proto, FunctionProto): + if not any(x.domain == proto.domain for x in proto.opset_import): + text += self.translate_opset_import(proto.domain, 1) + return text + return "" + def translate_function_signature(self, funproto: onnx.FunctionProto) -> str: """Generate signature for FunctionProto.""" type_map = _attribute_param_types(funproto) @@ -522,10 +549,13 @@ def rename_variable(name): "main_model": model_onnx, "python_make_node": exporter._python_make_node, # pylint: disable=protected-access "python_make_node_graph": exporter._python_make_node_graph, # pylint: disable=protected-access - "python_make_node_name": _python_make_node_name, # pylint: disable=protected-access + "python_make_node_name": exporter._python_make_node_name, # pylint: disable=protected-access "rename": rename_variable, "translate_sig": _translate_signature, "translate_function_signature": exporter.translate_function_signature, + "translate_opset_imports_of": exporter.translate_opset_imports_of, + "hasattr": hasattr, + "make_opset_name": exporter.make_opset_name, } # opset diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index 1aeb85cdc3..efaa6c62d2 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -17,6 +17,7 @@ import onnxscript import onnxscript.testing +import onnxscript.values from onnxscript.backend import onnx_backend, onnx_export from onnxscript.tests.models import type_double @@ -136,7 +137,8 @@ class TestOnnxBackEnd(unittest.TestCase): test_folder = root_folder / "tests" / "onnx_backend_test_code" temp_folder = root_folder / "tests" / "export" - def _round_trip_check(self, proto, **export_options): + def _round_trip_check(self, script_function, **export_options): + proto = script_function.to_function_proto() code = onnx_export.export2python(proto, **export_options) map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder) result_proto = map[proto.name] @@ -150,8 +152,18 @@ def test_attr_ref(self): def fun_with_attr_param(X, dtype: int): return op.Cast(X, to=dtype) - fun_proto = fun_with_attr_param.to_function_proto() - self._round_trip_check(fun_proto) + self._round_trip_check(fun_with_attr_param) + + def test_qualified_domain(self): + """Test use of qualified domain name.""" + op = onnxscript.opset17 + custom_opset = onnxscript.values.Opset("my.domain.com", 1) + + @onnxscript.script(custom_opset) + def twice(X): + return op.Add(X, X) + + self._round_trip_check(twice) def test_export2python(self): proto = type_double.double_abs_subgraph.to_model_proto()