From 5ff0919e0e80a175fb6a9a6b466c652513863283 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 14 Oct 2025 15:53:59 -0700 Subject: [PATCH 1/4] Extend reinplace pass to select_copy.int --- backends/cuda/TARGETS | 2 +- backends/cuda/cuda_backend.py | 6 +++--- ...h_slice.py => replace_view_copy_with_view.py} | 16 ++++++++++------ 3 files changed, 14 insertions(+), 10 deletions(-) rename backends/cuda/{replace_slice_copy_with_slice.py => replace_view_copy_with_view.py} (88%) diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS index fe57f7f1b63..e36b5d60352 100644 --- a/backends/cuda/TARGETS +++ b/backends/cuda/TARGETS @@ -6,7 +6,7 @@ runtime.python_library( name = "cuda_backend", srcs = [ "cuda_backend.py", - "replace_slice_copy_with_slice.py", + "replace_view_copy_with_view.py", ], visibility = [ "//executorch/...", diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 8047639a30c..200351143ab 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -12,8 +12,8 @@ from typing import Any, Dict, final, List, Optional, Set import torch -from executorch.backends.cuda.replace_slice_copy_with_slice import ( - ReplaceSliceCopyWithSlicePass, +from executorch.backends.cuda.replace_view_copy_with_view import ( + ReplaceViewCopyWithViewPass, ) from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir._warnings import experimental @@ -124,7 +124,7 @@ def preprocess( cuda_edge_program = move_to_device_pass(edge_program, "cuda") # replace slice_copy with slice - ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module) + ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module) cuda_edge_program = cuda_edge_program.run_decompositions( cuda_decomposition_table diff --git a/backends/cuda/replace_slice_copy_with_slice.py b/backends/cuda/replace_view_copy_with_view.py similarity index 88% rename from backends/cuda/replace_slice_copy_with_slice.py rename to backends/cuda/replace_view_copy_with_view.py index 4f16759af35..b0847d364ea 100644 --- a/backends/cuda/replace_slice_copy_with_slice.py +++ b/backends/cuda/replace_view_copy_with_view.py @@ -15,33 +15,37 @@ from torch import fx -_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = ( +_VIEW_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = ( torch.ops.aten.slice_copy.Tensor, ops.edge.aten.slice_copy.Tensor, + torch.ops.aten.select_copy.int, + ops.edge.aten.select_copy.int, ) -_SLICE_TARGETS: Dict[ +_VIEW_TARGETS: Dict[ torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload ] = { torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor, + torch.ops.aten.select_copy.int: torch.ops.aten.select.int, + ops.edge.aten.select_copy.int: ops.edge.aten.select.int, } -class ReplaceSliceCopyWithSlicePass(ExportPass): - """Replace non-mutated ``slice_copy`` results with ``slice`` views.""" +class ReplaceViewCopyWithViewPass(ExportPass): + """Replace non-mutated ``view_copy`` type of ops with ``view`` ops.""" def call(self, graph_module: fx.GraphModule) -> PassResult: graph_changed = False for node in graph_module.graph.nodes: - if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS: + if node.op != "call_function" or node.target not in _VIEW_COPY_TARGETS: continue if self._has_blocking_user(node, node.users.keys()): continue - node.target = _SLICE_TARGETS[node.target] + node.target = _VIEW_TARGETS[node.target] graph_changed = True if graph_changed: From d2c7c5552c2bbd60bdba8edf1eb97d5a5261fa1d Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 14 Oct 2025 21:18:07 -0700 Subject: [PATCH 2/4] Make it work --- backends/aoti/passes/TARGETS | 17 +++ .../passes}/replace_view_copy_with_view.py | 5 +- backends/apple/metal/metal_backend.py | 6 +- .../metal/replace_slice_copy_with_slice.py | 118 ------------------ backends/cuda/TARGETS | 2 +- backends/cuda/cuda_backend.py | 4 +- 6 files changed, 27 insertions(+), 125 deletions(-) create mode 100644 backends/aoti/passes/TARGETS rename backends/{cuda => aoti/passes}/replace_view_copy_with_view.py (93%) delete mode 100644 backends/apple/metal/replace_slice_copy_with_slice.py diff --git a/backends/aoti/passes/TARGETS b/backends/aoti/passes/TARGETS new file mode 100644 index 00000000000..82f3b40dc54 --- /dev/null +++ b/backends/aoti/passes/TARGETS @@ -0,0 +1,17 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "passes", + srcs = [ + "replace_view_copy_with_view.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + ], +) diff --git a/backends/cuda/replace_view_copy_with_view.py b/backends/aoti/passes/replace_view_copy_with_view.py similarity index 93% rename from backends/cuda/replace_view_copy_with_view.py rename to backends/aoti/passes/replace_view_copy_with_view.py index b0847d364ea..da2550cb25d 100644 --- a/backends/cuda/replace_view_copy_with_view.py +++ b/backends/aoti/passes/replace_view_copy_with_view.py @@ -4,7 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-strict +# This pass replaces view_copy ops with view ops. This is different than +# exir/passes/replace_view_copy_with_view.py and exir/passes/reinplace.py +# because this should only be used in the AOTInductor backend, as it +# has less restrictions on whether the tensor memory is densely packed, from typing import Dict, Iterable, Tuple diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index db57bed8fc7..13a3534004b 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -12,8 +12,8 @@ from typing import Any, Dict, final, List, Optional, Set import torch -from executorch.backends.apple.metal.replace_slice_copy_with_slice import ( - ReplaceSliceCopyWithSlicePass, +from executorch.backends.aoti.passes.replace_view_copy_with_view import ( + ReplaceViewCopyWithViewPass, ) from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir._warnings import experimental @@ -93,7 +93,7 @@ def preprocess( mps_edge_program = move_to_device_pass(edge_program, "mps") # replace slice_copy with slice - ReplaceSliceCopyWithSlicePass()(mps_edge_program.graph_module) + ReplaceViewCopyWithViewPass()(mps_edge_program.graph_module) edge_program_module = mps_edge_program.module() diff --git a/backends/apple/metal/replace_slice_copy_with_slice.py b/backends/apple/metal/replace_slice_copy_with_slice.py deleted file mode 100644 index 4f16759af35..00000000000 --- a/backends/apple/metal/replace_slice_copy_with_slice.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from typing import Dict, Iterable, Tuple - -import torch -from executorch.exir.dialects._ops import ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ExportPass, PassResult -from torch import fx - - -_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = ( - torch.ops.aten.slice_copy.Tensor, - ops.edge.aten.slice_copy.Tensor, -) - -_SLICE_TARGETS: Dict[ - torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload -] = { - torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, - ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor, -} - - -class ReplaceSliceCopyWithSlicePass(ExportPass): - """Replace non-mutated ``slice_copy`` results with ``slice`` views.""" - - def call(self, graph_module: fx.GraphModule) -> PassResult: - graph_changed = False - - for node in graph_module.graph.nodes: - if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS: - continue - - if self._has_blocking_user(node, node.users.keys()): - continue - - node.target = _SLICE_TARGETS[node.target] - graph_changed = True - - if graph_changed: - graph_module.graph.lint() - graph_module.recompile() - - return PassResult(graph_module, graph_changed) - - def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool: - for user in users: - if self._is_mutating_user(node, user) or self._is_view_user(node, user): - return True - return False - - def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool: - if user.op == "call_method": - # Treat in-place tensor methods conservatively as mutations only when the - # method name ends with ``_`` which is the PyTorch convention for mutation. - return isinstance(user.target, str) and user.target.endswith("_") - - if user.op != "call_function": - return False - - target = user.target - if not hasattr(target, "_schema"): - return False - - schema = target._schema # pyre-ignore[16] - # Positional arguments - for index, arg in enumerate(user.args): - if arg is node and self._argument_mutates(schema, index): - return True - - # Keyword arguments - for name, arg in user.kwargs.items(): - if arg is node and self._argument_mutates(schema, name): - return True - - return False - - def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool: - if user.op == "call_method": - # Treat tensor methods conservatively and assume they may be view-producing. - return True - - if user.op != "call_function": - return False - - target = user.target - if getattr(target, "is_view", False): - for arg in user.args: - if arg is node: - return True - for arg in user.kwargs.values(): - if arg is node: - return True - - return False - - def _argument_mutates( - self, schema: torch._C.FunctionSchema, key: int | str - ) -> bool: - arguments = schema.arguments - if isinstance(key, int): - if key >= len(arguments): - return False - argument = arguments[key] - else: - argument = next((arg for arg in arguments if arg.name == key), None) - if argument is None: - return False - - alias_info = argument.alias_info - return bool(alias_info and alias_info.is_write) diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS index e36b5d60352..22987a728ca 100644 --- a/backends/cuda/TARGETS +++ b/backends/cuda/TARGETS @@ -6,13 +6,13 @@ runtime.python_library( name = "cuda_backend", srcs = [ "cuda_backend.py", - "replace_view_copy_with_view.py", ], visibility = [ "//executorch/...", ], deps = [ "//caffe2:torch", + "//executorch/backends/aoti/passes:passes", "//executorch/exir/_serialize:lib", "//executorch/exir/backend:backend_details", "//executorch/exir/backend:compile_spec_schema", diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 200351143ab..e1c3cf9719f 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -12,7 +12,7 @@ from typing import Any, Dict, final, List, Optional, Set import torch -from executorch.backends.cuda.replace_view_copy_with_view import ( +from executorch.backends.aoti.passes.replace_view_copy_with_view import ( ReplaceViewCopyWithViewPass, ) from executorch.exir._serialize._named_data_store import NamedDataStore @@ -123,7 +123,7 @@ def preprocess( # Move the edge_program from CPU to CUDA for aoti compile cuda_edge_program = move_to_device_pass(edge_program, "cuda") - # replace slice_copy with slice + # replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module) cuda_edge_program = cuda_edge_program.run_decompositions( From b3de67aec22216c8adefc3ed72abee4cefc2723f Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 14 Oct 2025 21:39:35 -0700 Subject: [PATCH 3/4] address comments --- backends/aoti/passes/replace_view_copy_with_view.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/backends/aoti/passes/replace_view_copy_with_view.py b/backends/aoti/passes/replace_view_copy_with_view.py index da2550cb25d..7622039db6e 100644 --- a/backends/aoti/passes/replace_view_copy_with_view.py +++ b/backends/aoti/passes/replace_view_copy_with_view.py @@ -18,13 +18,6 @@ from torch import fx -_VIEW_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = ( - torch.ops.aten.slice_copy.Tensor, - ops.edge.aten.slice_copy.Tensor, - torch.ops.aten.select_copy.int, - ops.edge.aten.select_copy.int, -) - _VIEW_TARGETS: Dict[ torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload ] = { @@ -42,7 +35,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: graph_changed = False for node in graph_module.graph.nodes: - if node.op != "call_function" or node.target not in _VIEW_COPY_TARGETS: + if node.op != "call_function" or node.target not in _VIEW_TARGETS: continue if self._has_blocking_user(node, node.users.keys()): From 30bb8799bb6cccddf1e5d31a3183ce967071c9e8 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 14 Oct 2025 21:40:04 -0700 Subject: [PATCH 4/4] lint --- backends/aoti/passes/replace_view_copy_with_view.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/aoti/passes/replace_view_copy_with_view.py b/backends/aoti/passes/replace_view_copy_with_view.py index 7622039db6e..c2be14f96e5 100644 --- a/backends/aoti/passes/replace_view_copy_with_view.py +++ b/backends/aoti/passes/replace_view_copy_with_view.py @@ -9,7 +9,7 @@ # because this should only be used in the AOTInductor backend, as it # has less restrictions on whether the tensor memory is densely packed, -from typing import Dict, Iterable, Tuple +from typing import Dict, Iterable import torch from executorch.exir.dialects._ops import ops