diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index c6b6abb56e..cfea1a501c 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -13,6 +13,8 @@ _SINGLE_INDENT = " " +_SMALL_TENSOR_SIZE = 4 + kwlist = { "False", "None", @@ -119,7 +121,7 @@ def renamer(name): def _translate_type(onnx_type): """Converts a onnx type into a type defined by *onnxscript*.""" - return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type) + return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type, reversible=False) def _translate_signature(inputs, outputs): @@ -350,25 +352,33 @@ def _translate_graph_body(self, graph, opsets, indent=0): if hasattr(graph, "initializer"): for init in graph.initializer: if self.skip_initializers: - init_py_name = self._translate_onnx_var(init.name) - if init_py_name in self.skipped_initializers: - raise RuntimeError( - f"Initializer {init.name!r} is already present in skipped_initializers." - ) - self.skipped_initializers[init_py_name] = init - continue + size = 1 + for d in init.dims: + size *= d + if size > _SMALL_TENSOR_SIZE: + init_py_name = self._translate_onnx_var(init.name) + if init_py_name in self.skipped_initializers: + raise RuntimeError( + f"Initializer {init.name!r} is already present in skipped_initializers." + ) + self.skipped_initializers[init_py_name] = init + continue node = onnx.helper.make_node( # noqa: TID251 "Constant", [], [self._translate_onnx_var(init.name)], # type: ignore[list-item] value=init, ) - code.append(self._translate_node(node, opsets, indent=indent)) + pyinit = self._translate_node(node, opsets, indent=indent) + if pyinit: + code.append(pyinit) if hasattr(graph, "sparse_initializer") and len(graph.sparse_initializer) > 0: raise NotImplementedError("Unable to convert sparse_initilizer into python.") for node in graph.node: pynode = self._translate_node(node, opsets, indent=indent) if pynode: + if node.name: + pynode += f" # {node.name}" code.append(pynode) final = "\n".join(code) @@ -418,7 +428,8 @@ def _translate_attributes(self, node): def _translate_if(self, node, opsets, indent=0): """Translates a node If into python.""" sindent = _SINGLE_INDENT * indent - code = [f"{sindent}if {node.input[0]}:"] + cond = self._translate_onnx_var_ref(node.input[0]) + code = [f"{sindent}if {cond}:"] if len(node.attribute) != 2: raise RuntimeError( f"Node {node.op_type!r} expected two attributes not {len(node.attribute)}." @@ -502,17 +513,21 @@ def _translate_loop(self, node, opsets, indent=0): rows.extend(self._emit_assign(formal_ins, actual_ins, indent)) + if node.name: + node_name = " # " + node.name + else: + node_name = "" if use_iter_var and not use_loop_cond: - rows.append(f"{sindent}for {iter_var} in range({n_iter}):") + rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}") # The following is a hacky way to suppress the generation of # "cond_out = cond_in", which ONNX forces for a FOR loop. # TODO: a cleaner solution for this. self._name_remappings[-1][cond_out] = self._translate_onnx_var(cond_in) elif not use_iter_var and use_loop_cond: - rows.append(f"{sindent}while {py_cond}:") + rows.append(f"{sindent}while {py_cond}:{node_name}") elif use_iter_var and use_loop_cond: # TODO: This needs fixing - rows.append(f"{sindent}for {iter_var} in range({n_iter}):") + rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}") rows.append(f"{sindent}{_SINGLE_INDENT}if not {py_cond}:") rows.append(f"{sindent}{_SINGLE_INDENT * 2}break") else: @@ -734,11 +749,13 @@ def _substitute_initializers( def generate_rand(name: str, value: TensorProto) -> str: shape = ",".join(str(d) for d in value.dims) - if value.data_type != TensorProto.FLOAT: - raise NotImplementedError( - f"Unable to generate random initializer for data type {value.data_type}." - ) - return f"{__}{name} = np.random.rand({shape}).astype(np.float32)" + if value.data_type == TensorProto.FLOAT: + return f"{__}{name} = np.random.rand({shape}).astype(np.float32)" + if value.data_type == TensorProto.INT8: + return f"{__}{name} = np.random.randint(-128, 127, size=({shape},), dtype=np.int8)" + raise NotImplementedError( + f"Unable to generate random initializer for data type {value.data_type}." + ) random_initializer_values = "\n".join( generate_rand(key, value) for key, value in self.skipped_initializers.items() diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index 2c1655024c..edbed36a37 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -196,11 +196,13 @@ class FLOAT4E2M1(TensorType, dtype=ir.DataType.FLOAT4E2M1): pass -def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: +def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto, *, reversible: bool = True) -> str: """Converts an onnx type into the string representation of the type in *onnxscript*. Args: onnx_type: an instance of onnx TypeProto + reversible: if True, the conversion produces only types that are + recognized by the onnxscript converter. Returns: The string representation of the type in onnxscript @@ -224,6 +226,10 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str: return name return f"{name}[{','.join(shape)}]" return f"{name}[...]" + if not reversible: + if onnx_type.HasField("sequence_type"): + elem_type = onnx_type.sequence_type.elem_type + return f"List[{onnx_type_to_onnxscript_repr(elem_type)}]" raise NotImplementedError(f"Unable to translate type {onnx_type!r} into onnxscript type.")