Skip to content

Commit 53b394f

Browse files
authored
Merge branch 'main' into dev_issue_fix
2 parents 68654d5 + 176800e commit 53b394f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1640
-566
lines changed

backends/apple/mps/serialization/mps_graph_serialize.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
67

8+
import importlib.resources as _resources
79
import json
810
import os
911
import tempfile
1012

11-
import pkg_resources
13+
import executorch.backends.apple.mps.serialization as serialization_package
1214
from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSGraph
1315
from executorch.exir._serialize._dataclass import _DataclassEncoder
1416
from executorch.exir._serialize._flatbuffer import _flatc_compile
@@ -19,7 +21,9 @@ def convert_to_flatbuffer(mps_graph: MPSGraph) -> bytes:
1921
with tempfile.TemporaryDirectory() as d:
2022
schema_path = os.path.join(d, "schema.fbs")
2123
with open(schema_path, "wb") as schema_file:
22-
schema_file.write(pkg_resources.resource_string(__name__, "schema.fbs"))
24+
schema_file.write(
25+
_resources.read_binary(serialization_package, "schema.fbs")
26+
)
2327
json_path = os.path.join(d, "schema.json")
2428
with open(json_path, "wb") as json_file:
2529
json_file.write(mps_graph_json.encode("ascii"))

