Skip to content

Commit c67e327

Browse files
authored
Merge branch 'main' into tag_submodule
2 parents 6917f85 + f7ca57e commit c67e327

14 files changed

+110
-94
lines changed

backends/arm/_passes/decompose_int16_activation_conv2d_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ def call_operator(self, op, args, kwargs, meta):
105105

106106
conv_output = super().call_operator(
107107
exir_ops.backend.tosa.RESCALE.default,
108-
(convolution, torch.int32, conv_rescale_factor, 0, 0),
108+
(convolution, torch.int32, [conv_rescale_factor], 0, 0),
109109
{},
110110
new_meta,
111111
)
112112

113113
bias_rescaled = super().call_operator(
114114
exir_ops.backend.tosa.RESCALE.default,
115-
(channel_bias, torch.int32, bias_rescale_factor, 0, 0),
115+
(channel_bias, torch.int32, [bias_rescale_factor], 0, 0),
116116
{},
117117
new_meta,
118118
)
@@ -129,7 +129,7 @@ def call_operator(self, op, args, kwargs, meta):
129129
(
130130
add,
131131
output_dtype,
132-
(common_scale / (conv_output_scale * (1 << bits_left_to_shift))),
132+
[(common_scale / (conv_output_scale * (1 << bits_left_to_shift)))],
133133
0,
134134
0,
135135
),

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule
4545
(
4646
node.all_input_nodes[0],
4747
q_args.dtype,
48-
new_scale,
48+
[new_scale],
4949
dq_args.zp,
5050
q_args.zp,
5151
),
@@ -228,10 +228,10 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b
228228
(
229229
arg_node,
230230
torch.int32,
231-
qp.get_scale_per_tensor()
232-
/ rescale_qargs[
233-
i
234-
].get_scale_per_tensor(), # Old scale / new scale
231+
[
232+
qp.get_scale_per_tensor()
233+
/ rescale_qargs[i].get_scale_per_tensor()
234+
], # [Old scale / new scale]
235235
qp.get_zp_per_tensor(), # Old zero point
236236
rescale_qargs[i].get_zp_per_tensor(), # New zero point
237237
),
@@ -264,8 +264,10 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
264264
(
265265
node,
266266
qarg.dtype,
267-
rescale_qargs.get_scale_per_tensor()
268-
/ qarg.get_scale_per_tensor(), # Old scale / new scale
267+
[
268+
rescale_qargs.get_scale_per_tensor()
269+
/ qarg.get_scale_per_tensor()
270+
], # [Old scale / new scale]
269271
rescale_qargs.get_zp_per_tensor(), # Old zero point
270272
qarg.get_zp_per_tensor(), # New zero point
271273
),

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
286286
rescale_node = create_node(
287287
graph=graph_module.graph,
288288
op_target=exir_ops.backend.tosa.RESCALE.default,
289-
args=(table_op_node, output_qparams[0].dtype, scale, 0, 0),
289+
args=(table_op_node, output_qparams[0].dtype, [scale], 0, 0),
290290
)
291291
output_node = rescale_node
292292

backends/arm/_passes/rewrite_conv2d_pass.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7+
import itertools
78
from typing import Set, Type
89

910
import torch
@@ -16,6 +17,10 @@
1617
is_buffer,
1718
is_param,
1819
)
20+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
21+
get_input_qparams,
22+
get_output_qparams,
23+
)
1924
from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER
2025
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2126
from executorch.backends.transforms.utils import create_constant_placeholder
@@ -156,6 +161,40 @@ def _add_bias(
156161
node.update_arg(2, bias_node)
157162
return bias_node
158163

164+
def insert_output_rescale(self, graph_module, node):
165+
input_qparams = get_input_qparams(node)
166+
output_qparams = get_output_qparams(node)[0]
167+
weight_qparams = input_qparams[1]
168+
input_qparams = input_qparams[0]
169+
is_per_channel = weight_qparams.per_channel
170+
if is_per_channel:
171+
weight_scale = weight_qparams.get_scale_per_channel()
172+
else:
173+
weight_scale = [weight_qparams.get_scale_per_tensor()]
174+
input_scale = input_qparams.get_scale_per_tensor()
175+
post_conv2d_scale = [
176+
(inp * w) / out
177+
for inp, w, out in zip(
178+
itertools.cycle([input_scale]),
179+
weight_scale,
180+
itertools.cycle([output_qparams.get_scale_per_tensor()]),
181+
)
182+
]
183+
with graph_module.graph.inserting_after(node):
184+
rescale_node = create_node(
185+
graph=graph_module.graph,
186+
op_target=exir_ops.backend.tosa.RESCALE.default,
187+
args=(
188+
node,
189+
output_qparams.dtype,
190+
post_conv2d_scale,
191+
0,
192+
output_qparams.get_zp_per_tensor(),
193+
),
194+
from_node=node,
195+
)
196+
return rescale_node
197+
159198
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
160199
modified = False
161200
for node in graph_module.graph.nodes:
@@ -180,20 +219,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
180219
) = node.args
181220

