Skip to content

Commit 97b86a0

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][ez] Fuse update_cache + custom_sdpa into sdpa_with_kv_cache
Pull Request resolved: #15618 SDPA used to be handled by a custom op `sdpa_with_kv_cache`, but it was eventually split (D62301837) into update_cache and custom_sdpa ops. However, having a single fused op is useful for Vulkan since it allows more control over how the cache tensors are stored and represented. Essentially, it makes it easier to manage the cache tensors and opens up opportunities for future optimizations. This diff introduces a fusion pass that does 2 things: 1. Combine update_cache and custom_sdpa back into sdpa_with_kv_cache 2. Ensure all references to the cache_pos symint use the same node - this prevents the select_at_dim_as_symint op from being called every time it is used. ghstack-source-id: 321258710 @exported-using-ghexport Differential Revision: [D86340339](https://our.internmc.facebook.com/intern/diff/D86340339/)
1 parent e938fea commit 97b86a0

File tree

4 files changed

+170
-0
lines changed

4 files changed

+170
-0
lines changed

backends/vulkan/patterns/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ runtime.python_library(
1212
"quantized_linear.py",
1313
"quantized_convolution.py",
1414
"quantized_binary.py",
15+
"sdpa.py",
1516
"select_as_symint.py",
1617
],
1718
visibility = [

backends/vulkan/patterns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import executorch.backends.vulkan.patterns.rope # noqa
1616

17+
import executorch.backends.vulkan.patterns.sdpa # noqa
18+
1719
import executorch.backends.vulkan.patterns.select_as_symint # noqa
1820

1921
import torch

backends/vulkan/patterns/sdpa.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Optional
8+
9+
import executorch.backends.vulkan.utils as utils
10+
11+
import torch
12+
13+
from executorch.backends.vulkan.patterns.pattern_registry import (
14+
PatternMatch,
15+
register_pattern_detector,
16+
register_pattern_replacement,
17+
)
18+
19+
from executorch.exir import ExportedProgram
20+
21+
22+
def is_update_cache_node(node: Any) -> bool:
23+
return utils.node_has_target(node, "llama::update_cache")
24+
25+
26+
def is_custom_sdpa_node(node: Any) -> bool:
27+
return utils.node_has_target(node, "llama::custom_sdpa")
28+
29+
30+
def is_sdpa_with_kv_cache_node(node: Any) -> bool:
31+
return utils.node_has_target(node, "llama::sdpa_with_kv_cache")
32+
33+
34+
class CausalSDPAMatch(PatternMatch):
35+
def __init__(self, custom_sdpa_node: torch.fx.Node) -> None:
36+
self.anchor_node = custom_sdpa_node
37+
self.match_found = False
38+
self.all_nodes = [self.anchor_node]
39+
40+
# llama.custom_sdpa has signature:
41+
# custom_sdpa(query, key_cache, value_cache, start_pos, attn_mask, dropout_p, is_causal, scale) -> output
42+
if len(custom_sdpa_node.args) < 4:
43+
return
44+
45+
self.query_node = custom_sdpa_node.args[0]
46+
self.key_cache_node = custom_sdpa_node.args[1]
47+
self.value_cache_node = custom_sdpa_node.args[2]
48+
self.start_pos_node = custom_sdpa_node.args[3]
49+
self.attn_mask_node = custom_sdpa_node.args[4]
50+
self.dropout_p_node = custom_sdpa_node.args[5]
51+
self.is_causal_node = custom_sdpa_node.args[6]
52+
if len(custom_sdpa_node.args) > 7:
53+
self.scale_node = custom_sdpa_node.args[7]
54+
else:
55+
self.scale_node = None
56+
57+
# try to find update key cache node
58+
self.update_key_cache_node = None
59+
for user in self.key_cache_node.users:
60+
if is_update_cache_node(user):
61+
self.update_key_cache_node = user
62+
break
63+
64+
self.key_projection_node = None
65+
if self.update_key_cache_node is not None:
66+
self.key_projection_node = self.update_key_cache_node.args[0]
67+
68+
# find update value cache node
69+
self.update_value_cache_node = None
70+
for user in self.value_cache_node.users:
71+
if is_update_cache_node(user):
72+
self.update_value_cache_node = user
73+
break
74+
75+
self.value_projection_node = None
76+
if self.update_value_cache_node is not None:
77+
self.value_projection_node = self.update_value_cache_node.args[0]
78+
79+
# We have additional optional arguments but we don't need to capture them
80+
# since the new op doesn't use them
81+
82+
self.match_found = True
83+
84+
85+
@register_pattern_detector("causal_sdpa")
86+
def find_causal_sdpa_patterns(
87+
node: torch.fx.Node,
88+
) -> Optional[CausalSDPAMatch]:
89+
if not is_custom_sdpa_node(node):
90+
return None
91+
92+
matched_pattern = CausalSDPAMatch(node)
93+
if matched_pattern.match_found:
94+
return matched_pattern
95+
96+
return None
97+
98+
99+
##
100+
## Pattern Replacement
101+
##
102+
103+
104+
def find_singleton_start_pos_node(graph_module: torch.fx.GraphModule):
105+
for node in graph_module.graph.nodes:
106+
if is_update_cache_node(node):
107+
return node.args[2]
108+
109+
if is_sdpa_with_kv_cache_node(node):
110+
return node.args[5]
111+
112+
raise Exception(
113+
"Could not find an instance of llama::update_cache or sdpa_with_kv_cache"
114+
)
115+
116+
117+
@register_pattern_replacement("causal_sdpa")
118+
def replace_custom_sdpa_with_causal_sdpa(
119+
ep: ExportedProgram,
120+
graph_module: torch.fx.GraphModule,
121+
match: CausalSDPAMatch,
122+
):
123+
assert match.update_key_cache_node is not None
124+
assert match.key_projection_node is not None
125+
assert match.update_value_cache_node is not None
126+
assert match.value_projection_node is not None
127+
128+
singleton_start_pos_node = find_singleton_start_pos_node(graph_module)
129+
130+
with graph_module.graph.inserting_before(match.anchor_node):
131+
new_node = graph_module.graph.create_node(
132+
"call_function",
133+
torch.ops.llama.sdpa_with_kv_cache.default,
134+
args=(
135+
match.query_node,
136+
match.key_projection_node,
137+
match.value_projection_node,
138+
match.key_cache_node,
139+
match.value_cache_node,
140+
singleton_start_pos_node,
141+
1,
142+
match.attn_mask_node,
143+
match.dropout_p_node,
144+
match.is_causal_node,
145+
match.scale_node,
146+
),
147+
)
148+
149+
new_node.meta["val"] = match.anchor_node.meta["val"]
150+
match.anchor_node.replace_all_uses_with(new_node)
151+
152+
# Manually erase update_cache nodes since DCE will not remove them since they
153+
# modify inputs (specifically, the cache args are modified)
154+
graph_module.graph.erase_node(match.update_key_cache_node)
155+
graph_module.graph.erase_node(match.update_value_cache_node)

backends/vulkan/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,18 @@ def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]:
373373
return None
374374

375375

376+
def node_has_target(node: Any, target: str):
377+
if not hasattr(node, "target"):
378+
return False
379+
380+
if isinstance(node.target, str):
381+
return node.target == target
382+
elif hasattr(node.target, "name"):
383+
return node.target.name() == target
384+
385+
return False
386+
387+
376388
##
377389
## Memory Layout, Storage Type Determination
378390
##

0 commit comments

Comments
 (0)