Skip to content

Commit 31ce12b

Browse files
Merge branch 'main' into mean-default
2 parents cf2dbe8 + f7ca57e commit 31ce12b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+602
-447
lines changed

.mypy.ini

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ ignore_missing_imports = True
8383
[mypy-tosa_tools.*]
8484
ignore_missing_imports = True
8585

86+
[mypy-tosa_serializer]
87+
ignore_missing_imports = True
88+
89+
[mypy-tosa_serializer.*]
90+
ignore_missing_imports = True
91+
8692
[mypy-setuptools.*]
8793
ignore_missing_imports = True
8894

backends/arm/_passes/decompose_int16_activation_conv2d_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ def call_operator(self, op, args, kwargs, meta):
105105

106106
conv_output = super().call_operator(
107107
exir_ops.backend.tosa.RESCALE.default,
108-
(convolution, torch.int32, conv_rescale_factor, 0, 0),
108+
(convolution, torch.int32, [conv_rescale_factor], 0, 0),
109109
{},
110110
new_meta,
111111
)
112112

113113
bias_rescaled = super().call_operator(
114114
exir_ops.backend.tosa.RESCALE.default,
115-
(channel_bias, torch.int32, bias_rescale_factor, 0, 0),
115+
(channel_bias, torch.int32, [bias_rescale_factor], 0, 0),
116116
{},
117117
new_meta,
118118
)
@@ -129,7 +129,7 @@ def call_operator(self, op, args, kwargs, meta):
129129
(
130130
add,
131131
output_dtype,
132-
(common_scale / (conv_output_scale * (1 << bits_left_to_shift))),
132+
[(common_scale / (conv_output_scale * (1 << bits_left_to_shift)))],
133133
0,
134134
0,
135135
),

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule
4545
(
4646
node.all_input_nodes[0],
4747
q_args.dtype,
48-
new_scale,
48+
[new_scale],
4949
dq_args.zp,
5050
q_args.zp,
5151
),
@@ -228,10 +228,10 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b
228228
(
229229
arg_node,
230230
torch.int32,
231-
qp.get_scale_per_tensor()
232-
/ rescale_qargs[
233-
i
234-
].get_scale_per_tensor(), # Old scale / new scale
231+
[
232+
qp.get_scale_per_tensor()
233+
/ rescale_qargs[i].get_scale_per_tensor()
234+
], # [Old scale / new scale]
235235
qp.get_zp_per_tensor(), # Old zero point
236236
rescale_qargs[i].get_zp_per_tensor(), # New zero point
237237
),
@@ -264,8 +264,10 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
264264
(
265265
node,
266266
qarg.dtype,
267-
rescale_qargs.get_scale_per_tensor()
268-
/ qarg.get_scale_per_tensor(), # Old scale / new scale
267+
[
268+
rescale_qargs.get_scale_per_tensor()
269+
/ qarg.get_scale_per_tensor()
270+
], # [Old scale / new scale]
269271
rescale_qargs.get_zp_per_tensor(), # Old zero point
270272
qarg.get_zp_per_tensor(), # New zero point
271273
),

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
286286
rescale_node = create_node(
287287
graph=graph_module.graph,
288288
op_target=exir_ops.backend.tosa.RESCALE.default,
289-
args=(table_op_node, output_qparams[0].dtype, scale, 0, 0),
289+
args=(table_op_node, output_qparams[0].dtype, [scale], 0, 0),
290290
)
291291
output_node = rescale_node
292292

backends/arm/_passes/rewrite_conv2d_pass.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7+
import itertools
78
from typing import Set, Type
89