182221
pad = [val for val in pad for _ in (0, 1)]
183-
input_shape = get_first_fake_tensor(x).shape
184-
weight_shape = get_first_fake_tensor(weight).shape
222+
input_fake_tensor = get_first_fake_tensor(x)
223+
weight_fake_tensor = get_first_fake_tensor(weight)
185224
# Adjust the pad value if needed to meet the
186225
# strict convolution output shape calculation.
187226
pad[1] = self._adjust_pad_if_needed(
188-
input_shape[2],
189-
weight_shape[2],
227+
input_fake_tensor.shape[2],
228+
weight_fake_tensor.shape[2],
190229
stride[0],
191230
pad[1],
192231
dilation[0],
193232
)
194233
pad[3] = self._adjust_pad_if_needed(
195-
input_shape[3],
196-
weight_shape[3],
234+
input_fake_tensor.shape[3],
235+
weight_fake_tensor.shape[3],
197236
stride[1],
198237
pad[3],
199238
dilation[1],
@@ -204,7 +243,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
204243

205244
if self._is_depthwise_conv2d(node):
206245
target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default
207-
self._reshape_weights(weight, input_shape[1])
246+
self._reshape_weights(weight, input_fake_tensor.shape[1])
247+
weight_fake_tensor = get_first_fake_tensor(weight)
208248
else:
209249
target_op = exir_ops.backend.tosa.CONV2D.default
210250

@@ -227,9 +267,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
227267
args=conv2d_args,
228268
from_node=node,
229269
)
270+
bias_fake_tensor = get_first_fake_tensor(bias) if bias else None
271+
tosa_node_fake_tensor = target_op(
272+
input_fake_tensor,
273+
weight_fake_tensor,
274+
bias_fake_tensor,
275+
*conv2d_args[3:],
276+
)
230277

278+
if (
279+
tosa_node_fake_tensor.dtype == torch.int32
280+
and input_fake_tensor.dtype == torch.int8
281+
) or (
282+
tosa_node_fake_tensor.dtype == torch.int32
283+
and input_fake_tensor.dtype == torch.int16
284+
):
285+
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
286+
node.replace_all_uses_with(output_rescale)
287+
if input_fake_tensor.dtype == torch.int16:
288+
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
289+
else:
231290
node.replace_all_uses_with(tosa_op)
232-
graph_module.graph.erase_node(node)
291+
292+
graph_module.graph.erase_node(node)
233293

234294
if modified:
235295
graph_module.recompile()

backends/arm/_passes/rewrite_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _insert_output_rescale(self, graph_module, node, tosa_matmul_node, dtype):
4444
rescale_node.args = (
4545
tosa_matmul_node,
4646
dtype,
47-
scale,
47+
[scale],
4848
0,
4949
output_qparams.get_zp_per_tensor(),
5050
)

backends/arm/_passes/rewrite_upsample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def call(self, graph_module):
7474
rescale_node.args = (
7575
tosa_resize_node,
7676
output_dtype,
77-
output_scale,
77+
[output_scale],
7878
0, # zero point
7979
0, # zero point
8080
)

backends/arm/operators/op_tosa_conv2d.py

Lines changed: 4 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88

99
"""Provide a visitor for lowering 2D convolution to TOSA (INT/FP)."""
1010

11-
import itertools
1211
from typing import Any, List
1312

1413
import torch
1514

1615
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1716
get_input_qparams,
18-
get_output_qparams,
1917
)
2018
from executorch.backends.arm.operators.node_visitor import (
2119
NodeVisitor,
@@ -26,9 +24,7 @@
2624
validate_valid_dtype,
2725
)
2826
from executorch.backends.arm.tosa.mapping import TosaArg
29-
from executorch.backends.arm.tosa.quant_utils import build_rescale
3027
from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification
31-
from executorch.backends.arm.tosa.utils import tosa_shape
3228

3329

3430
@register_node_visitor
@@ -58,7 +54,8 @@ def define_node(
5854
inputs: List[TosaArg],
5955
output: TosaArg,
6056
) -> None:
61-
"""Define the TOSA CONV2D/DEPTHWISE_CONV2D operator and post-rescale."""
57+
"""Define the TOSA CONV2D/DEPTHWISE_CONV2D operator."""
58+
6259
input, weight, bias, stride, pad, dilation, _, _, group = inputs
6360
validate_num_inputs(self.target, inputs, 9)
6461

