Skip to content

Commit 9075855

Browse files
authored
Arm backend: support mean.default (#15363)
Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent f7ca57e commit 9075855

File tree

4 files changed

+57
-8
lines changed

4 files changed

+57
-8
lines changed

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919

2020

2121
def get_meandim_decomposition(op) -> tuple:
22-
if op == exir_ops.edge.aten.mean.dim:
22+
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
2323
return (
2424
exir_ops.edge.aten.sum.dim_IntList,
2525
exir_ops.edge.aten.full.default,
2626
exir_ops.edge.aten.mul.Tensor,
2727
)
28-
if op == torch.ops.aten.mean.dim:
28+
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
2929
return (
3030
torch.ops.aten.sum.dim_IntList,
3131
torch.ops.aten.full.default,
@@ -35,17 +35,17 @@ def get_meandim_decomposition(op) -> tuple:
3535

3636

3737
def get_avgpool(op):
38-
if op == exir_ops.edge.aten.mean.dim:
38+
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
3939
return exir_ops.edge.aten.avg_pool2d.default
40-
if op == torch.ops.aten.mean.dim:
40+
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
4141
return torch.ops.aten.avg_pool2d.default
4242
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
4343

4444

4545
def get_view(op):
46-
if op == exir_ops.edge.aten.mean.dim:
46+
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
4747
return exir_ops.edge.aten.view_copy.default
48-
if op == torch.ops.aten.mean.dim:
48+
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
4949
return torch.ops.aten.view_copy.default
5050
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5151

@@ -87,13 +87,18 @@ def __init__(self, graph_module, tosa_spec):
8787
)
8888

8989
def call_operator(self, op, args, kwargs, meta):
90-
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
90+
if op not in (
91+
exir_ops.edge.aten.mean.dim,
92+
torch.ops.aten.mean.dim,
93+
exir_ops.edge.aten.mean.default,
94+
torch.ops.aten.mean.default,
95+
):
9196
return super().call_operator(op, args, kwargs, meta)
9297

9398
x = get_node_arg(args, 0)
9499
input_shape = list(x.data.shape)
95100
output_shape = list(meta["val"].shape)
96-
dims_to_reduce = get_node_arg(args, 1)
101+
dims_to_reduce = get_node_arg(args, 1, range(len(input_shape)))
97102
if dims_to_reduce is None:
98103
dims_to_reduce = range(len(input_shape))
99104
dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@
178178
exir_ops.edge.aten.native_group_norm.default,
179179
exir_ops.edge.aten.sigmoid.default,
180180
exir_ops.edge.aten.mean.dim,
181+
exir_ops.edge.aten.mean.default,
181182
exir_ops.edge.aten.mm.default,
182183
exir_ops.edge.aten.minimum.default,
183184
exir_ops.edge.aten.maximum.default,

backends/arm/scripts/parse_test_names.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"hardswish.default",
1515
"linear.default",
1616
"maximum.default",
17+
"mean.default",
1718
"multihead_attention.default",
1819
"adaptive_avg_pool2d.default",
1920
"bitwise_right_shift.Tensor",

backends/arm/test/ops/test_mean_dim.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
7+
from typing import Callable
8+
79
import torch
810
from executorch.backends.arm.test import common
911
from executorch.backends.arm.test.tester.test_pipeline import (
@@ -344,3 +346,43 @@ def test_mean_dim_vgf_INT(test_data):
344346
tosa_version="TOSA-1.0+INT",
345347
)
346348
pipeline.run()
349+
350+
351+
mean_input_t = tuple[torch.Tensor, bool]
352+
353+
354+
class MeanDefault(torch.nn.Module):
355+
def forward(self, tensor: torch.Tensor, keepdim: bool):
356+
return tensor.mean()
357+
358+
test_data_suite: dict[str, Callable[[], mean_input_t]] = {
359+
"rank1": lambda: (
360+
torch.rand(
361+
1,
362+
),
363+
False,
364+
),
365+
"rank2": lambda: (torch.rand(5, 5), True),
366+
"rank4": lambda: (torch.rand(5, 1, 10, 1), False),
367+
}
368+
369+
370+
@common.parametrize("test_data", MeanDefault.test_data_suite)
371+
def test_mean_tosa_FP(test_data):
372+
pipeline = TosaPipelineFP[mean_input_t](
373+
MeanDefault(),
374+
test_data(),
375+
[], # Might be sum, avgpool, or both
376+
)
377+
pipeline.run()
378+
379+
380+
@common.parametrize("test_data", MeanDefault.test_data_suite)
381+
def test_mean_tosa_INT(test_data):
382+
pipeline = TosaPipelineINT[mean_input_t](
383+
MeanDefault(),
384+
test_data(),
385+
[], # Might be sum, avgpool, or both
386+
symmetric_io_quantization=True,
387+
)
388+
pipeline.run()

0 commit comments

Comments
 (0)