Skip to content
12 changes: 10 additions & 2 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@
from executorch.backends.arm._passes.decompose_select import ( # type: ignore[import-not-found]
DecomposeSelectPass,
)
from executorch.backends.arm._passes.decompose_softmax_pass import DecomposeSoftmaxPass
from executorch.backends.arm._passes.decompose_softmax_unstable_pass import (
from executorch.backends.arm._passes.decompose_softmax_pass import ( # type: ignore[import-not-found]
DecomposeSoftmaxPass,
)
from executorch.backends.arm._passes.decompose_softmax_unstable_pass import ( # type: ignore[import-not-found]
DecomposeSoftmaxUnstablePass,
)
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
Expand Down Expand Up @@ -85,6 +87,10 @@
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform

from executorch.backends.transforms.replace_scalar_tensor_with_full import ( # type: ignore[import-not-found]
ReplaceScalarTensorWithFullPass,
)

from executorch.backends.transforms.replace_scalar_with_tensor import (
ReplaceScalarWithTensorArgPass,
)
Expand Down Expand Up @@ -143,6 +149,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
return self._transform(exported_program.graph_module)

def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(ReplaceScalarTensorWithFullPass())
self.add_pass(ReplaceScalarWithTensorArgPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
Expand Down Expand Up @@ -213,4 +220,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeSoftmaxPass())

self.add_pass(ConvertMinMaxPass())
self.add_pass(ReplaceScalarTensorWithFullPass())
return self._transform(graph_module)
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def is_node_supported(
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
torch.ops.aten.scalar_tensor.default,
]

return supported
Expand Down
1 change: 0 additions & 1 deletion backends/arm/test/models/test_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class TestConformer(unittest.TestCase):
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
"torch.ops.aten._assert_scalar.default": 10,
"torch.ops.aten._local_scalar_dense.default": 1,
"torch.ops.aten.scalar_tensor.default": 2,
"torch.ops.higher_order.executorch_call_delegate": 6,
}

Expand Down
137 changes: 137 additions & 0 deletions backends/arm/test/ops/test_scalar_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.xnnpack.test.tester.tester import Quantize
from parameterized import parameterized


float_test_data_suite = [
# (test_name, scalar input, scalar input type,)
(
"scalar_tensor_float_1",
3.7,
torch.float32,
),
(
"scalar_tensor_float_2",
66,
torch.float32,
),
]

int_test_data_suite = [
# (test_name, scalar input, scalar input type,)
(
"scalar_tensor_int32",
33,
torch.int32,
),
(
"scalar_tensor_int8",
8,
torch.int8,
),
(
"scalar_tensor_int16",
16 * 16 * 16,
torch.int16,
),
]


class ScalarTensor(torch.nn.Module):
def __init__(self, scalar, dtype=torch.float32):
super().__init__()
self.scalar = scalar
self.dtype = dtype

def forward(self):
return torch.scalar_tensor(self.scalar, dtype=self.dtype)


class TestScalarTensor(unittest.TestCase):

def _test_scalar_tensor_tosa_MI_pipeline(
self, module: torch.nn.Module, expected_output
):
test_outputs = []
in_data = ()

