Skip to content

Commit 36f25ce

Browse files
committed
Revert "ArBackend: Enable Pybindings for tosa_serialization lib (pytorch#15356)"
This reverts commit fdfeaa4.
1 parent 53bb98b commit 36f25ce

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+345
-489
lines changed

.mypy.ini

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,6 @@ ignore_missing_imports = True
8383
[mypy-tosa_tools.*]
8484
ignore_missing_imports = True
8585

86-
[mypy-tosa_serializer]
87-
ignore_missing_imports = True
88-
89-
[mypy-tosa_serializer.*]
90-
ignore_missing_imports = True
91-
9286
[mypy-setuptools.*]
9387
ignore_missing_imports = True
9488

backends/arm/common/debug.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
import os
88
from typing import Optional
99

10+
import serializer.tosa_serializer as ts
1011
import torch
11-
12-
import tosa_serializer as ts
1312
from executorch.exir.print_program import inspect_node
1413

1514
logger = logging.getLogger(__name__)
@@ -51,20 +50,29 @@ def get_node_debug_info(
5150
return output
5251

5352

54-
# Output TOSA flatbuffer for debugging
55-
def debug_tosa_dump(tosa_graph: bytes, path: str, suffix: str = ""):
53+
# Output TOSA flatbuffer and test harness file
54+
def debug_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):
5655
filename = f"output{suffix}.tosa"
5756

5857
logger.info(f"Emitting debug output to: {path=}, {suffix=}")
5958

6059
os.makedirs(path, exist_ok=True)
6160

61+
fb = tosa_graph.serialize()
62+
js = tosa_graph.writeJson(filename)
63+
6264
filepath_tosa_fb = os.path.join(path, filename)
6365
with open(filepath_tosa_fb, "wb") as f:
64-
f.write(tosa_graph)
66+
f.write(fb)
6567
if not os.path.exists(filepath_tosa_fb):
6668
raise IOError("Failed to write TOSA flatbuffer")
6769

70+
filepath_desc_json = os.path.join(path, f"desc{suffix}.json")
71+
with open(filepath_desc_json, "w") as f:
72+
f.write(js)
73+
if not os.path.exists(filepath_desc_json):
74+
raise IOError("Failed to write TOSA JSON")
75+
6876

6977
def debug_fail(
7078
node,
@@ -73,7 +81,7 @@ def debug_fail(
7381
path: Optional[str] = None,
7482
):
7583
logger.warning("Internal error due to poorly handled node:")
76-
if tosa_graph is not None and path:
77-
debug_tosa_dump(tosa_graph.serialize(), path)
84+
if tosa_graph is not None and path is not None:
85+
debug_tosa_dump(tosa_graph, path)
7886
logger.warning(f"Debug output captured in '{path}'.")
7987
debug_node(node, graph_module)

backends/arm/debug/schema.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from dataclasses import asdict, dataclass
1111
from typing import Any, Optional
1212

13+
import serializer.tosa_serializer as ts
1314
import torch
14-
import tosa_serializer as ts
1515

1616
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
1717

@@ -114,18 +114,23 @@ def to_dict(self) -> dict[str, Any]:
114114
class DebugHook:
115115
def __init__(self, debug_mode: ArmCompileSpec.DebugMode) -> None:
116116
self._debug_events: list[DebugSchema] = []
117+
self.__op_id_to_name = {}
117118
self.mode = debug_mode
118119

119-
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: ts.Op) -> DebugSchema:
120+
# Build up a mapping from TOSA 1.0 operator IDs to their names
121+
for name, val in vars(ts.Op).items():
122+
self.__op_id_to_name[val] = name
123+
124+
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema:
120125
tosa_debug_info = None
121126

122127
# If the debug data is being embedded into the TOSA flatbuffer
123128
# do not collect TOSADebugSchema data, it's redundent
124129
if self.mode != ArmCompileSpec.DebugMode.TOSA:
125130
tosa_debug_info = TosaDebugSchema(
126131
node_name=str(tosa_op),
127-
operator_name=str(tosa_op_id),
128-
operator_id=int(tosa_op_id),
132+
operator_name=self.__op_id_to_name[tosa_op_id],
133+
operator_id=tosa_op_id,
129134
)
130135

131136
aten_debug_info = ATenDebugSchema.from_node(node)

backends/arm/ethosu/backend.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,6 @@ def _compile_tosa_flatbuffer(
5151
"compile_flags are required in the CompileSpec list for EthosUBackend"
5252
)
5353

