Skip to content

Commit e38734e

Browse files
authored
Arm backend: Add support for int32 clamp (#15977)
Tosa does not support int32 dtype for clamp operator so instead we need to use Min/Max to implement clamp. Change-Id: Iea442901b2227610ebb5e5a0f1bca8d236e70d9d cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai --------- Signed-off-by: Ryan O'Shea <ryan.oshea3@arm.com>
1 parent fe1bc8a commit e38734e

File tree

6 files changed

+160
-7
lines changed

6 files changed

+160
-7
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .decompose_int16_activation_conv2d_pass import ( # noqa
5353
DecomposeConv2dWithInt16ActivationPass,
5454
)
55+
from .decompose_int32_clamp_pass import DecomposeInt32ClampPass # noqa
5556
from .decompose_int_pow_pass import DecomposeIntPowPass # noqa
5657
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
5758
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
DecomposeGluPass,
5656
DecomposeGroupedConvPass,
5757
DecomposeGroupNormPass,
58+
DecomposeInt32ClampPass,
5859
DecomposeIntPowPass,
5960
DecomposeLayerNormPass,
6061
DecomposeLeakyReLUPass,
@@ -122,7 +123,6 @@
122123

123124

124125
class ArmPassManager(PassManager):
125-
126126
def __init__(self, tosa_spec: TosaSpecification) -> None:
127127
self.tosa_spec = tosa_spec
128128
super().__init__()
@@ -174,6 +174,7 @@ def _tosa_pipeline(
174174
FuseQuantizedActivationPass(),
175175
RemoveGetItemPass(),
176176
ConvertToClampPass(),
177+
DecomposeInt32ClampPass(),
177178
DecomposeGroupNormPass(),
178179
DecomposeLayerNormPass(),
179180
DecomposeBatchNormNoStatsPass(),
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass
12+
13+
14+
class DecomposeInt32ClampPass(ArmPass):
15+
"""Rewrite int32 clamp into min/max chain since TOSA lacks int32 clamp support."""
16+
17+
_passes_required_after: Set[Type[ExportPass]] = set()
18+
_supported_ops = {
19+
exir_ops.edge.aten.clamp.default,
20+
torch.ops.aten.clamp.default,
21+
}
22+
23+
def _ensure_tensor(
24+
self,
25+
value,
26+
ref_tensor,
27+
dtype,
28+
rank,
29+
meta,
30+
):
31+
if value is None:
32+
return None
33+
return super().call_operator(
34+
exir_ops.edge.aten.full.default,
35+
((1,) * rank, value),
36+
{"dtype": dtype},
37+
meta,
38+
updated=True,
39+
)
40+
41+
def call_operator(self, op, args, kwargs, meta):
42+
val = meta["val"]
43+
if op not in self._supported_ops or val.dtype != torch.int32:
44+
return super().call_operator(op, args, kwargs, meta)
45+
46+
input_tensor = args[0]
47+
min_arg = args[1] if len(args) > 1 else None
48+
max_arg = args[2] if len(args) > 2 else None
49+
dtype = val.dtype
50+
rank = len(val.shape)
51+
52+
min_arg = self._ensure_tensor(min_arg, input_tensor, dtype, rank, meta)
53+
max_arg = self._ensure_tensor(max_arg, input_tensor, dtype, rank, meta)
54+
55+
current = input_tensor
56+
if max_arg is not None:
57+
current = super().call_operator(
58+
exir_ops.edge.aten.minimum.default,
59+
(current, max_arg),
60+
{},
61+
meta,
62+
updated=True,
63+
)
64+
if min_arg is not None:
65+
current = super().call_operator(
66+
exir_ops.edge.aten.maximum.default,
67+
(current, min_arg),
68+
{},
69+
meta,
70+
updated=True,
71+
)
72+
return current

backends/arm/operators/op_clamp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(self, *args):
4040
def _get_min_max_arguments(
4141
self, node: Node, dtype: torch.dtype
4242
) -> Tuple[int | float, int | float]:
43-
4443
def cast_type(value: Any) -> int | float:
4544
if isinstance(value, int):
4645
return value
@@ -91,7 +90,12 @@ def define_node(
9190
validate_valid_dtype(
9291
self.target,
9392
[inputs[0], output],
94-
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP16, ts.DType.FP32],
93+
[
94+
ts.DType.INT8,
95+
ts.DType.INT16,
96+
ts.DType.FP16,
97+
ts.DType.FP32,
98+
],
9599
output.tosa_spec,
96100
)
97101

backends/arm/test/ops/test_clamp.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,25 @@
3535
"rank_4_no_max": lambda: (torch.rand(1, 10, 10, 1) - 3, -3.3, None),
3636
}
3737

38+
test_data_suite_int32 = {
39+
"int32_rank2": lambda: (torch.randint(-50, 50, (2, 3), dtype=torch.int32), -10, 10),
40+
"int32_rank3_no_min": lambda: (
41+
torch.randint(-100, 100, (1, 3, 3), dtype=torch.int32),
42+
None,
43+
25,
44+
),
45+
"int32_rank3_no_max": lambda: (
46+
torch.randint(-100, 100, (1, 3, 3), dtype=torch.int32),
47+
-25,
48+
None,
49+
),
50+
"int32_rank4_large_range": lambda: (
51+
torch.randint(-200, 200, (1, 2, 4, 4), dtype=torch.int32),
52+
torch.iinfo(torch.int32).min,
53+
torch.iinfo(torch.int32).max,
54+
),
55+
}
56+
3857

3958
class Clamp(torch.nn.Module):
4059
def __init__(
@@ -53,7 +72,6 @@ def forward(self, x):
5372

5473
@common.parametrize("test_data", test_data_suite)
5574
def test_clamp_tosa_FP(test_data):
56-
5775
input_tensor, min_val, max_val = test_data()
5876
model = Clamp(min_val, max_val)
5977

@@ -69,7 +87,6 @@ def test_clamp_tosa_FP(test_data):
6987

7088
@common.parametrize("test_data", test_data_suite)
7189
def test_clamp_tosa_INT(test_data):
72-
7390
input_tensor, min_val, max_val = test_data()
7491
model = Clamp(min_val, max_val)
7592

@@ -84,6 +101,22 @@ def test_clamp_tosa_INT(test_data):
84101
pipeline.run()
85102

86103

104+
@common.parametrize("test_data", test_data_suite_int32)
105+
def test_clamp_tosa_INT_int32_inputs(test_data):
106+
input_tensor, min_val, max_val = test_data()
107+
model = Clamp(min_val, max_val)
108+
109+
pipeline = TosaPipelineINT[input_t](
110+
model,
111+
(input_tensor,),
112+
aten_op,
113+
exir_op,
114+
)
115+
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
116+
pipeline.pop_stage("quantize")
117+
pipeline.run()
118+
119+
87120
@common.parametrize("test_data", test_data_suite)
88121
def test_clamp_tosa_INT_a16w8(test_data):
89122
"""Test clamp operation with int16 I/O quantization for TOSA INT."""
@@ -103,7 +136,6 @@ def test_clamp_tosa_INT_a16w8(test_data):
103136
@common.parametrize("test_data", test_data_suite)
104137
@common.XfailIfNoCorstone300
105138
def test_clamp_u55_INT(test_data):
106-
107139
input_tensor, min_val, max_val = test_data()
108140
model = Clamp(min_val, max_val)
109141

@@ -140,7 +172,6 @@ def test_clamp_16a8w_u55_INT16(test_data):
140172
@common.parametrize("test_data", test_data_suite)
141173
@common.XfailIfNoCorstone320
142174
def test_clamp_u85_INT(test_data):
143-
144175
input_tensor, min_val, max_val = test_data()
145176
model = Clamp(min_val, max_val)
146177

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes.decompose_int32_clamp_pass import (
10+
DecomposeInt32ClampPass,
11+
)
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
14+
15+
input_t = Tuple[torch.Tensor]
16+
17+
18+
class ClampInt32(torch.nn.Module):
19+
test_data = {"rand": (torch.randint(-50, 50, (2, 3), dtype=torch.int32),)}
20+
21+
def forward(self, x: torch.Tensor):
22+
return torch.clamp(x, -10, 5)
23+
24+
25+
@common.parametrize("test_data", ClampInt32.test_data)
26+
def test_decompose_int32_clamp_pass(test_data: input_t):
27+
module = ClampInt32()
28+
pipeline = PassPipeline[input_t](
29+
module,
30+
test_data,
31+
quantize=False,
32+
ops_before_pass={
33+
"executorch_exir_dialects_edge__ops_aten_clamp_default": 1,
34+
},
35+
ops_after_pass={
36+
"executorch_exir_dialects_edge__ops_aten_minimum_default": 1,
37+
"executorch_exir_dialects_edge__ops_aten_maximum_default": 1,
38+
},
39+
ops_not_after_pass=[
40+
"executorch_exir_dialects_edge__ops_aten_clamp_default",
41+
],
42+
pass_list=[DecomposeInt32ClampPass],
43+
)
44+
pipeline.run()

0 commit comments

Comments
 (0)