(
ArmTester(
module,
example_inputs=in_data,
compile_spec=common.get_tosa_compile_spec(
"TOSA-0.80+MI",
),
)
.export()
.check_count({"torch.ops.aten.scalar_tensor.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_get_output(test_outputs, inputs=in_data)
)
self._verify_output(test_outputs, expected_output)

def _test_scalar_tensor_tosa_BI_pipeline(
self, module: torch.nn.Module, expected_output
):
test_outputs = []
in_data = ()
tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+BI")
compile_spec = common.get_tosa_compile_spec(tosa_spec)
quantizer = TOSAQuantizer(tosa_spec).set_io(get_symmetric_quantization_config())

(
ArmTester(
module,
example_inputs=in_data,
compile_spec=compile_spec,
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check_count({"torch.ops.aten.full.default": 1}) # Already replaced
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_get_output(test_outputs, inputs=in_data)
)
self._verify_output(test_outputs, expected_output)

def _verify_output(self, test_outputs, expected_output):
out_data = torch.squeeze(test_outputs[0][0])
assert out_data == expected_output
assert out_data.dtype == expected_output.dtype

@parameterized.expand(int_test_data_suite + float_test_data_suite)
def test_scalar_tensor_tosa_MI( # Note TOSA MI supports all types
self, test_name: str, scalar_value, scalar_type
):
scalar = scalar_value
dtype = scalar_type
self._test_scalar_tensor_tosa_MI_pipeline(
ScalarTensor(scalar, dtype), torch.scalar_tensor(scalar, dtype=dtype)
)

@parameterized.expand(float_test_data_suite)
def test_scalar_tensor_tosa_BI(self, test_name: str, scalar_value, scalar_type):
scalar = scalar_value
dtype = scalar_type
self._test_scalar_tensor_tosa_BI_pipeline(
ScalarTensor(scalar, dtype), torch.scalar_tensor(scalar, dtype=dtype)
)
54 changes: 54 additions & 0 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,60 @@ def serialize(
def is_quantized(self) -> bool:
return self.stages[self.stage_name(tester.Quantize)] is not None

def run_method_and_get_output(
self,
test_outputs: List,
inputs: Optional[Tuple[torch.Tensor]] = None,
stage: Optional[str] = None,
num_runs=1,
):
"""
Returns the run_artifact output of 'stage'. This output is returned as parameter of type List.
Returns self to allow the function to be run in a test chain.

Args:
stage: (Optional[str]): The name of the stage to compare.
The default is the latest run stage.
test_output: All output results.
inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data.
The default is random data.
"""
edge_stage = self.stages[self.stage_name(tester.ToEdge)]
if edge_stage is None:
edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)]
assert (
edge_stage is not None
), "To get outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."

stage = stage or self.cur
test_stage = self.stages[stage]

exported_program = self.stages[self.stage_name(tester.Export)].artifact
output_nodes = get_output_nodes(exported_program)
output_qparams = get_output_quantization_params(output_nodes)

quantization_scales = []
for node in output_qparams:
quantization_scales.append(getattr(output_qparams[node], "scale", None))

# Loop inputs and get outputs of the test stage.
for run_iteration in range(num_runs):
reference_input = inputs if inputs else next(self.generate_random_inputs())

input_shapes = [
generated_input.shape if hasattr(generated_input, "shape") else (1,)
for generated_input in reference_input
]
input_shape_str = ", ".join([str(list(i)) for i in input_shapes])
logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}")

test_output, _ = pytree.tree_flatten(
test_stage.run_artifact(reference_input)
)
test_outputs.append(test_output)

return self

def run_method_and_compare_outputs(
self,
inputs: Optional[Tuple[torch.Tensor]] = None,
Expand Down
35 changes: 6 additions & 29 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
)
from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.backends.transforms.replace_scalar_tensor_with_full import (
ReplaceScalarTensorWithFullPass,
)
from executorch.backends.transforms.replace_scalar_with_tensor import (
ReplaceScalarWithTensorArgPass,
)
Expand Down Expand Up @@ -1723,35 +1726,9 @@ def call_operator(self, op, args, kwargs, meta):
register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceScalarTensorWithFullPass(ExportPass):
"""
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
scalar_tensor is not supported, so this is an opt_level=0 pass.
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
exir_ops.edge.aten.scalar_tensor.default,
torch.ops.aten.scalar_tensor.default,
}:
return super().call_operator(op, args, kwargs, meta)

return super().call_operator(
exir_ops.edge.aten.full.default,
(
[1],
args[0],
),
{"dtype": torch.float32},
meta,
)
register_cadence_pass(CadencePassAttribute(opt_level=0))(
ReplaceScalarTensorWithFullPass
)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
Expand Down
42 changes: 42 additions & 0 deletions backends/transforms/replace_scalar_tensor_with_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you this is the right thing to do.

# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, Tuple

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from torch.fx.node import Argument


class ReplaceScalarTensorWithFullPass(ExportPass):
"""
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
exir_ops.edge.aten.scalar_tensor.default,
torch.ops.aten.scalar_tensor.default,
}:
return super().call_operator(op, args, kwargs, meta)

return super().call_operator(
exir_ops.edge.aten.full.default,
(
[1],
args[0],
),
{"dtype": kwargs["dtype"]},
meta,
)
Loading