|
13 | 13 |
|
14 | 14 | _SINGLE_INDENT = " "
|
15 | 15 |
|
| 16 | +_SMALL_TENSOR_SIZE = 4 |
| 17 | + |
16 | 18 | kwlist = {
|
17 | 19 | "False",
|
18 | 20 | "None",
|
@@ -119,7 +121,7 @@ def renamer(name):
|
119 | 121 |
|
120 | 122 | def _translate_type(onnx_type):
|
121 | 123 | """Converts a onnx type into a type defined by *onnxscript*."""
|
122 |
| - return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type) |
| 124 | + return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type, reversible=False) |
123 | 125 |
|
124 | 126 |
|
125 | 127 | def _translate_signature(inputs, outputs):
|
@@ -350,25 +352,33 @@ def _translate_graph_body(self, graph, opsets, indent=0):
|
350 | 352 | if hasattr(graph, "initializer"):
|
351 | 353 | for init in graph.initializer:
|
352 | 354 | if self.skip_initializers:
|
353 |
| - init_py_name = self._translate_onnx_var(init.name) |
354 |
| - if init_py_name in self.skipped_initializers: |
355 |
| - raise RuntimeError( |
356 |
| - f"Initializer {init.name!r} is already present in skipped_initializers." |
357 |
| - ) |
358 |
| - self.skipped_initializers[init_py_name] = init |
359 |
| - continue |
| 355 | + size = 1 |
| 356 | + for d in init.dims: |
| 357 | + size *= d |
| 358 | + if size > _SMALL_TENSOR_SIZE: |
| 359 | + init_py_name = self._translate_onnx_var(init.name) |
| 360 | + if init_py_name in self.skipped_initializers: |
| 361 | + raise RuntimeError( |
| 362 | + f"Initializer {init.name!r} is already present in skipped_initializers." |
| 363 | + ) |
| 364 | + self.skipped_initializers[init_py_name] = init |
| 365 | + continue |
360 | 366 | node = onnx.helper.make_node( # noqa: TID251
|
361 | 367 | "Constant",
|
362 | 368 | [],
|
363 | 369 | [self._translate_onnx_var(init.name)], # type: ignore[list-item]
|
364 | 370 | value=init,
|
365 | 371 | )
|
366 |
| - code.append(self._translate_node(node, opsets, indent=indent)) |
| 372 | + pyinit = self._translate_node(node, opsets, indent=indent) |
| 373 | + if pyinit: |
| 374 | + code.append(pyinit) |
367 | 375 | if hasattr(graph, "sparse_initializer") and len(graph.sparse_initializer) > 0:
|
368 | 376 | raise NotImplementedError("Unable to convert sparse_initilizer into python.")
|
369 | 377 | for node in graph.node:
|
370 | 378 | pynode = self._translate_node(node, opsets, indent=indent)
|
371 | 379 | if pynode:
|
| 380 | + if node.name: |
| 381 | + pynode += f" # {node.name}" |
372 | 382 | code.append(pynode)
|
373 | 383 |
|
374 | 384 | final = "\n".join(code)
|
@@ -418,7 +428,8 @@ def _translate_attributes(self, node):
|
418 | 428 | def _translate_if(self, node, opsets, indent=0):
|
419 | 429 | """Translates a node If into python."""
|
420 | 430 | sindent = _SINGLE_INDENT * indent
|
421 |
| - code = [f"{sindent}if {node.input[0]}:"] |
| 431 | + cond = self._translate_onnx_var_ref(node.input[0]) |
| 432 | + code = [f"{sindent}if {cond}:"] |
422 | 433 | if len(node.attribute) != 2:
|
423 | 434 | raise RuntimeError(
|
424 | 435 | f"Node {node.op_type!r} expected two attributes not {len(node.attribute)}."
|
@@ -502,17 +513,21 @@ def _translate_loop(self, node, opsets, indent=0):
|
502 | 513 |
|
503 | 514 | rows.extend(self._emit_assign(formal_ins, actual_ins, indent))
|
504 | 515 |
|
| 516 | + if node.name: |
| 517 | + node_name = " # " + node.name |
| 518 | + else: |
| 519 | + node_name = "" |
505 | 520 | if use_iter_var and not use_loop_cond:
|
506 |
| - rows.append(f"{sindent}for {iter_var} in range({n_iter}):") |
| 521 | + rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}") |
507 | 522 | # The following is a hacky way to suppress the generation of
|
508 | 523 | # "cond_out = cond_in", which ONNX forces for a FOR loop.
|
509 | 524 | # TODO: a cleaner solution for this.
|
510 | 525 | self._name_remappings[-1][cond_out] = self._translate_onnx_var(cond_in)
|
511 | 526 | elif not use_iter_var and use_loop_cond:
|
512 |
| - rows.append(f"{sindent}while {py_cond}:") |
| 527 | + rows.append(f"{sindent}while {py_cond}:{node_name}") |
513 | 528 | elif use_iter_var and use_loop_cond:
|
514 | 529 | # TODO: This needs fixing
|
515 |
| - rows.append(f"{sindent}for {iter_var} in range({n_iter}):") |
| 530 | + rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}") |
516 | 531 | rows.append(f"{sindent}{_SINGLE_INDENT}if not {py_cond}:")
|
517 | 532 | rows.append(f"{sindent}{_SINGLE_INDENT * 2}break")
|
518 | 533 | else:
|
@@ -734,11 +749,13 @@ def _substitute_initializers(
|
734 | 749 |
|
735 | 750 | def generate_rand(name: str, value: TensorProto) -> str:
|
736 | 751 | shape = ",".join(str(d) for d in value.dims)
|
737 |
| - if value.data_type != TensorProto.FLOAT: |
738 |
| - raise NotImplementedError( |
739 |
| - f"Unable to generate random initializer for data type {value.data_type}." |
740 |
| - ) |
741 |
| - return f"{__}{name} = np.random.rand({shape}).astype(np.float32)" |
| 752 | + if value.data_type == TensorProto.FLOAT: |
| 753 | + return f"{__}{name} = np.random.rand({shape}).astype(np.float32)" |
| 754 | + if value.data_type == TensorProto.INT8: |
| 755 | + return f"{__}{name} = np.random.randint(-128, 127, size=({shape},), dtype=np.int8)" |
| 756 | + raise NotImplementedError( |
| 757 | + f"Unable to generate random initializer for data type {value.data_type}." |
| 758 | + ) |
742 | 759 |
|
743 | 760 | random_initializer_values = "\n".join(
|
744 | 761 | generate_rand(key, value) for key, value in self.skipped_initializers.items()
|
|
0 commit comments