Skip to content

Commit 75ef3cb

Browse files
authored
Fix logic for converting np array to text (#2470)
In onnx2script, nan, inf etc. were converted to plain text, which causes evaluation to fail because they don't exist in the script. I updated the logic to replace them with np. values. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 68962aa commit 75ef3cb

File tree

2 files changed

+12
-17
lines changed

2 files changed

+12
-17
lines changed

onnxscript/backend/onnx_export.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import Any, Optional, Sequence
66

7-
import numpy
7+
import numpy as np
88
import onnx
99
from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto
1010

@@ -384,17 +384,17 @@ def _translate_attributes(self, node):
384384
if isinstance(value, str):
385385
attributes.append((at.name, f"{value!r}"))
386386
continue
387-
if isinstance(value, numpy.ndarray):
387+
if isinstance(value, np.ndarray):
388388
onnx_dtype = at.t.data_type
389389
if len(value.shape) == 0:
390390
text = (
391391
f'make_tensor("value", {onnx_dtype}, dims=[], '
392-
f"vals=[{value.tolist()!r}])"
392+
f"vals=[{repr(value.tolist()).replace('nan', 'np.nan').replace('inf', 'np.inf')}])"
393393
)
394394
else:
395395
text = (
396396
f'make_tensor("value", {onnx_dtype}, dims={list(value.shape)!r}, '
397-
f"vals={value.ravel().tolist()!r})"
397+
f"vals={repr(value.ravel().tolist()).replace('nan', 'np.nan').replace('inf', 'np.inf')})"
398398
)
399399
attributes.append((at.name, text))
400400
continue
@@ -738,7 +738,7 @@ def generate_rand(name: str, value: TensorProto) -> str:
738738
raise NotImplementedError(
739739
f"Unable to generate random initializer for data type {value.data_type}."
740740
)
741-
return f"{__}{name} = numpy.random.rand({shape}).astype(numpy.float32)"
741+
return f"{__}{name} = np.random.rand({shape}).astype(np.float32)"
742742

743743
random_initializer_values = "\n".join(
744744
generate_rand(key, value) for key, value in self.skipped_initializers.items()
@@ -793,7 +793,7 @@ def add(line: str) -> None:
793793
result.append(line)
794794

795795
# Generic imports.
796-
add("import numpy")
796+
add("import numpy as np")
797797
add("from onnx import TensorProto")
798798
add("from onnx.helper import make_tensor")
799799
add("from onnxscript import script, external_tensor")
@@ -873,11 +873,11 @@ def export2python(
873873
.. runpython::
874874
:showcode:
875875
:process:
876-
import numpy
876+
import numpy as np
877877
from sklearn.cluster import KMeans
878878
from mlprodict.onnx_conv import to_onnx
879879
from mlprodict.onnx_tools.onnx_export import export2python
880-
X = numpy.arange(20).reshape(10, 2).astype(numpy.float32)
880+
X = np.arange(20).reshape(10, 2).astype(np.float32)
881881
tr = KMeans(n_clusters=2)
882882
tr.fit(X)
883883
onx = to_onnx(tr, X, target_opset=14)

onnxscript/backend/onnx_export_test.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,8 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True):
4545

4646

4747
SKIP_TESTS = (
48-
skip(
49-
r"^test_ai_onnx_ml_array_feature_extractor",
50-
"ImportError: cannot import name 'opset' from 'onnxscript.onnx_opset'",
51-
),
52-
skip(
53-
r"^test_ai_onnx_ml_binarizer",
54-
"ImportError: cannot import name 'opset' from 'onnxscript.onnx_opset'",
55-
),
48+
skip(r"^test_ai_onnx_ml_array_feature_extractor", "ORT doesn't support this op"),
49+
skip(r"^test_ai_onnx_ml_binarizer", "ORT doesn't support this op"),
5650
skip(r"^test_center_crop_pad_crop_negative_axes_hwc", "fixme: ORT segfaults"),
5751
skip(r"_scan_", "Operator Scan is not supported by onnxscript"),
5852
skip(r"^test_scan", "Operator Scan is not supported by onnxscript"),
@@ -89,6 +83,7 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True):
8983
"Change when the converter supports support something like 'while i < n and cond:'",
9084
),
9185
skip(r"^test_ai_onnx_ml_label_encoder", "ONNX Runtime does not support Opset 21 at 1.17"),
86+
skip(r"^test_ai_onnx_ml_tree_ensemble", "Opset 23 is not supported"),
9287
)
9388

9489
if sys.platform == "win32":
@@ -160,7 +155,7 @@ class TestOnnxBackEnd(unittest.TestCase):
160155
test_folder = root_folder / "tests" / "onnx_backend_test_code"
161156
temp_folder = root_folder / "tests" / "export"
162157

163-
def _proto_to_os_and_back(self, proto: onnxscript.FunctionProto, **export_options):
158+
def _proto_to_os_and_back(self, proto: onnx.FunctionProto, **export_options):
164159
"""Convert a proto to onnxscript code and convert it back to a proto."""
165160
code = onnx_export.export2python(proto, **export_options)
166161
map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder)

0 commit comments

Comments
 (0)