@@ -105,23 +102,8 @@ def define_node(
105102
input_qparams = get_input_qparams(node)
106103
weight_zp = input_qparams[1].zp # type: ignore[assignment]
107104

108-
# The output type is int32 when input type is int8.
109-
if inputs[0].dtype == ts.DType.INT8:
110-
conv2d_res = tosa_graph.addIntermediate(
111-
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
112-
)
113-
conv2d_output_name = conv2d_res.name
114-
acc_type = ts.DType.INT32
115-
elif inputs[0].dtype == ts.DType.INT16:
116-
conv2d_res = tosa_graph.addIntermediate(
117-
tosa_shape(output.shape, output.dim_order), ts.DType.INT48
118-
)
119-
conv2d_output_name = conv2d_res.name
120-
acc_type = ts.DType.INT48
121-
else:
122-
conv2d_output_name = output.name
123-
conv2d_res = output
124-
acc_type = ts.DType.FP32
105+
conv2d_output_name = output.name
106+
acc_type = output.dtype
125107

126108
tosa_graph.addConst(
127109
[1], inputs[0].dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
@@ -158,36 +140,3 @@ def define_node(
158140
[conv2d_output_name],
159141
attr,
160142
)
161-
162-
# For quantized convolution, rescale the output value back to the same
163-
# integer value domain of the next op. Otherwise return float32 output.
164-
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
165-
# Get scale_factor from input, weight, and output.
166-
input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61]
167-
per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61]
168-
if per_channel_quant:
169-
weight_scale = input_qparams[1].get_scale_per_channel()
170-
else:
171-
weight_scale = [
172-
input_qparams[1].get_scale_per_tensor()
173-
] # pyre-ignore [61]
174-
output_qargs = get_output_qparams(node)
175-
post_conv2d_scale = [
176-
(inp * w) / out
177-
for inp, w, out in zip(
178-
itertools.cycle([input_scale]),
179-
weight_scale,
180-
itertools.cycle([output_qargs[0].get_scale_per_tensor()]),
181-
)
182-
]
183-
build_rescale(
184-
tosa_fb=tosa_graph,
185-
scale=post_conv2d_scale,
186-
input_node=conv2d_res, # type: ignore[possibly-undefined]
187-
output_name=output.name,
188-
output_type=output.dtype,
189-
input_zp=[0],
190-
output_zp=[output_qargs[0].get_zp_per_tensor()],
191-
per_channel=per_channel_quant,
192-
rounding_mode=ts.RoundingMode.SINGLE_ROUND,
193-
)

backends/arm/operators/op_tosa_depthwise_conv2d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7+
8+
"""Provide a visitor for lowering 2D depthwise convolution to TOSA (INT/FP)."""
9+
710
import tosa_serializer as ts
11+
812
from executorch.backends.arm.operators.node_visitor import register_node_visitor
913
from executorch.backends.arm.operators.op_tosa_conv2d import Conv2dVisitor
1014
from executorch.backends.arm.tosa import TosaSpecification

backends/arm/operators/op_tosa_rescale.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def define_node(
4141

4242
input_dtype = inputs[0].dtype
4343
output_dtype = cast(torch.dtype, node.args[1])
44-
scale = cast(float, node.args[2])
44+
scales = cast(list[float], node.args[2])
4545
input_zp = cast(int, node.args[3])
4646
output_zp = cast(int, node.args[4])
4747

@@ -63,12 +63,12 @@ def define_node(
6363

6464
build_rescale(
6565
tosa_graph,
66-
scale=[scale],
66+
scale=scales,
6767
input_node=inputs[0],
6868
output_name=output.name,
6969
output_type=output.dtype,
7070
input_zp=[input_zp],
7171
output_zp=[output_zp],
7272
rounding_mode=ts.RoundingMode.SINGLE_ROUND,
73-
per_channel=False,
73+
per_channel=len(scales) > 1,
7474
)

backends/arm/test/misc/test_tosa_dialect_conv2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_conv2d_tosa_INT():
3131
4,
3232
),
3333
(1, 8, 20, 20),
34-
torch.int8,
34+
torch.int32,
3535
),
3636
(
3737
(
@@ -46,7 +46,7 @@ def test_conv2d_tosa_INT():
4646
4,
4747
),
4848
(1, 4, 10, 10),
49-
torch.int8,
49+
torch.int32,
5050
),
5151
]
5252

0 commit comments

Comments
 (0)