Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

_SINGLE_INDENT = " "

_SMALL_TENSOR_SIZE = 4

kwlist = {
"False",
"None",
Expand Down Expand Up @@ -119,7 +121,7 @@ def renamer(name):

def _translate_type(onnx_type):
"""Converts a onnx type into a type defined by *onnxscript*."""
return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type)
return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type, reversible=False)


def _translate_signature(inputs, outputs):
Expand Down Expand Up @@ -350,25 +352,33 @@ def _translate_graph_body(self, graph, opsets, indent=0):
if hasattr(graph, "initializer"):
for init in graph.initializer:
if self.skip_initializers:
init_py_name = self._translate_onnx_var(init.name)
if init_py_name in self.skipped_initializers:
raise RuntimeError(
f"Initializer {init.name!r} is already present in skipped_initializers."
)
self.skipped_initializers[init_py_name] = init
continue
size = 1
for d in init.dims:
size *= d
if size > _SMALL_TENSOR_SIZE:
init_py_name = self._translate_onnx_var(init.name)
if init_py_name in self.skipped_initializers:
raise RuntimeError(
f"Initializer {init.name!r} is already present in skipped_initializers."
)
self.skipped_initializers[init_py_name] = init
continue
node = onnx.helper.make_node( # noqa: TID251
"Constant",
[],
[self._translate_onnx_var(init.name)], # type: ignore[list-item]
value=init,
)
code.append(self._translate_node(node, opsets, indent=indent))
pyinit = self._translate_node(node, opsets, indent=indent)
if pyinit:
code.append(pyinit)
if hasattr(graph, "sparse_initializer") and len(graph.sparse_initializer) > 0:
raise NotImplementedError("Unable to convert sparse_initilizer into python.")
for node in graph.node:
pynode = self._translate_node(node, opsets, indent=indent)
if pynode:
if node.name:
pynode += f" # {node.name}"
code.append(pynode)

final = "\n".join(code)
Expand Down Expand Up @@ -418,7 +428,8 @@ def _translate_attributes(self, node):
def _translate_if(self, node, opsets, indent=0):
"""Translates a node If into python."""
sindent = _SINGLE_INDENT * indent
code = [f"{sindent}if {node.input[0]}:"]
cond = self._translate_onnx_var_ref(node.input[0])
code = [f"{sindent}if {cond}:"]
if len(node.attribute) != 2:
raise RuntimeError(
f"Node {node.op_type!r} expected two attributes not {len(node.attribute)}."
Expand Down Expand Up @@ -502,17 +513,21 @@ def _translate_loop(self, node, opsets, indent=0):

rows.extend(self._emit_assign(formal_ins, actual_ins, indent))

if node.name:
node_name = " # " + node.name
else:
node_name = ""
if use_iter_var and not use_loop_cond:
rows.append(f"{sindent}for {iter_var} in range({n_iter}):")
rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}")
# The following is a hacky way to suppress the generation of
# "cond_out = cond_in", which ONNX forces for a FOR loop.
# TODO: a cleaner solution for this.
self._name_remappings[-1][cond_out] = self._translate_onnx_var(cond_in)
elif not use_iter_var and use_loop_cond:
rows.append(f"{sindent}while {py_cond}:")
rows.append(f"{sindent}while {py_cond}:{node_name}")
elif use_iter_var and use_loop_cond:
# TODO: This needs fixing
rows.append(f"{sindent}for {iter_var} in range({n_iter}):")
rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}")
rows.append(f"{sindent}{_SINGLE_INDENT}if not {py_cond}:")
rows.append(f"{sindent}{_SINGLE_INDENT * 2}break")
else:
Expand Down Expand Up @@ -734,11 +749,13 @@ def _substitute_initializers(

def generate_rand(name: str, value: TensorProto) -> str:
shape = ",".join(str(d) for d in value.dims)
if value.data_type != TensorProto.FLOAT:
raise NotImplementedError(
f"Unable to generate random initializer for data type {value.data_type}."
)
return f"{__}{name} = np.random.rand({shape}).astype(np.float32)"
if value.data_type == TensorProto.FLOAT:
return f"{__}{name} = np.random.rand({shape}).astype(np.float32)"
if value.data_type == TensorProto.INT8:
return f"{__}{name} = np.random.randint(-128, 127, size=({shape},), dtype=np.int8)"
raise NotImplementedError(
f"Unable to generate random initializer for data type {value.data_type}."
)

random_initializer_values = "\n".join(
generate_rand(key, value) for key, value in self.skipped_initializers.items()
Expand Down
8 changes: 7 additions & 1 deletion onnxscript/onnx_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,13 @@ class FLOAT4E2M1(TensorType, dtype=ir.DataType.FLOAT4E2M1):
pass


def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:
def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto, *, reversible: bool = True) -> str:
"""Converts an onnx type into the string representation of the type in *onnxscript*.

Args:
onnx_type: an instance of onnx TypeProto
reversible: if True, the conversion produces only types that are
recognized by the onnxscript converter.

Returns:
The string representation of the type in onnxscript
Expand All @@ -224,6 +226,10 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:
return name
return f"{name}[{','.join(shape)}]"
return f"{name}[...]"
if not reversible:
if onnx_type.HasField("sequence_type"):
elem_type = onnx_type.sequence_type.elem_type
return f"List[{onnx_type_to_onnxscript_repr(elem_type)}]"
raise NotImplementedError(f"Unable to translate type {onnx_type!r} into onnxscript type.")


Expand Down
Loading