910
import torch
@@ -16,6 +17,10 @@
1617
is_buffer,
1718
is_param,
1819
)
20+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
21+
get_input_qparams,
22+
get_output_qparams,
23+
)
1924
from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER
2025
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2126
from executorch.backends.transforms.utils import create_constant_placeholder
@@ -156,6 +161,40 @@ def _add_bias(
156161
node.update_arg(2, bias_node)
157162
return bias_node
158163

164+
def insert_output_rescale(self, graph_module, node):
165+
input_qparams = get_input_qparams(node)
166+
output_qparams = get_output_qparams(node)[0]
167+
weight_qparams = input_qparams[1]
168+
input_qparams = input_qparams[0]
169+
is_per_channel = weight_qparams.per_channel
170+
if is_per_channel:
171+
weight_scale = weight_qparams.get_scale_per_channel()
172+
else:
173+
weight_scale = [weight_qparams.get_scale_per_tensor()]
174+
input_scale = input_qparams.get_scale_per_tensor()
175+
post_conv2d_scale = [
176+
(inp * w) / out
177+
for inp, w, out in zip(
178+
itertools.cycle([input_scale]),
179+
weight_scale,
180+
itertools.cycle([output_qparams.get_scale_per_tensor()]),
181+
)
182+
]
183+
with graph_module.graph.inserting_after(node):
184+
rescale_node = create_node(
185+
graph=graph_module.graph,
186+
op_target=exir_ops.backend.tosa.RESCALE.default,
187+
args=(
188+
node,
189+
output_qparams.dtype,
190+
post_conv2d_scale,
191+
0,
192+
output_qparams.get_zp_per_tensor(),
193+
),
194+
from_node=node,
195+
)
196+
return rescale_node
197+
159198
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
160199
modified = False
161200
for node in graph_module.graph.nodes:
@@ -180,20 +219,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
180219
) = node.args
181220

182221
pad = [val for val in pad for _ in (0, 1)]
183-
input_shape = get_first_fake_tensor(x).shape
184-
weight_shape = get_first_fake_tensor(weight).shape
222+
input_fake_tensor = get_first_fake_tensor(x)
223+
weight_fake_tensor = get_first_fake_tensor(weight)
185224
# Adjust the pad value if needed to meet the
186225
# strict convolution output shape calculation.
187226
pad[1] = self._adjust_pad_if_needed(
188-
input_shape[2],
189-
weight_shape[2],
227+
input_fake_tensor.shape[2],
228+
weight_fake_tensor.shape[2],
190229
stride[0],
191230
pad[1],
192231
dilation[0],
193232
)
194233
pad[3] = self._adjust_pad_if_needed(
195-
input_shape[3],
196-
weight_shape[3],
234+
input_fake_tensor.shape[3],
235+
weight_fake_tensor.shape[3],
197236
stride[1],
198237
pad[3],
199238
dilation[1],
@@ -204,7 +243,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
204243

205244
if self._is_depthwise_conv2d(node):
206245
target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default
207-
self._reshape_weights(weight, input_shape[1])
246+
self._reshape_weights(weight, input_fake_tensor.shape[1])
247+
weight_fake_tensor = get_first_fake_tensor(weight)
208248
else:
209249
target_op = exir_ops.backend.tosa.CONV2D.default
210250

@@ -227,9 +267,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
227267
args=conv2d_args,
228268
from_node=node,
229269
)
270+
bias_fake_tensor = get_first_fake_tensor(bias) if bias else None
271+
tosa_node_fake_tensor = target_op(
272+
input_fake_tensor,
273+
weight_fake_tensor,
274+
bias_fake_tensor,
275+
*conv2d_args[3:],
276+
)
230277

278+
if (
279+
tosa_node_fake_tensor.dtype == torch.int32
280+
and input_fake_tensor.dtype == torch.int8
281+
) or (
282+
tosa_node_fake_tensor.dtype == torch.int32
283+
and input_fake_tensor.dtype == torch.int16
284+
):
285+
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
286+
node.replace_all_uses_with(output_rescale)
287+
if input_fake_tensor.dtype == torch.int16:
288+
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
289+
else:
231290
node.replace_all_uses_with(tosa_op)
232-
graph_module.graph.erase_node(node)
291+
292+
graph_module.graph.erase_node(node)
233293

234294
if modified:
235295
graph_module.recompile()

backends/arm/_passes/rewrite_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _insert_output_rescale(self, graph_module, node, tosa_matmul_node, dtype):
4444
rescale_node.args = (
4545
tosa_matmul_node,
4646
dtype,
47-
scale,
47+
[scale],
4848
0,
4949
output_qparams.get_zp_per_tensor(),
5050
)

backends/arm/_passes/rewrite_upsample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def call(self, graph_module):
7474
rescale_node.args = (
7575
tosa_resize_node,
7676
output_dtype,
77-
output_scale,
77+
[output_scale],
7878
0, # zero point
7979
0, # zero point
8080
)

backends/arm/common/debug.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import os
88
from typing import Optional
99

