Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 54 additions & 24 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# --------------------------------------------------------------------------
from __future__ import annotations

from typing import Any, Optional
from typing import Any, Optional, Sequence

import numpy
import onnx
from onnx import FunctionProto, ModelProto, TensorProto, ValueInfoProto
from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto
from onnx.helper import make_node

import onnxscript.onnx_types
Expand All @@ -23,11 +23,9 @@
{% if unique_types %}
from onnxscript.onnx_types import {{ ", ".join(unique_types) }}
{%- endif %}
from onnxscript.onnx_opset import opset{{ opsets[''] }}
{% for domain, version in unique_function_domain_version: %}
{{ domain }}{{ version }} = Opset("{{ domain }}", {{ version }}){% endfor %}
{{translate_opset_imports_of(main_model)}}
{% for domain, name, fct in functions: %}
@script({{ domain }}1)
@script({{make_opset_name(domain, 1)}})
def {{ python_make_node_name(fct['proto'].domain, 1, fct['proto'].name) }}{{
translate_function_signature(fct['proto'])}}
{% if fct['proto'].doc_string %}"""
Expand Down Expand Up @@ -201,22 +199,6 @@ def _attribute_value(attr: onnx.AttributeProto):
raise NotImplementedError(f"Unable to return a value for attribute {attr!r}.")


def _python_make_node_name(domain, version, name, node=False):
name = _rename_variable(name)
if node:
if version is None:
version = 1
if not isinstance(version, int):
raise TypeError(
f"version must be an integer not {version!r} for domain={domain!r} "
f"and name={name!r}."
)
if domain == "":
return f"opset{version}.{name}"
return f"{domain.replace('.', '_')}{version}.{name}"
return name


class Exporter:
"""Class used for recursive traversal of Proto structures."""

Expand All @@ -230,6 +212,28 @@ def _rename_variable_s(self, name):
"""Renames all names equal to a python keyword."""
return str(self._rename_variable(name))

def _rename_domain(self, domain: str) -> str:
if domain == "":
return "opset"
return domain.replace(".", "_")

def make_opset_name(self, domain, version):
return f"{self._rename_domain(domain)}{version}"

def _python_make_node_name(self, domain, version, name, node=False):
name = _rename_variable(name)
if node:
if version is None:
version = 1
if not isinstance(version, int):
raise TypeError(
f"version must be an integer not {version!r} for domain={domain!r} "
f"and name={name!r}."
)
opset = self.make_opset_name(domain, version)
return f"{opset}.{name}"
return name

def _python_make_node_graph(self, graph, opsets, indent=0, output_names=None):
"""Translates a GraphProto into python."""
code = []
Expand Down Expand Up @@ -403,7 +407,7 @@ def _python_make_node(self, onnx_node, opsets, indent=0):
f"{sindent}{self._rename_variable(node.output[0])} = "
f"{(f' {ops[node.op_type]} ').join(map(self.lookup, node.input))}"
)
name = _python_make_node_name(
name = self._python_make_node_name(
node.domain, opsets[node.domain], node.op_type, node=True
)
attributes_str = self._python_make_node_make_attribute_str(node)
Expand All @@ -428,6 +432,29 @@ def _python_make_node(self, onnx_node, opsets, indent=0):
]
return "".join(text)

def translate_opset_import(self, domain: str, version: int) -> str:
if domain in {"", "ai.onnx"}:
return f"from onnxscript.onnx_opset import opset{version}\n"
else:
varname = self.make_opset_name(domain, version)
return f"{varname} = Opset('{domain}', {version})\n"

def translate_opset_imports(self, opset_imports: Sequence[onnx.OperatorSetIdProto]) -> str:
return "".join(
[self.translate_opset_import(x.domain, x.version) for x in opset_imports]
)

def translate_opset_imports_of(
self, proto: ModelProto | FunctionProto | GraphProto
) -> str:
if hasattr(proto, "opset_import"):
text = self.translate_opset_imports(proto.opset_import)
if isinstance(proto, FunctionProto):
if not any(x.domain == proto.domain for x in proto.opset_import):
text += self.translate_opset_import(proto.domain, 1)
return text
return ""

def translate_function_signature(self, funproto: onnx.FunctionProto) -> str:
"""Generate signature for FunctionProto."""
type_map = _attribute_param_types(funproto)
Expand Down Expand Up @@ -522,10 +549,13 @@ def rename_variable(name):
"main_model": model_onnx,
"python_make_node": exporter._python_make_node, # pylint: disable=protected-access
"python_make_node_graph": exporter._python_make_node_graph, # pylint: disable=protected-access
"python_make_node_name": _python_make_node_name, # pylint: disable=protected-access
"python_make_node_name": exporter._python_make_node_name, # pylint: disable=protected-access
"rename": rename_variable,
"translate_sig": _translate_signature,
"translate_function_signature": exporter.translate_function_signature,
"translate_opset_imports_of": exporter.translate_opset_imports_of,
"hasattr": hasattr,
"make_opset_name": exporter.make_opset_name,
}

# opset
Expand Down
18 changes: 15 additions & 3 deletions onnxscript/backend/onnx_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import onnxscript
import onnxscript.testing
import onnxscript.values
from onnxscript.backend import onnx_backend, onnx_export
from onnxscript.tests.models import type_double

Expand Down Expand Up @@ -136,7 +137,8 @@ class TestOnnxBackEnd(unittest.TestCase):
test_folder = root_folder / "tests" / "onnx_backend_test_code"
temp_folder = root_folder / "tests" / "export"

def _round_trip_check(self, proto, **export_options):
def _round_trip_check(self, script_function, **export_options):
proto = script_function.to_function_proto()
code = onnx_export.export2python(proto, **export_options)
map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder)
result_proto = map[proto.name]
Expand All @@ -150,8 +152,18 @@ def test_attr_ref(self):
def fun_with_attr_param(X, dtype: int):
return op.Cast(X, to=dtype)

fun_proto = fun_with_attr_param.to_function_proto()
self._round_trip_check(fun_proto)
self._round_trip_check(fun_with_attr_param)

def test_qualified_domain(self):
"""Test use of qualified domain name."""
op = onnxscript.opset17
custom_opset = onnxscript.values.Opset("my.domain.com", 1)

@onnxscript.script(custom_opset)
def twice(X):
return op.Add(X, X)

self._round_trip_check(twice)

def test_export2python(self):
proto = type_double.double_abs_subgraph.to_model_proto()
Expand Down