Skip to content

Commit e938fea

Browse files
authored
[ET-VK] Implement select_at_dim_as_symint (#15644)
## Context The SDPA custom op accepts the `input_pos` (i.e. cache position) argument as a symbolic integer. The value of the symbolic integer is obtained by selecting the first element of a cache position input tensor and converting it to symint via local_scalar_dense. Currently, ET-VK handles this in a hacky manner. 1. the select + local_scalar_dense op pattern is removed, and the cache pos tensor is passed directly into the custom sdpa ops 2. Single element tensors that have users that are all select + local_scalar_dense will be interpreted as symints instead of tensors Unfortunately, this technique will not work for the huggingface implementation of transformer models, since the cache pos input tensor has not just a single element but is expected to be a vector of integer cache positions corresponding to all cache positions that will be updated. ## Changes Introduce a custom op to capture the select + local_scalar_dense op pattern, which is the proper way to handle the op pattern. Note that a custom op is needed because this op needs to access the staging buffer data of the input tensor, whereas `select` would typically be executed via a compute shader. The reason for this is because the `input_pos` value is needed to configure the sizes of attention weight tensors participating in the custom SDPA op, so the value must be set before any command buffers are dispatched. As a consequence of this change, the previous handling of select + local scalar dense can also be removed. Differential Revision: [D86340340](https://our.internmc.facebook.com/intern/diff/D86340340/)
1 parent b8bdfa2 commit e938fea

File tree

18 files changed

+224
-178
lines changed

18 files changed

+224
-178
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,6 @@ runtime.python_library(
6363
],
6464
)
6565

66-
runtime.python_library(
67-
name = "remove_local_scalar_dense",
68-
srcs = ["remove_local_scalar_dense_ops.py"],
69-
visibility = [
70-
"//executorch/backends/...",
71-
],
72-
deps = [
73-
"//caffe2:torch",
74-
"//executorch/exir:pass_base",
75-
"//executorch/exir/dialects:lib",
76-
],
77-
)
78-
7966
runtime.python_library(
8067
name = "remove_redundant_ops",
8168
srcs = ["remove_redundant_ops.py"],
@@ -161,7 +148,6 @@ runtime.python_library(
161148
":fuse_quantized_ops",
162149
":insert_prepack_nodes",
163150
":remove_asserts",
164-
":remove_local_scalar_dense",
165151
":remove_redundant_ops",
166152
":replace_qdq",
167153
":squeeze_unsqueeze_inputs",

backends/vulkan/_passes/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
remove_asserts,
1717
RemoveAssertsTransform,
1818
)
19-
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
20-
RemoveLocalScalarDenseOpsTransform,
21-
)
2219
from executorch.backends.vulkan._passes.remove_redundant_ops import (
2320
RemoveRedundantOpsTransform,
2421
)
@@ -35,7 +32,6 @@
3532
"insert_prepack_nodes",
3633
"remove_asserts",
3734
"RemoveAssertsTransform",
38-
"RemoveLocalScalarDenseOpsTransform",
3935
"RemoveRedundantOpsTransform",
4036
"ReplaceQDQPass",
4137
"SqueezeUnsqueezeInputs",

backends/vulkan/_passes/remove_local_scalar_dense_ops.py

Lines changed: 0 additions & 110 deletions
This file was deleted.

backends/vulkan/custom_ops_lib.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import executorch.backends.vulkan.patterns as vk_patterns
1010
import torch.library
1111

12+
from torch._subclasses.fake_tensor import FakeTensor
13+
1214
namespace = "et_vk"
1315
lib = torch.library.Library(namespace, "DEF")
1416

@@ -614,3 +616,18 @@ def add_q8ta_q8ta_q8to_impl(
614616
)
615617
lib.impl(name, add_q8ta_q8ta_q8to_impl, "CompositeExplicitAutograd")
616618
add_q8ta_q8ta_q8to_op = getattr(getattr(torch.ops, namespace), name)
619+
620+
#############################
621+
## select_as_symint ##
622+
#############################
623+
624+
625+
def select_as_symint_impl(x: torch.Tensor, dim: int, index: int):
626+
assert isinstance(x, FakeTensor)
627+
return x.fake_mode.shape_env.create_unbacked_symint()
628+
629+
630+
name = "select_as_symint"
631+
lib.define(f"{name}(Tensor x, int dim, int index) -> SymInt")
632+
lib.impl(name, select_as_symint_impl, "Meta")
633+
select_as_symint_op = getattr(getattr(torch.ops, namespace), name)

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -184,36 +184,6 @@ def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]:
184184

185185
return False, False
186186