10-
import serializer.tosa_serializer as ts
1110
import torch
11+
12+
import tosa_serializer as ts
1213
from executorch.exir.print_program import inspect_node
1314

1415
logger = logging.getLogger(__name__)
@@ -50,29 +51,20 @@ def get_node_debug_info(
5051
return output
5152

5253

53-
# Output TOSA flatbuffer and test harness file
54-
def debug_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):
54+
# Output TOSA flatbuffer for debugging
55+
def debug_tosa_dump(tosa_graph: bytes, path: str, suffix: str = ""):
5556
filename = f"output{suffix}.tosa"
5657

5758
logger.info(f"Emitting debug output to: {path=}, {suffix=}")
5859

5960
os.makedirs(path, exist_ok=True)
6061

61-
fb = tosa_graph.serialize()
62-
js = tosa_graph.writeJson(filename)
63-
6462
filepath_tosa_fb = os.path.join(path, filename)
6563
with open(filepath_tosa_fb, "wb") as f:
66-
f.write(fb)
64+
f.write(tosa_graph)
6765
if not os.path.exists(filepath_tosa_fb):
6866
raise IOError("Failed to write TOSA flatbuffer")
6967

70-
filepath_desc_json = os.path.join(path, f"desc{suffix}.json")
71-
with open(filepath_desc_json, "w") as f:
72-
f.write(js)
73-
if not os.path.exists(filepath_desc_json):
74-
raise IOError("Failed to write TOSA JSON")
75-
7668

7769
def debug_fail(
7870
node,
@@ -81,7 +73,7 @@ def debug_fail(
8173
path: Optional[str] = None,
8274
):
8375
logger.warning("Internal error due to poorly handled node:")
84-
if tosa_graph is not None and path is not None:
85-
debug_tosa_dump(tosa_graph, path)
76+
if tosa_graph is not None and path:
77+
debug_tosa_dump(tosa_graph.serialize(), path)
8678
logger.warning(f"Debug output captured in '{path}'.")
8779
debug_node(node, graph_module)

backends/arm/debug/schema.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from dataclasses import asdict, dataclass
1111
from typing import Any, Optional
1212

13-
import serializer.tosa_serializer as ts
1413
import torch
14+
import tosa_serializer as ts
1515

1616
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
1717

@@ -114,23 +114,18 @@ def to_dict(self) -> dict[str, Any]:
114114
class DebugHook:
115115
def __init__(self, debug_mode: ArmCompileSpec.DebugMode) -> None:
116116
self._debug_events: list[DebugSchema] = []
117-
self.__op_id_to_name = {}
118117
self.mode = debug_mode
119118

120-
# Build up a mapping from TOSA 1.0 operator IDs to their names
121-
for name, val in vars(ts.Op).items():
122-
self.__op_id_to_name[val] = name
123-
124-
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema:
119+
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: ts.Op) -> DebugSchema:
125120
tosa_debug_info = None
126121

127122
# If the debug data is being embedded into the TOSA flatbuffer
128123
# do not collect TOSADebugSchema data, it's redundent
129124
if self.mode != ArmCompileSpec.DebugMode.TOSA:
130125
tosa_debug_info = TosaDebugSchema(
131126
node_name=str(tosa_op),
132-
operator_name=self.__op_id_to_name[tosa_op_id],
133-
operator_id=tosa_op_id,
127+
operator_name=str(tosa_op_id),
128+
operator_id=int(tosa_op_id),
134129
)
135130

136131
aten_debug_info = ATenDebugSchema.from_node(node)

backends/arm/ethosu/backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ def _compile_tosa_flatbuffer(
5151
"compile_flags are required in the CompileSpec list for EthosUBackend"
5252
)
5353

54+
# Vela tooling only supports flatbuffers up to 2 GiB.
55+
max_flatbuffer_size = 2 * 1024 * 1024 * 1024
56+
flatbuffer_size = len(tosa_flatbuffer)
57+
if flatbuffer_size > max_flatbuffer_size:
58+
raise RuntimeError(
59+
"TOSA flatbuffer is too large for Vela "
60+
f"({flatbuffer_size} bytes > {max_flatbuffer_size} bytes limit)."
61+
)
62+
5463
# Pass on the TOSA flatbuffer to the vela compiler.
5564
binary = vela_compile(
5665
tosa_flatbuffer,

0 commit comments

Comments
 (0)