backends/arm/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,10 @@ if(EXECUTORCH_BUILD_VGF)
7373
# vgf backend
7474
list(TRANSFORM _vgf_backend_sources PREPEND "${EXECUTORCH_ROOT}/")
7575
add_library(vgf_backend ${_vgf_backend_sources})
76+
install(TARGETS vgf_backend EXPORT ExecuTorchTargets)
7677
target_include_directories(
77-
vgf_backend PUBLIC ${_common_include_directories} ${VULKAN_HEADERS_PATH}
78-
${VOLK_HEADERS_PATH}
78+
vgf_backend PRIVATE ${_common_include_directories} ${VULKAN_HEADERS_PATH}
79+
${VOLK_HEADERS_PATH}
7980
)
8081
target_compile_options(
8182
vgf_backend PRIVATE -DUSE_VULKAN_WRAPPER -DUSE_VULKAN_VOLK

backends/arm/README.md

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ You can test to run some models with the full fvp test flow
8888
backends/arm/test/test_arm_baremetal.sh test_full_ethosu_fvp
8989
```
9090

91+
To run the unit test suite with VKML use the following. Note Vulkan SDK need to be installed.
92+
Have a look at install_vulkan_sdk() in .ci/scripts/setup-vulkan-linux-deps.sh on how to install Vulkan SDK.
93+
94+
```
95+
backends/arm/test/test_arm_baremetal.sh test_pytest_vkml
96+
```
97+
98+
You can test to run some models with the full VKML flow
99+
100+
```
101+
backends/arm/test/test_arm_baremetal.sh test_full_vkml
102+
```
103+
91104
## Unit tests
92105

93106
This is the structure of the test directory
@@ -102,6 +115,7 @@ test # Root test folder
102115
├── tosautil # Utility functions for TOSA artifacts
103116
├ common.py # Common functions and definitions used by many tests
104117
├ setup_testing.sh # Script to prepare testing for using the Corstone 3x0 FVP
118+
├ setup_testing_vkml.sh # Script to prepare testing for using the VKML
105119
├ test_arm_baremetal.sh # Help script to trigger testing
106120
```
107121

@@ -123,7 +137,7 @@ first you need to build and prepare some used target libs
123137

124138
```
125139
examples/arm/run.sh --model_name=add --build_only
126-
backends/arm/test/setup_testing.sh
140+
backends/arm/test/setup_testing.sh and/or backends/arm/test/setup_testing_vkml.sh
127141
```
128142

129143
The you can run the tests with
@@ -195,6 +209,38 @@ List of model specific and optional passes:
195209
- InsertCastForOpsWithInt64InputPass
196210
- Functionality:
197211
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
198-
- Example usage: backends/arm/test/models/test_llama.py
199212
- Supported Ops:
200213
- aten.embedding.default, aten.slice_copy.Tensor
214+
- Example usage:
215+
- backends/arm/test/models/test_llama.py
216+
217+
- ConvertInt64ConstOpsToInt32Pass
218+
- Functionalities:
219+
- Rewrites constant-producing ops that output int64 to instead output int32, when values are within int32 bounds.
220+
- Supported Ops:
221+
- `torch.full`, `torch.arange`, `torch.eye`, `torch.linspace`, `torch.tensor`
222+
- Example usage:
223+
- backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
224+
- backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
225+
226+
- ConvertInt64OutputOpsToInt32Pass
227+
- Overview:
228+
- Rewrites or removes operations that produce int64 outputs, converting them to int32 where possible.
229+
- Overflow checks are applied selectively; for ops without such checks, users need to ensure values fit within the int32 range.
230+
- Functionalities:
231+
1. Handling casting to int64:
232+
- (1) int32 -> int64:
233+
- Removes the cast and redirect uses of int64 to int32
234+
- (2) other types -> int64:
235+
- Rewrites the cast to other types -> int32
236+
- Supported Ops:
237+
- torch.ops.aten.to.\[dtype|dtype_layout\]
238+
- exir_ops.edge.dim_order_ops._to_dim_order_copy.default
239+
2. Post-process argmax outputs:
240+
- Inserts an int64->int32 cast after the argmax operations that produce int64 outputs:
241+
- Supported Ops:
242+
- torch.ops.aten.argmax.default
243+
- exir_ops.edge.aten.argmax.default
244+
- Example usage:
245+
- (Functionality 1) backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
246+
- (Functionality 2) backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
1717
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa
1818
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
19+
from .convert_int64_const_ops_to_int32 import ConvertInt64ConstOpsToInt32Pass # noqa
20+
from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa
1921
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
2022
from .convert_minmax_pass import ConvertMinMaxPass # noqa
2123
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
ConvertAnyDefaultDimDimsPass,
2121
ConvertExpandCopyToRepeatPass,
2222
ConvertFullLikeToFullPass,
23+
ConvertInt64ConstOpsToInt32Pass,
24+
ConvertInt64OutputOpsToInt32Pass,
2325
ConvertIntPowToMuls,
2426
ConvertMinMaxPass,
2527
ConvertMmToBmmPass,
@@ -98,6 +100,7 @@
98100
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
99101
from executorch.exir import ExportedProgram
100102
from executorch.exir.pass_manager import PassManager
103+
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
101104
from torch.fx import GraphModule
102105

103106

@@ -258,6 +261,11 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
258261
)
259262