187-
def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]:
188-
"""
189-
Scalar tensors are usually converted to scalar values in the graph via`
190-
scalar_tensor[0].item()` in Python, which translates to a chain of
191-
`local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph.
192-
This function marks the entire chain as supported by the Vulkan delegate.
193-
194-
Later, within vulkan_preprocess there will be a graph transform which replaces
195-
the chain with passing in the scalar tensor directly.
196-
197-
Similar to the `is_linear_permute` function, this function has 2 return values.
198-
"""
199-
if node.target == exir_ops.edge.aten.select_copy.int:
200-
if len(node.users) != 1:
201-
return False, False
202-
# pyre-ignore
203-
if node.args[0].meta["val"].numel() != 1:
204-
return False, False
205-
206-
local_scalar_dense = list(node.users.keys())[0]
207-
if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default:
208-
return False, False
209-
210-
return self.is_in_local_scalar_dense_chain(local_scalar_dense)
211-
212-
if node.target == torch.ops.aten._local_scalar_dense.default:
213-
return True, all(self.node_is_compatible(user)[0] for user in node.users)
214-
215-
return False, False
216-
217187
def log_skip(self, node: torch.fx.Node, reason: str) -> None:
218188
if node.op == "call_function":
219189
logger.info(
@@ -261,16 +231,6 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
261231
self.log_skip(node, "permute node of non compatible linear node")
262232
return False
263233

264-
(
265-
is_in_local_scalar_dense_chain,
266-
dst_node_is_compatible,
267-
) = self.is_in_local_scalar_dense_chain(node)
268-
if is_in_local_scalar_dense_chain and dst_node_is_compatible:
269-
return True
270-
elif is_in_local_scalar_dense_chain:
271-
self.log_skip(node, "local scalar dense of incompatible op node")
272-
return False
273-
274234
features = None
275235
if target not in vulkan_supported_ops:
276236
# For some ops, i.e. custom ops the name is registered instead of the

backends/vulkan/patterns/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ runtime.python_library(
1212
"quantized_linear.py",
1313
"quantized_convolution.py",
1414
"quantized_binary.py",
15+
"select_as_symint.py",
1516
],
1617
visibility = [
1718
"//executorch/backends/...",

backends/vulkan/patterns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import executorch.backends.vulkan.patterns.rope # noqa
1616

17+
import executorch.backends.vulkan.patterns.select_as_symint # noqa
18+
1719
import torch
1820

1921
from executorch.backends.vulkan.patterns.pattern_registry import (
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
11+
from executorch.backends.vulkan.patterns.pattern_registry import (
12+
PatternMatch,
13+
register_pattern_detector,
14+
register_pattern_replacement,
15+
)
16+
17+
from executorch.exir import ExportedProgram
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
20+
21+
class SelectAsSymIntMatch(PatternMatch):
22+
def __init__(self, local_scalar_dense_node: torch.fx.Node) -> None:
23+
self.anchor_node = local_scalar_dense_node
24+
self.match_found = False
25+
26+
# Check if the input to local_scalar_dense is a select_copy node
27+
if len(local_scalar_dense_node.args) < 1:
28+
return
29+
30+
select_node = local_scalar_dense_node.args[0]
31+
if not isinstance(select_node, torch.fx.Node):
32+
return
33+
34+
if (
35+
select_node.op != "call_function"
36+
or select_node.target != exir_ops.edge.aten.select_copy.int
37+
):
38+
return
39+
40+
# select_copy.int has signature: select_copy(Tensor self, int dim, int index)
41+
if len(select_node.args) < 3:
42+
return
43+
44+
self.select_node = select_node
45+
46+
self.tensor_node = select_node.args[0]
47+
self.dim_node = select_node.args[1]
48+
self.index_node = select_node.args[2]
49+
50+
self.all_nodes = [
51+
self.anchor_node,
52+
self.select_node,
53+
self.tensor_node,
54+
self.dim_node,
55+
self.index_node,
56+
]
57+
58+
self.match_found = True
59+
60+
61+
@register_pattern_detector("select_as_symint")
62+
def find_select_as_symint_patterns(
63+
node: torch.fx.Node,
64+
) -> Optional[SelectAsSymIntMatch]:
65+
if node.target != torch.ops.aten._local_scalar_dense.default:
66+
return None
67+
68+
matched_pattern = SelectAsSymIntMatch(node)
69+
if matched_pattern.match_found:
70+
return matched_pattern
71+
72+
return None
73+
74+
75+
##
76+
## Pattern Replacement
77+
##
78+
79+
80+
@register_pattern_replacement("select_as_symint")
81+
def replace_select_local_scalar_dense_with_select_as_symint(
82+
ep: ExportedProgram,
83+
graph_module: torch.fx.GraphModule,
84+
match: SelectAsSymIntMatch,
85+
):
86+
with graph_module.graph.inserting_before(match.anchor_node):
87+
new_node = graph_module.graph.create_node(
88+
"call_function",
89+
exir_ops.edge.et_vk.select_as_symint.default,
90+
args=(
91+
match.tensor_node,
92+
match.dim_node,
93+
match.index_node,
94+
),
95+
)
96+
97+
new_node.meta["val"] = match.anchor_node.meta["val"]
98+
match.anchor_node.replace_all_uses_with(new_node)
99+
100+
# # Remove both the local_scalar_dense and select_copy nodes
101+
# graph_module.graph.erase_node(match.anchor_node)
102+
# # Only erase select_node if it has no other users
103+
# if len(match.select_node.users) == 0:
104+
# graph_module.graph.erase_node(match.select_node)

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
649649
}
650650
}
651651

652-
if (should_propagate_resize) {
652+
if (should_propagate_resize || compute_graph->has_data_dependent_shapes()) {
653653
compute_graph->propagate_resize();
654654
}
655655

backends/vulkan/runtime/api/containers/StagingBuffer.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,20 @@ class StagingBuffer final {
112112
inline void set_staging_zeros() {
113113
memset(data(), 0, nbytes());
114114
}
115+
116+
template <typename T>
117+
T select_element_at_dim(
118+
const std::vector<int64_t>& sizes,
119+
const int64_t dim,
120+
const int64_t index) {
121+
int64_t stride = 1;
122+
for (size_t i = dim + 1; i < sizes.size(); ++i) {
123+
stride *= sizes[i];
124+
}
125+
const int64_t offset = index * stride;
126+
const T* typed_data = reinterpret_cast<const T*>(data());
127+
return typed_data[offset];
128+
}
115129
};
116130

117131
} // namespace api

0 commit comments

Comments
 (0)