diff --git a/backends/aoti/passes/TARGETS b/backends/aoti/passes/TARGETS new file mode 100644 index 00000000000..82f3b40dc54 --- /dev/null +++ b/backends/aoti/passes/TARGETS @@ -0,0 +1,17 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "passes", + srcs = [ + "replace_view_copy_with_view.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + ], +) diff --git a/backends/apple/metal/replace_slice_copy_with_slice.py b/backends/aoti/passes/replace_view_copy_with_view.py similarity index 84% rename from backends/apple/metal/replace_slice_copy_with_slice.py rename to backends/aoti/passes/replace_view_copy_with_view.py index 4f16759af35..c2be14f96e5 100644 --- a/backends/apple/metal/replace_slice_copy_with_slice.py +++ b/backends/aoti/passes/replace_view_copy_with_view.py @@ -4,9 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-strict +# This pass replaces view_copy ops with view ops. This is different than +# exir/passes/replace_view_copy_with_view.py and exir/passes/reinplace.py +# because this should only be used in the AOTInductor backend, as it +# has less restrictions on whether the tensor memory is densely packed, -from typing import Dict, Iterable, Tuple +from typing import Dict, Iterable import torch from executorch.exir.dialects._ops import ops @@ -15,33 +18,30 @@ from torch import fx -_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = ( - torch.ops.aten.slice_copy.Tensor, - ops.edge.aten.slice_copy.Tensor, -) - -_SLICE_TARGETS: Dict[ +_VIEW_TARGETS: Dict[ torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload ] = { torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor, + torch.ops.aten.select_copy.int: torch.ops.aten.select.int, + ops.edge.aten.select_copy.int: ops.edge.aten.select.int, } -class ReplaceSliceCopyWithSlicePass(ExportPass): - """Replace non-mutated ``slice_copy`` results with ``slice`` views.""" +class ReplaceViewCopyWithViewPass(ExportPass): + """Replace non-mutated ``view_copy`` type of ops with ``view`` ops.""" def call(self, graph_module: fx.GraphModule) -> PassResult: graph_changed = False for node in graph_module.graph.nodes: - if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS: + if node.op != "call_function" or node.target not in _VIEW_TARGETS: continue if self._has_blocking_user(node, node.users.keys()): continue - node.target = _SLICE_TARGETS[node.target] + node.target = _VIEW_TARGETS[node.target] graph_changed = True if graph_changed: diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index db57bed8fc7..13a3534004b 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -12,8 +12,8 @@ from typing import Any, Dict, final, List, Optional, Set import torch -from executorch.backends.apple.metal.replace_slice_copy_with_slice import ( - ReplaceSliceCopyWithSlicePass, +from executorch.backends.aoti.passes.replace_view_copy_with_view import ( + ReplaceViewCopyWithViewPass, ) from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir._warnings import experimental @@ -93,7 +93,7 @@ def preprocess( mps_edge_program = move_to_device_pass(edge_program, "mps") # replace slice_copy with slice - ReplaceSliceCopyWithSlicePass()(mps_edge_program.graph_module) + ReplaceViewCopyWithViewPass()(mps_edge_program.graph_module) edge_program_module = mps_edge_program.module() diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS index fe57f7f1b63..22987a728ca 100644 --- a/backends/cuda/TARGETS +++ b/backends/cuda/TARGETS @@ -6,13 +6,13 @@ runtime.python_library( name = "cuda_backend", srcs = [ "cuda_backend.py", - "replace_slice_copy_with_slice.py", ], visibility = [ "//executorch/...", ], deps = [ "//caffe2:torch", + "//executorch/backends/aoti/passes:passes", "//executorch/exir/_serialize:lib", "//executorch/exir/backend:backend_details", "//executorch/exir/backend:compile_spec_schema", diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 8047639a30c..e1c3cf9719f 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -12,8 +12,8 @@ from typing import Any, Dict, final, List, Optional, Set import torch -from executorch.backends.cuda.replace_slice_copy_with_slice import ( - ReplaceSliceCopyWithSlicePass, +from executorch.backends.aoti.passes.replace_view_copy_with_view import ( + ReplaceViewCopyWithViewPass, ) from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir._warnings import experimental @@ -123,8 +123,8 @@ def preprocess( # Move the edge_program from CPU to CUDA for aoti compile cuda_edge_program = move_to_device_pass(edge_program, "cuda") - # replace slice_copy with slice - ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module) + # replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int + ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module) cuda_edge_program = cuda_edge_program.run_decompositions( cuda_decomposition_table diff --git a/backends/cuda/replace_slice_copy_with_slice.py b/backends/cuda/replace_slice_copy_with_slice.py deleted file mode 100644 index 4f16759af35..00000000000 --- a/backends/cuda/replace_slice_copy_with_slice.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from typing import Dict, Iterable, Tuple - -import torch -from executorch.exir.dialects._ops import ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ExportPass, PassResult -from torch import fx - - -_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = ( - torch.ops.aten.slice_copy.Tensor, - ops.edge.aten.slice_copy.Tensor, -) - -_SLICE_TARGETS: Dict[ - torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload -] = { - torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, - ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor, -} - - -class ReplaceSliceCopyWithSlicePass(ExportPass): - """Replace non-mutated ``slice_copy`` results with ``slice`` views.""" - - def call(self, graph_module: fx.GraphModule) -> PassResult: - graph_changed = False - - for node in graph_module.graph.nodes: - if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS: - continue - - if self._has_blocking_user(node, node.users.keys()): - continue - - node.target = _SLICE_TARGETS[node.target] - graph_changed = True - - if graph_changed: - graph_module.graph.lint() - graph_module.recompile() - - return PassResult(graph_module, graph_changed) - - def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool: - for user in users: - if self._is_mutating_user(node, user) or self._is_view_user(node, user): - return True - return False - - def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool: - if user.op == "call_method": - # Treat in-place tensor methods conservatively as mutations only when the - # method name ends with ``_`` which is the PyTorch convention for mutation. - return isinstance(user.target, str) and user.target.endswith("_") - - if user.op != "call_function": - return False - - target = user.target - if not hasattr(target, "_schema"): - return False - - schema = target._schema # pyre-ignore[16] - # Positional arguments - for index, arg in enumerate(user.args): - if arg is node and self._argument_mutates(schema, index): - return True - - # Keyword arguments - for name, arg in user.kwargs.items(): - if arg is node and self._argument_mutates(schema, name): - return True - - return False - - def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool: - if user.op == "call_method": - # Treat tensor methods conservatively and assume they may be view-producing. - return True - - if user.op != "call_function": - return False - - target = user.target - if getattr(target, "is_view", False): - for arg in user.args: - if arg is node: - return True - for arg in user.kwargs.values(): - if arg is node: - return True - - return False - - def _argument_mutates( - self, schema: torch._C.FunctionSchema, key: int | str - ) -> bool: - arguments = schema.arguments - if isinstance(key, int): - if key >= len(arguments): - return False - argument = arguments[key] - else: - argument = next((arg for arg in arguments if arg.name == key), None) - if argument is None: - return False - - alias_info = argument.alias_info - return bool(alias_info and alias_info.is_write)