260263
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
264+
self.add_pass(
265+
RemoveGraphAssertsPass()
266+
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
267+
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
268+
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
261269
self.add_pass(InsertCastForOpsWithInt64InputPass())
262270
self.add_pass(DecomposeEmbeddingPass())
263271
self.add_pass(DecomposeScaledDotProductAttention())
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
11+
import torch
12+
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
16+
logger = logging.getLogger(__name__)
17+
INT32_MIN = torch.iinfo(torch.int32).min
18+
INT32_MAX = torch.iinfo(torch.int32).max
19+
20+
21+
class ConvertInt64ConstOpsToInt32Pass(ExportPass):
22+
"""
23+
Rewrite constant ops that produce int64 to int32 where safe.
24+
25+
List of supported operatos:
26+
1. `torch.full`
27+
2. `torch.arange`
28+
3. `torch.eye`
29+
4. `torch.linspace`
30+
5. `torch.tensor`
31+
"""
32+
33+
torch_ops = [
34+
torch.ops.aten.full.default,
35+
torch.ops.aten.arange.default,
36+
torch.ops.aten.arange.start,
37+
torch.ops.aten.arange.start_step,
38+
torch.ops.aten.eye.default,
39+
torch.ops.aten.linspace.default,
40+
]
41+
42+
def call(self, graph_module: torch.fx.GraphModule):
43+
modified = False
44+
for node in graph_module.graph.nodes:
45+
if node.op != "call_function":
46+
continue
47+
48+
if node.target not in ComputeConstantOpsAOT.targeted_ops + self.torch_ops:
49+
continue
50+
51+
data = node.target(*node.args, **node.kwargs)
52+
if data.dtype is not torch.int64:
53+
continue
54+
55+
min_val, max_val = torch.min(data), torch.max(data)
56+
if INT32_MIN <= min_val and max_val <= INT32_MAX:
57+
logger.warning(
58+
f"Casting {node.name} from torch.int64 to torch.int32"
59+
f" defined in {node.meta.get('stack_trace','[no stack trace found]')}"
60+
)
61+
node.update_kwarg("dtype", torch.int32)
62+
modified = True
63+
else:
64+
logger.warning(
65+
f"[{node.name}] has values: min={min_val}, max={max_val}, which exceeds int32 range "
66+
f"([{INT32_MIN}, {INT32_MAX}]); not converting dtype to int32."
67+
)
68+
69+
if modified:
70+
graph_module.graph.eliminate_dead_code()
71+
graph_module.recompile()
72+
graph_module = super().call(graph_module).graph_module
73+
74+
return PassResult(graph_module, modified)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
11+
import torch
12+
from executorch.backends.arm._passes.arm_pass_utils import (
13+
create_node,
14+
get_first_fake_tensor,
15+
set_node_arg,
16+
)
17+
from executorch.exir.dialects._ops import ops as exir_ops
18+
from executorch.exir.pass_base import ExportPass, PassResult
19+
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class ConvertInt64OutputOpsToInt32Pass(ExportPass):
25+
"""
26+
Rewrites or removes operations that produce int64 outputs, converting them
27+
to int32 where possible.
28+
29+
30+
Currently, this pass handles casting and argmax operators:
31+
1. int32 -> int64:
32+
removes the cast and redirects all uses to the original int32 value.
33+
2. other types -> int64:
34+
rewrites the cast to produce int32 instead of int64.
35+
3. torch.argmax()
36+
insert an int64->int32 cast after the argmax node
37+
38+
Future extensions may include operators that return int64 outputs by default
39+
(e.g., `argmin`), rewriting them or inserting an int64 -> int32 cast to yield
40+
int32 results.
41+
42+
Note: Overflow checks are applied selectively in this pass. For operators without
43+
such checks, it is the user's responsibility to ensure that values fit within
44+
the int32 range.
45+
"""
46+
47+
aten_cast_ops = (
48+
torch.ops.aten.to.dtype,
49+
torch.ops.aten.to.dtype_layout,
50+
)
51+
edge_cast_ops = (exir_ops.edge.dim_order_ops._to_dim_order_copy.default,)
52+
53+
aten_argmax_ops = (torch.ops.aten.argmax.default,)
54+
edge_argmax_ops = (exir_ops.edge.aten.argmax.default,)
55+
56+
aten_ops = aten_cast_ops + aten_argmax_ops
57+
edge_ops = edge_cast_ops + edge_argmax_ops
58+
59+
# dtype is specified in args
60+
cast_ops_args = (
61+
torch.ops.aten.to.dtype, # to_2: node.args: (gt, torch.int64) node.kwargs: {}
62+
)
63+
# dtype is specified in kwargs
64+
cast_ops_kwargs = (
65+
torch.ops.aten.to.dtype_layout, # to_1: node.args: (unsqueeze,) node.kwargs: {'dtype': torch.int64, 'layout': torch.strided, 'device': device(type='cpu')}
66+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default, # node.args: (aten_gt_scalar,) node.kwargs: {'dtype': torch.int64, 'dim_order': [0, 1]}
67+
)
68+
69+
def _get_decomposition(self, op):
70+
if op in self.edge_ops:
71+
return exir_ops.edge.aten._to_copy.default
72+
73+
if op in self.aten_ops:
74+
return torch.ops.aten._to_copy.default
75+
76+
raise RuntimeError(
77+
f"[{self.__class__.__name__}] Can't get decomposition for op {op}"
78+
)
79+
80+
def _convert_casting_operators(self, node: torch.fx.Node):
81+
input_node = node.all_input_nodes[0]
82+
input_dtype = get_first_fake_tensor(input_node).dtype
83+
# Case 1: int32 -> int64 - removes the ops
84+
if input_dtype == torch.int32:
85+
users = [user for user in node.users if node != user]
86+
for user in users:
87+
logger.warning(
88+
f"Removing int32->int64 casting node {node.name} defined in"
89+
f" {node.meta.get('stack_trace','[no stack trace found]')}"
90+
)
91+
user.replace_input_with(node, input_node)
92+
# Case 2: other types -> int64 - rewrites to cast to int32
93+
else:
94+
if node.target in self.cast_ops_kwargs:
95+
set_node_arg(node, "dtype", torch.int32)
96+
elif node.target in self.cast_ops_args:
97+
set_node_arg(node, 1, torch.int32)
98+
else:
99+
raise RuntimeError(f"Unexpected target {node.target} in {node.name}")
100+
output_dtype = get_first_fake_tensor(node).dtype
101+
logger.warning(
102+
f"Converting casting node {node.name} from {input_dtype}->{output_dtype} to"
103+
f" {input_dtype}->torch.int32 defined in {node.meta.get('stack_trace','[no stack trace found]')}"
104+
)
105+
106+
def _convert_argmax_operators(self, node: torch.fx.Node, graph: torch.fx.Graph):
107+
output_tensor = node
108+
to_copy_op = self._get_decomposition(node.target)
109+
with graph.inserting_after(node):
110+
cast_after = create_node(
111+
graph,
112+
to_copy_op,
113+
args=(output_tensor,),
114+
kwargs={
115+
"dtype": torch.int32,
116+
},
117+
)
118+
users = [user for user in node.users if user != cast_after]
119+
for user in users:
120+
user.replace_input_with(output_tensor, cast_after)
121+
logger.warning(
122+
f"Inserting a casting node {cast_after.name} after {node.name} to cast int64 output"
123+
f" to int32 for {node.name} defined in {node.meta.get('stack_trace','[no stack trace found]')}"
124+
)
125+
126+
def call(self, graph_module: torch.fx.GraphModule):
127+
modified = False
128+
graph = graph_module.graph
129+
for node in list(graph.nodes):
130+
if node.op != "call_function":
131+
continue
132+
if node.target not in self.aten_ops + self.edge_ops:
133+
continue
134+
output_dtype = get_first_fake_tensor(node).dtype
135+
if output_dtype != torch.int64:
136+
continue
137+
138+
if node.target in self.aten_cast_ops + self.edge_cast_ops:
139+
self._convert_casting_operators(node)
140+
elif node.target in self.aten_argmax_ops + self.edge_argmax_ops:
141+
# TODO: Add range check based on the input tensor shape before casting the output
142+
self._convert_argmax_operators(node, graph)
143+
else:
144+
raise RuntimeError(f"Unexpected target {node.target} in {node.name}")
145+
146+
modified = True
147+
148+
if modified:
149+
graph_module.graph.eliminate_dead_code()
150+
graph_module.recompile()
151+
graph_module = super().call(graph_module).graph_module
152+
153+
return PassResult(graph_module, modified)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
ethos-u-vela @ git+https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela@d37febc1715edf0d236c2ff555739a8a9aadcf9a
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
diffusers[torch] == 0.33.1

0 commit comments

Comments
 (0)