Skip to content

Commit d975528

Browse files
committed
Revert "Arm backend: Move rescales from ADD & SUB visitors to pass (pytorch#15378)"
This reverts commit 4efd79c.
1 parent 4250b49 commit d975528

File tree

7 files changed

+570
-82
lines changed

7 files changed

+570
-82
lines changed

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,13 @@ def call(self, graph_module: GraphModule) -> PassResult:
7676

7777

7878
class InsertRescaleInt32Pass(ArmPass):
79-
"""Numerous TOSA ops require inputs and outputs to be 32-bit integers in their
79+
"""
80+
Numerous TOSA ops require inputs and outputs to be 32-bit integers in their
8081
quantized implementations. This pass treats such operator nodes by
81-
inserting rescale ops before and after them if needed. Note that extra
82-
logic that handles the scales and zero points are in place here because the
83-
affected TOSA ops have naive implementations that do not account for the
84-
quantization parameters.
82+
inserting rescale ops before and after them if needed. Note that extra logic
83+
that handles the scales and zero points must be in place because the affected
84+
TOSA have naive implementations that do not account for the quantization
85+
parameters.
8586
"""
8687

8788
# SUM must be decomposed after this pass to prevent insertion of RESCALE
@@ -92,7 +93,6 @@ class InsertRescaleInt32Pass(ArmPass):
9293

9394
included_targets = [
9495
exir_ops.edge.aten.abs.default,
95-
exir_ops.edge.aten.add.Tensor,
9696
exir_ops.edge.aten.eq.Tensor,
9797
exir_ops.edge.aten.ge.Tensor,
9898
exir_ops.edge.aten.gt.Tensor,
@@ -101,7 +101,6 @@ class InsertRescaleInt32Pass(ArmPass):
101101
exir_ops.edge.aten.maximum.default,
102102
exir_ops.edge.aten.minimum.default,
103103
exir_ops.edge.aten.mul.Tensor,
104-
exir_ops.edge.aten.sub.Tensor,
105104
exir_ops.edge.aten.sum.dim_IntList,
106105
]
107106

@@ -143,34 +142,6 @@ def _get_inputs_rescaled_qparams(
143142
qparams = {
144143
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
145144
}
146-
elif target in [
147-
exir_ops.edge.aten.add.Tensor,
148-
exir_ops.edge.aten.sub.Tensor,
149-
]:
150-
if input_qparams[0].dtype != input_qparams[1].dtype:
151-
raise ValueError(
152-
"Mismatch in dtype args: {input_qparams[0].dtype} != {input_qparams[1].dtype}"
153-
)
154-
155-
# We are handling two INT8 or two INT16 numbers. For INT8, if the
156-
# zero point is non-null, the result will be in the range [-255;
157-
# 255], therefore we need 9 bits for the result. We have a 32-bit
158-
# accumulator, so we can divide the scale by (1 << 20) which is
159-
# equivalent to shifting the INT8 operands 20 bits to the left
160-
# before rescaling them both to 2 * max(lhs, rhs).
161-
#
162-
# For INT16, similary logic can be applied, but we instead end up
163-
# with a left shift of 12.
164-
lhs_scale, rhs_scale = (
165-
qp.get_scale_per_tensor() for qp in input_qparams.values()
166-
)
167-
max_scale_2x = 2 * max(lhs_scale, rhs_scale)
168-
169-
# Select shift based on input dtype.
170-
shift_bits = 12 if input_qparams[0].dtype == torch.int16 else 20
171-
172-
scale = max_scale_2x / (1 << shift_bits)
173-
qparams = {i: self._int32_qargs(scale) for i in range(len(input_qparams))}
174145
elif target in [
175146
exir_ops.edge.aten.mul.Tensor,
176147
exir_ops.edge.aten.sum.dim_IntList,
@@ -197,8 +168,6 @@ def _get_output_qparams(
197168
exir_ops.edge.aten.maximum.default,
198169
exir_ops.edge.aten.minimum.default,
199170
exir_ops.edge.aten.sum.dim_IntList,
200-
exir_ops.edge.aten.add.Tensor,
201-
exir_ops.edge.aten.sub.Tensor,
202171
]:
203172
# The op has not altered the scale; the output scale is equal to
204173
# the operands' scales.

backends/arm/operators/op_add.py

Lines changed: 103 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import Any, List
88

9+
import executorch.backends.arm.tosa.quant_utils as tqutils
10+
import executorch.backends.arm.tosa.utils as tutils
911
import tosa_serializer as ts
1012

1113
from executorch.backends.arm.operators.node_visitor import (
@@ -17,20 +19,22 @@
1719
validate_same_dtype,
1820
validate_valid_dtype,
1921
)
22+
from executorch.backends.arm.tosa import TosaSpecification
2023
from executorch.backends.arm.tosa.mapping import TosaArg
21-
from executorch.backends.arm.tosa.specification import TosaSpecification
2224
from torch.fx import Node
2325

2426

2527
@register_node_visitor
26-
class AddVisitor(NodeVisitor):
28+
class AddVisitor_INT(NodeVisitor):
2729
target = "aten.add.Tensor"
2830

2931
tosa_specs = [
3032
TosaSpecification.create_from_string("TOSA-1.0+INT"),
31-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3233
]
3334

35+
def __init__(self, *args):
36+
super().__init__(*args)
37+
3438
def define_node(
3539
self,
3640
node: Node,
@@ -40,21 +44,113 @@ def define_node(
4044
) -> None:
4145
validate_num_inputs(self.target, inputs, 2)
4246
validate_same_dtype(self.target, [*inputs, output], ts)
47+
valid_dtypes = []
48+
if self.tosa_spec.support_integer():
49+
valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
50+
if self.tosa_spec.support_float():
51+
valid_dtypes.extend([ts.DType.INT32])
52+
4353
validate_valid_dtype(
4454
self.target,
4555
[*inputs, output],
46-
[ts.DType.INT32, ts.DType.FP32],
56+
valid_dtypes,
4757
output.tosa_spec,
4858
)
59+
scale_back = 1.0
60+
if inputs[0].dtype == ts.DType.INT8:
61+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
62+
tosa_graph, inputs, node, self.tosa_spec
63+
)
64+
elif inputs[0].dtype == ts.DType.INT16:
65+
rescaled_inputs, scale_back = (
66+
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
67+
tosa_graph, inputs, node, self.tosa_spec
68+
)
69+
)
70+
else:
71+
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
72+
# Non quantized input, natively support by TOSA.ADD
73+
rescaled_inputs = inputs
74+
75+
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
76+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
77+
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
78+
else:
79+
# output.dtype == ts.DType.INT16 or ts.DType.INT32
80+
add_output = output
4981

82+
input1, input2 = rescaled_inputs
5083
attr = ts.TosaSerializerAttribute()
5184
attr.AddAttribute()
52-
85+
# Do the INT32 Add
5386
self._serialize_operator(
5487
node,
5588
tosa_graph,
5689
ts.Op.ADD,
57-
[inputs[0].name, inputs[1].name],
58-
[output.name],
90+
[input1.name, input2.name],
91+
[add_output.name],
5992
attr,
6093
)
94+
95+
if output.dtype == ts.DType.INT8:
96+
# Scale output back to 8 bit
97+
# pyre-ignore
98+
tqutils.insert_rescale_op_to_int8(
99+
tosa_graph,
100+
add_output,
101+
scale_back,
102+
node,
103+
compute_rescale=False,
104+
tosa_spec=self.tosa_spec,
105+
) # type: ignore[possibly-undefined]
106+
elif output.dtype == ts.DType.INT16:
107+
tqutils.insert_rescale_op_to_int16(
108+
tosa_graph,
109+
add_output,
110+
scale_back,
111+
node,
112+
compute_rescale=False,
113+
tosa_spec=self.tosa_spec,
114+
) # type: ignore[possibly-undefined]
115+
116+
117+
@register_node_visitor
118+
class AddVisitor_FP(AddVisitor_INT):
119+
# inheriting 'target' from INT class
120+
121+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
122+
123+
def __init__(self, *args):
124+
super().__init__(*args)
125+
126+
def define_node(
127+
self,
128+
node: Node,
129+
tosa_graph: Any,
130+
inputs: List[TosaArg],
131+
output: TosaArg,
132+
) -> None:
133+
validate_num_inputs(self.target, inputs, 2)
134+
validate_same_dtype(self.target, [*inputs, output], ts)
135+
136+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]:
137+
# Call the inherited define_node for handling integers
138+
super().define_node(node, tosa_graph, inputs, output)
139+
else:
140+
# FP32 Add lowering
141+
validate_valid_dtype(
142+
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
143+
)
144+
145+
input1, input2 = inputs
146+
attr = ts.TosaSerializerAttribute()
147+
attr.AddAttribute()
148+
# FP lowering
149+
self._serialize_operator(
150+
node,
151+
tosa_graph,
152+
ts.Op.ADD,
153+
[input1.name, input2.name],
154+
[output.name],
155+
attr,
156+
)

backends/arm/operators/op_sub.py

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import Any, List
88

9+
import executorch.backends.arm.tosa.quant_utils as tqutils
10+
import executorch.backends.arm.tosa.utils as tutils
911
import tosa_serializer as ts
1012

1113
from executorch.backends.arm.operators.node_visitor import (
@@ -17,20 +19,22 @@
1719
validate_same_dtype,
1820
validate_valid_dtype,
1921
)
22+
from executorch.backends.arm.tosa import TosaSpecification
2023
from executorch.backends.arm.tosa.mapping import TosaArg
21-
from executorch.backends.arm.tosa.specification import TosaSpecification
2224
from torch.fx import Node
2325

2426

2527
@register_node_visitor
26-
class SubVisitor(NodeVisitor):
28+
class SubVisitor_INT(NodeVisitor):
2729
target = "aten.sub.Tensor"
2830

2931
tosa_specs = [
3032
TosaSpecification.create_from_string("TOSA-1.0+INT"),
31-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3233
]
3334

35+
def __init__(self, *args):
36+
super().__init__(*args)
37+
3438
def define_node(
3539
self,
3640
node: Node,
@@ -43,21 +47,106 @@ def define_node(
4347
validate_valid_dtype(
4448
self.target,
4549
[*inputs, output],
46-
[ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
4751
output.tosa_spec,
4852
)
4953

54+
scale_back = 1.0
55+
if inputs[0].dtype == ts.DType.INT8:
56+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
57+
tosa_graph, inputs, node, self.tosa_spec
58+
)
59+
elif inputs[0].dtype == ts.DType.INT16:
60+
rescaled_inputs, scale_back = (
61+
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
62+
tosa_graph, inputs, node, self.tosa_spec
63+
)
64+
)
65+
else:
66+
# input[0].dtype == ts.DType.INT32
67+
# Non quantized input, natively support by TOSA.SUB
68+
rescaled_inputs = inputs
69+
70+
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
71+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
72+
sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
73+
else:
74+
# output.dtype == ts.DType.INT32
75+
sub_output = output
76+
77+
# Do the INT32 Sub
5078
attr = ts.TosaSerializerAttribute()
5179
attr.SubAttribute()
52-
5380
self._serialize_operator(
5481
node,
5582
tosa_graph,
5683
ts.Op.SUB,
5784
[
58-
inputs[0].name,
59-
inputs[1].name,
85+
rescaled_inputs[0].name,
86+
rescaled_inputs[1].name,
6087
],
61-
[output.name],
88+
[sub_output.name],
6289
attr,
6390
)
91+
92+
if output.dtype == ts.DType.INT8:
93+
# Scale output back to 8 bit
94+
# pyre-ignore
95+
tqutils.insert_rescale_op_to_int8(
96+
tosa_graph,
97+
sub_output,
98+
scale_back,
99+
node,
100+
compute_rescale=False,
101+
tosa_spec=self.tosa_spec,
102+
) # type: ignore[possibly-undefined]
103+
elif output.dtype == ts.DType.INT16:
104+
tqutils.insert_rescale_op_to_int16(
105+
tosa_graph,
106+
sub_output,
107+
scale_back,
108+
node,
109+
compute_rescale=False,
110+
tosa_spec=self.tosa_spec,
111+
) # type: ignore[possibly-undefined]
112+
113+
114+
@register_node_visitor
115+
class SubVisitor_FP(SubVisitor_INT):
116+
# inheriting 'target' from INT class
117+
118+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
119+
120+
def __init__(self, *args):
121+
super().__init__(*args)
122+
123+
def define_node(
124+
self,
125+
node: Node,
126+
tosa_graph: Any,
127+
inputs: List[TosaArg],
128+
output: TosaArg,
129+
) -> None:
130+
validate_num_inputs(self.target, inputs, 2)
131+
validate_same_dtype(self.target, [*inputs, output], ts)
132+
133+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
134+
# Call the inherited define_node for handling integers
135+
super().define_node(node, tosa_graph, inputs, output)
136+
else:
137+
# FP32 Sub lowering
138+
validate_valid_dtype(
139+
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
140+
)
141+
142+
# MI lowering
143+
attr = ts.TosaSerializerAttribute()
144+
attr.SubAttribute()
145+
self._serialize_operator(
146+
node,
147+
tosa_graph,
148+
ts.Op.SUB,
149+
[inputs[0].name, inputs[1].name],
150+
[output.name],
151+
attr,
152+
)

backends/arm/test/misc/test_conv_relu_residual_add.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,6 @@ def test_tosa_INT(per_channel_quantization):
7676
pipeline.run()
7777

7878

79-
# TODO: Xfail until the Ethos-U Vela compiler ships commit
80-
# 642f7517d3a6bd053032e1942822f6e38ccd546f. That patch fixes the bug that
81-
# causes this test to fail.
82-
@pytest.mark.xfail(
83-
reason=("Blocked by Vela commit 642f7517d3a6bd053032e1942822f6e38ccd546f"),
84-
strict=True,
85-
)
8679
@pytest.mark.slow
8780
@common.XfailIfNoCorstone300
8881
@common.parametrize("per_channel_quantization", quant_test_data)

0 commit comments

Comments
 (0)