Skip to content

Commit f5b58e0

Browse files
authored
Minor fixes to onnx to onnxscript converter (#2510)
Minor fixes to onnx to onnxscript converter: * Embed node name as a comment in generated onnxscript * Handle sequence types (just for readable representation) * Minor tweak to handling of initializers, distinguishing small and large tensors (when replacing them by random values) * Handle INT8 type initializers, which show up in quantized models. --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent fce51b6 commit f5b58e0

File tree

2 files changed

+42
-19
lines changed

2 files changed

+42
-19
lines changed

onnxscript/backend/onnx_export.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
_SINGLE_INDENT = " "
1515

16+
_SMALL_TENSOR_SIZE = 4
17+
1618
kwlist = {
1719
"False",
1820
"None",
@@ -119,7 +121,7 @@ def renamer(name):
119121

120122
def _translate_type(onnx_type):
121123
"""Converts a onnx type into a type defined by *onnxscript*."""
122-
return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type)
124+
return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type, reversible=False)
123125

124126

125127
def _translate_signature(inputs, outputs):
@@ -350,25 +352,33 @@ def _translate_graph_body(self, graph, opsets, indent=0):
350352
if hasattr(graph, "initializer"):
351353
for init in graph.initializer:
352354
if self.skip_initializers:
353-
init_py_name = self._translate_onnx_var(init.name)
354-
if init_py_name in self.skipped_initializers:
355-
raise RuntimeError(
356-
f"Initializer {init.name!r} is already present in skipped_initializers."
357-
)
358-
self.skipped_initializers[init_py_name] = init
359-
continue
355+
size = 1
356+
for d in init.dims:
357+
size *= d
358+
if size > _SMALL_TENSOR_SIZE:
359+
init_py_name = self._translate_onnx_var(init.name)
360+
if init_py_name in self.skipped_initializers:
361+
raise RuntimeError(
362+
f"Initializer {init.name!r} is already present in skipped_initializers."
363+
)
364+
self.skipped_initializers[init_py_name] = init
365+
continue
360366
node = onnx.helper.make_node( # noqa: TID251
361367
"Constant",
362368
[],
363369
[self._translate_onnx_var(init.name)], # type: ignore[list-item]
364370
value=init,
365371
)
366-
code.append(self._translate_node(node, opsets, indent=indent))
372+
pyinit = self._translate_node(node, opsets, indent=indent)
373+
if pyinit:
374+
code.append(pyinit)
367375
if hasattr(graph, "sparse_initializer") and len(graph.sparse_initializer) > 0:
368376
raise NotImplementedError("Unable to convert sparse_initilizer into python.")
369377
for node in graph.node:
370378
pynode = self._translate_node(node, opsets, indent=indent)
371379
if pynode:
380+
if node.name:
381+
pynode += f" # {node.name}"
372382
code.append(pynode)
373383

374384
final = "\n".join(code)
@@ -418,7 +428,8 @@ def _translate_attributes(self, node):
418428
def _translate_if(self, node, opsets, indent=0):
419429
"""Translates a node If into python."""
420430
sindent = _SINGLE_INDENT * indent
421-
code = [f"{sindent}if {node.input[0]}:"]
431+
cond = self._translate_onnx_var_ref(node.input[0])
432+
code = [f"{sindent}if {cond}:"]
422433
if len(node.attribute) != 2:
423434
raise RuntimeError(
424435
f"Node {node.op_type!r} expected two attributes not {len(node.attribute)}."
@@ -502,17 +513,21 @@ def _translate_loop(self, node, opsets, indent=0):
502513

503514
rows.extend(self._emit_assign(formal_ins, actual_ins, indent))
504515

516+
if node.name:
517+
node_name = " # " + node.name
518+
else:
519+
node_name = ""
505520
if use_iter_var and not use_loop_cond:
506-
rows.append(f"{sindent}for {iter_var} in range({n_iter}):")
521+
rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}")
507522
# The following is a hacky way to suppress the generation of
508523
# "cond_out = cond_in", which ONNX forces for a FOR loop.
509524
# TODO: a cleaner solution for this.
510525
self._name_remappings[-1][cond_out] = self._translate_onnx_var(cond_in)
511526
elif not use_iter_var and use_loop_cond:
512-
rows.append(f"{sindent}while {py_cond}:")
527+
rows.append(f"{sindent}while {py_cond}:{node_name}")
513528
elif use_iter_var and use_loop_cond:
514529
# TODO: This needs fixing
515-
rows.append(f"{sindent}for {iter_var} in range({n_iter}):")
530+
rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}")
516531
rows.append(f"{sindent}{_SINGLE_INDENT}if not {py_cond}:")
517532
rows.append(f"{sindent}{_SINGLE_INDENT * 2}break")
518533
else:
@@ -734,11 +749,13 @@ def _substitute_initializers(
734749

735750
def generate_rand(name: str, value: TensorProto) -> str:
736751
shape = ",".join(str(d) for d in value.dims)
737-
if value.data_type != TensorProto.FLOAT:
738-
raise NotImplementedError(
739-
f"Unable to generate random initializer for data type {value.data_type}."
740-
)
741-
return f"{__}{name} = np.random.rand({shape}).astype(np.float32)"
752+
if value.data_type == TensorProto.FLOAT:
753+
return f"{__}{name} = np.random.rand({shape}).astype(np.float32)"
754+
if value.data_type == TensorProto.INT8:
755+
return f"{__}{name} = np.random.randint(-128, 127, size=({shape},), dtype=np.int8)"
756+
raise NotImplementedError(
757+
f"Unable to generate random initializer for data type {value.data_type}."
758+
)
742759

743760
random_initializer_values = "\n".join(
744761
generate_rand(key, value) for key, value in self.skipped_initializers.items()

onnxscript/onnx_types.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,13 @@ class FLOAT4E2M1(TensorType, dtype=ir.DataType.FLOAT4E2M1):
196196
pass
197197

198198

199-
def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:
199+
def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto, *, reversible: bool = True) -> str:
200200
"""Converts an onnx type into the string representation of the type in *onnxscript*.
201201
202202
Args:
203203
onnx_type: an instance of onnx TypeProto
204+
reversible: if True, the conversion produces only types that are
205+
recognized by the onnxscript converter.
204206
205207
Returns:
206208
The string representation of the type in onnxscript
@@ -224,6 +226,10 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:
224226
return name
225227
return f"{name}[{','.join(shape)}]"
226228
return f"{name}[...]"
229+
if not reversible:
230+
if onnx_type.HasField("sequence_type"):
231+
elem_type = onnx_type.sequence_type.elem_type
232+
return f"List[{onnx_type_to_onnxscript_repr(elem_type)}]"
227233
raise NotImplementedError(f"Unable to translate type {onnx_type!r} into onnxscript type.")
228234

229235

0 commit comments

Comments
 (0)