54-
# Vela tooling only supports flatbuffers up to 2 GiB.
55-
max_flatbuffer_size = 2 * 1024 * 1024 * 1024
56-
flatbuffer_size = len(tosa_flatbuffer)
57-
if flatbuffer_size > max_flatbuffer_size:
58-
raise RuntimeError(
59-
"TOSA flatbuffer is too large for Vela "
60-
f"({flatbuffer_size} bytes > {max_flatbuffer_size} bytes limit)."
61-
)
62-
6354
# Pass on the TOSA flatbuffer to the vela compiler.
6455
binary = vela_compile(
6556
tosa_flatbuffer,

backends/arm/operators/node_visitor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import Any, Dict, List, Optional
1010

1111
import torch
12-
import tosa_serializer as ts
1312

1413
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
1514
from executorch.backends.arm.debug.schema import DebugHook
@@ -47,12 +46,12 @@ def _serialize_operator(
4746
self,
4847
node: torch.fx.Node,
4948
tosa_graph: Any,
50-
tosa_op: ts.Op,
49+
tosa_op: Any,
5150
inputs: List[str],
5251
outputs: List[str],
5352
attributes: Optional[Any] = None,
5453
) -> None:
55-
op_location = ts.TosaOpLocation()
54+
op_location = ""
5655
if self.debug_hook:
5756
debug_info = self.debug_hook.add(
5857
node,
@@ -61,7 +60,7 @@ def _serialize_operator(
6160
)
6261

6362
if self.debug_hook.mode == ArmCompileSpec.DebugMode.TOSA:
64-
op_location.text = json.dumps(debug_info.to_dict())
63+
op_location = json.dumps(debug_info.to_dict())
6564

6665
tosa_graph.addOperator(
6766
tosa_op,

backends/arm/operators/op_abs.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# pyre-unsafe
77
from typing import Any, List
88

9-
import tosa_serializer as ts
9+
import serializer.tosa_serializer as ts
1010

1111
from executorch.backends.arm.operators.node_visitor import (
1212
NodeVisitor,
@@ -48,13 +48,11 @@ def define_node(
4848
output.tosa_spec,
4949
)
5050

51-
attr = ts.TosaSerializerAttribute()
52-
attr.AbsAttribute()
53-
self._serialize_operator(
54-
node,
55-
tosa_graph,
56-
ts.Op.ABS,
57-
[inputs[0].name],
51+
tosa_graph.addOperator(
52+
ts.TosaOp.Op().ABS,
53+
[
54+
inputs[0].name,
55+
],
5856
[output.name],
59-
attr,
57+
None,
6058
)

backends/arm/operators/op_add.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import executorch.backends.arm.tosa.quant_utils as tqutils
1111
import executorch.backends.arm.tosa.utils as tutils
12-
import tosa_serializer as ts
12+
import serializer.tosa_serializer as ts
1313

1414
from executorch.backends.arm.operators.node_visitor import (
1515
NodeVisitor,
@@ -81,16 +81,15 @@ def define_node(
8181
add_output = output
8282

8383
input1, input2 = rescaled_inputs
84-
attr = ts.TosaSerializerAttribute()
85-
attr.AddAttribute()
84+
8685
# Do the INT32 Add
8786
self._serialize_operator(
8887
node,
8988
tosa_graph,
90-
ts.Op.ADD,
89+
ts.TosaOp.Op().ADD,
9190
[input1.name, input2.name],
9291
[add_output.name],
93-
attr,
92+
None,
9493
)
9594

9695
if output.dtype == ts.DType.INT8:
@@ -144,14 +143,13 @@ def define_node(
144143
)
145144

146145
input1, input2 = inputs
147-
attr = ts.TosaSerializerAttribute()
148-
attr.AddAttribute()
146+
149147
# FP lowering
150148
self._serialize_operator(
151149
node,
152150
tosa_graph,
153-
ts.Op.ADD,
151+
ts.TosaOp.Op().ADD,
154152
[input1.name, input2.name],
155153
[output.name],
156-
attr,
154+
None,
157155
)

backends/arm/operators/op_amax.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import Any, List
66

7-
import tosa_serializer as ts
7+
import serializer.tosa_serializer as ts
88

99
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1010
from executorch.backends.arm.operators.node_visitor import (
@@ -60,12 +60,11 @@ def define_node(
6060
)
6161

6262
attr = ts.TosaSerializerAttribute()
63-
nan_mode = ts.NanPropagationMode.PROPAGATE
64-
attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=nan_mode)
63+
attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=1)
6564
self._serialize_operator(
6665
node,
6766
tosa_graph,
68-
ts.Op.REDUCE_MAX,
67+
ts.TosaOp.Op().REDUCE_MAX,
6968
[input.name],
7069
[output.name],
7170
attr,

backends/arm/operators/op_amin.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import Any, List
66

7-
import tosa_serializer as ts
7+
import serializer.tosa_serializer as ts
88

99
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1010
from executorch.backends.arm.operators.node_visitor import (
@@ -60,13 +60,11 @@ def define_node(
6060
)
6161

6262
attr = ts.TosaSerializerAttribute()
63-
attr.ReduceMinAttribute(
64-
axis=input.dim_order.index(dim), nan_mode=ts.NanPropagationMode.PROPAGATE
65-
)
63+
attr.ReduceMinAttribute(axis=input.dim_order.index(dim), nan_mode=1)
6664
self._serialize_operator(
6765
node,
6866
tosa_graph,
69-
ts.Op.REDUCE_MIN,
67+
ts.TosaOp.Op().REDUCE_MIN,
7068
[input.name],
7169
[output.name],
7270
attr,

backends/arm/operators/op_any.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# pyre-unsafe
77
from typing import Any, cast, List
88

9-
import tosa_serializer as ts
9+
import serializer.tosa_serializer as ts
1010

1111
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
1212
NodeVisitor,
@@ -55,7 +55,7 @@ def define_node(
5555
self._serialize_operator(
5656
node,
5757
tosa_graph,
58-
ts.Op.REDUCE_ANY,
58+
ts.TosaOp.Op().REDUCE_ANY,
5959
[inputs[0].name],
6060
[output.name],
6161
attr,

0 commit comments

Comments
 (0)