Skip to content

Commit d6d33eb

Browse files
committed
Leverage update_cache op to reduce overhead from cache update
Summary: This likely might be a short lived optimization where in future we can replace update_cache op with index_put_ op. This is what original StaticCache does, however this requires cache transpose for custom_sdpa (which can also be fixed). We will leverage custom cache for now, however in near future this should not be needed. This option however will allow us to bypass any transposes if the need continues Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent da80c9e commit d6d33eb

File tree

6 files changed

+345
-6
lines changed

6 files changed

+345
-6
lines changed

optimum/commands/export/executorch.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
def parse_args_executorch(parser):
2929
required_group = parser.add_argument_group("Required arguments")
3030
required_group.add_argument(
31-
"-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from."
31+
"-m",
32+
"--model",
33+
type=str,
34+
required=True,
35+
help="Model ID on huggingface.co or path on disk to load model from.",
3236
)
3337
required_group.add_argument(
3438
"-o",
@@ -57,6 +61,12 @@ def parse_args_executorch(parser):
5761
action="store_true",
5862
help="For decoder-only models to use custom sdpa with static kv cache to boost performance. Defaults to False.",
5963
)
64+
required_group.add_argument(
65+
"--use_custom_kv_cache",
66+
required=False,
67+
action="store_true",
68+
help="For decoder-only models to use custom kv cache for static cache that updates cache using custom op. Defaults to False.",
69+
)
6070
required_group.add_argument(
6171
"--qlinear",
6272
required=False,
@@ -84,6 +94,8 @@ def run(self):
8494
kwargs = {}
8595
if self.args.use_custom_sdpa:
8696
kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa
97+
if self.args.use_custom_kv_cache:
98+
kwargs["use_custom_kv_cache"] = self.args.use_custom_kv_cache
8799
if self.args.qlinear:
88100
kwargs["qlinear"] = self.args.qlinear
89101
if self.args.qembedding:
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Dict, Optional, Tuple, Union
8+
9+
import torch
10+
11+
12+
try:
13+
from transformers.cache_utils import StaticCache
14+
except ImportError:
15+
# If transformers is not installed, raise an ImportError
16+
try:
17+
from transformers.cache_utils import StaticCache
18+
except ImportError:
19+
raise ImportError("transformers is not installed. Please install it to use StaticCache.")
20+
21+
22+
class ETCustomStaticCache(StaticCache):
23+
"""
24+
Custom KV Cache implementation for ExecutorTorch that inherits from Hugging Face's StaticCache
25+
but uses custom operations for cache updates similar to ExecutorTorch's CustomStaticCache.
26+
"""
27+
28+
def __init__(
29+
self,
30+
config,
31+
max_batch_size: int,
32+
max_cache_len: Optional[int] = None,
33+
device: Union[torch.device, str, None] = None,
34+
dtype: torch.dtype = torch.float32,
35+
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
36+
):
37+
super().__init__(
38+
config=config,
39+
max_batch_size=max_batch_size,
40+
max_cache_len=max_cache_len,
41+
device=device,
42+
dtype=dtype,
43+
layer_device_map=layer_device_map,
44+
)
45+
46+
# make sure layer_device_map is none
47+
assert layer_device_map is None
48+
49+
# Clear existing caches
50+
self.key_cache = []
51+
self.value_cache = []
52+
53+
# Initialize cache buffers with our custom shape
54+
cache_shape = (
55+
self.max_batch_size,
56+
self.max_cache_len,
57+
self.num_key_value_heads,
58+
self.head_dim,
59+
)
60+
assert device is None or device == "cpu", "Device must be None or 'cpu'"
61+
62+
for _ in range(config.num_hidden_layers):
63+
self.new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device="cpu")
64+
self.new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device="cpu")
65+
66+
self.key_cache.append(self.new_layer_key_cache)
67+
self.value_cache.append(self.new_layer_value_cache)
68+
69+
def update(
70+
self,
71+
key_states: torch.Tensor,
72+
value_states: torch.Tensor,
73+
layer_idx: int,
74+
cache_kwargs: Optional[Dict[str, Any]] = None,
75+
) -> Tuple[torch.Tensor, torch.Tensor]:
76+
"""
77+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
78+
using custom operations.
79+
80+
Args:
81+
key_states (`torch.Tensor`):
82+
The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
83+
value_states (`torch.Tensor`):
84+
The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim]
85+
layer_idx (`int`):
86+
The index of the layer to cache the states for.
87+
cache_kwargs (`Dict[str, Any]`, `optional`):
88+
Additional arguments for the cache update.
89+
90+
Returns:
91+
A tuple containing the updated key and value states.
92+
"""
93+
assert cache_kwargs is not None
94+
95+
# Get cache position from cache_kwargs (used by StaticCache)
96+
cache_position = cache_kwargs.get("cache_position")
97+
assert cache_position is not None
98+
99+
# Get the current cache for this layer
100+
k_out = self.key_cache[layer_idx]
101+
v_out = self.value_cache[layer_idx]
102+
103+
# Transpose key and value states to match our cache shape
104+
# From [batch_size, n_heads, seq_len, head_dim] to [batch_size, seq_len, n_heads, head_dim]
105+
k_val = key_states.transpose(1, 2)
106+
v_val = value_states.transpose(1, 2)
107+
108+
# Use custom operations to update the cache
109+
# Update cache with indices for more complex update patterns
110+
assert isinstance(cache_position, torch.Tensor)
111+
start_pos = cache_position[0].item()
112+
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
113+
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
114+
115+
# Return the updated cache in the format expected by the model
116+
# Transpose back from [batch_size, seq_len, n_heads, head_dim] to [batch_size, n_heads, seq_len, head_dim]
117+
return k_out.transpose(1, 2), v_out.transpose(1, 2)
118+
119+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
120+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
121+
# Occupied cache == any slot in the 2nd dim (sequence length) holds a non-zero value
122+
# This is different from StaticCache which checks the 3rd dim
123+
return (self.key_cache[layer_idx][0, :, 0].any(dim=-1)).sum()
124+
125+
@classmethod
126+
def from_legacy_cache(
127+
cls,
128+
config,
129+
legacy_cache,
130+
max_cache_len=None,
131+
device=None,
132+
dtype=None,
133+
):
134+
"""
135+
Create an ETCustomStaticCache from a legacy cache implementation.
136+
137+
Args:
138+
config: The model configuration
139+
legacy_cache: The legacy cache implementation
140+
max_cache_len: The maximum cache length
141+
device: The device for the new cache
142+
dtype: The data type for the new cache
143+
144+
Returns:
145+
A new ETCustomStaticCache instance
146+
"""
147+
assert hasattr(legacy_cache, "k_cache") and hasattr(legacy_cache, "v_cache")
148+
# Extract dimensions from the legacy cache
149+
assert len(legacy_cache.k_cache.shape) == 4
150+
if legacy_cache.k_cache.shape[1] == legacy_cache.n_heads:
151+
# Shape is [batch_size, n_heads, seq_len, head_dim]
152+
max_batch_size = legacy_cache.k_cache.shape[0]
153+
else:
154+
# Shape is [batch_size, seq_len, n_heads, head_dim]
155+
max_batch_size = legacy_cache.k_cache.shape[0]
156+
157+
# Use the legacy cache's device and dtype if not specified
158+
if device is None and hasattr(legacy_cache, "device"):
159+
device = legacy_cache.device
160+
elif device is None and hasattr(legacy_cache.k_cache, "device"):
161+
device = legacy_cache.k_cache.device
162+
163+
if dtype is None and hasattr(legacy_cache, "dtype"):
164+
dtype = legacy_cache.dtype
165+
elif dtype is None and hasattr(legacy_cache.k_cache, "dtype"):
166+
dtype = legacy_cache.k_cache.dtype
167+
168+
assert device is None or device == "cpu"
169+
assert dtype is None or dtype == torch.float32
170+
171+
# Use the legacy cache's max_seq_len if max_cache_len is not specified
172+
if max_cache_len is None and hasattr(legacy_cache, "max_seq_len"):
173+
max_cache_len = legacy_cache.max_seq_len
174+
elif max_cache_len is None and hasattr(legacy_cache, "max_cache_len"):
175+
max_cache_len = legacy_cache.max_cache_len
176+
177+
return cls(
178+
config=config,
179+
max_batch_size=max_batch_size,
180+
max_cache_len=max_cache_len,
181+
device=device,
182+
dtype=dtype,
183+
)
184+
185+
186+
def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype):
187+
"""
188+
Replace all KV caches in the module with ETCustomStaticCache.
189+
This modifies the model in place.
190+
191+
Args:
192+
module: The module to modify
193+
config: The model configuration
194+
195+
Returns:
196+
The modified module
197+
"""
198+
# Ensure custom ops are registered
199+
try:
200+
op = torch.ops.llama.update_cache
201+
assert op is not None
202+
except:
203+
try:
204+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
205+
206+
op = torch.ops.llama.update_cache
207+
assert op is not None
208+
except ImportError:
209+
raise ImportError(
210+
"ExecutorTorch custom operations are not available. "
211+
"Please install executorch with custom operations support."
212+
)
213+
214+
# Recursively replace KV caches
215+
return _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype)
216+
217+
218+
def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype):
219+
"""
220+
Helper function to recursively replace KV caches in the module.
221+
222+
Args:
223+
module: The module to modify
224+
config: The model configuration
225+
226+
Returns:
227+
The modified module
228+
"""
229+
assert hasattr(module, "static_cache")
230+
assert isinstance(
231+
module.static_cache, StaticCache
232+
), "Only StaticCache transform is supported. Hybrid cache with local global attention is not yet supported"
233+
# TODO: Add replace_cache to exported module
234+
# in transformer's executorch.py
235+
if getattr(module, "replace_cache", None) is not None:
236+
static_cache = ETCustomStaticCache(
237+
config=config,
238+
max_batch_size=generation_config.cache_config.batch_size,
239+
max_cache_len=generation_config.cache_config.max_cache_len,
240+
device=generation_config.cache_config.device,
241+
dtype=cache_dtype,
242+
)
243+
module.replace_cache(static_cache)
244+
else:
245+
module.static_cache = ETCustomStaticCache(
246+
config=config,
247+
max_batch_size=generation_config.cache_config.batch_size,
248+
max_cache_len=generation_config.cache_config.max_cache_len,
249+
device=generation_config.cache_config.device,
250+
dtype=cache_dtype,
251+
)
252+
for i in range(len(module.static_cache.key_cache)):
253+
setattr(module, f"key_cache_{i}", module.static_cache.key_cache[i])
254+
setattr(module, f"value_cache_{i}", module.static_cache.value_cache[i])
255+
256+
return module

optimum/exporters/executorch/integrations.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,11 @@ class CausalLMExportableModule(torch.nn.Module):
3737
This module ensures that the exported model is compatible with ExecuTorch.
3838
"""
3939

40-
def __init__(self, model):
40+
def __init__(self, model, use_custom_kv_cache=False):
4141
super().__init__()
4242
self.model = model
4343
self.config = model.config
44+
self.use_custom_kv_cache = use_custom_kv_cache
4445
self.metadata = save_config_to_constant_methods(model.config, model.generation_config)
4546

4647
def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgram]:
@@ -55,9 +56,34 @@ def export(self, input_ids=None, cache_position=None) -> Dict[str, ExportedProgr
5556
max_batch_size = 1
5657
max_cache_len = 4094
5758
exportable_module = TorchExportableModuleForDecoderOnlyLM(self.model, max_batch_size, max_cache_len)
59+
if self.use_custom_kv_cache:
60+
from optimum.executorch.attentions.custom_kv_cache import (
61+
replace_with_et_custom_kv_cache,
62+
)
63+
64+
replace_with_et_custom_kv_cache(
65+
exportable_module.model,
66+
self.model.config,
67+
self.model.generation_config,
68+
self.model.dtype,
69+
)
5870

5971
with torch.no_grad():
6072
exported_program = exportable_module.export(example_input_ids, example_cache_position)
73+
# Apply RemoveTransposes pass to remove
74+
# any back-to-back transpose ops that are not needed
75+
# e.g. output of update_cache is transposed and
76+
# input to custom_sdpa is transposed.
77+
from executorch.extension.llm.export.export_passes import (
78+
RemoveRedundantTransposes,
79+
)
80+
81+
mutated_gm = RemoveRedundantTransposes()(exported_program.module())[0]
82+
exported_program = torch.export.export(
83+
mutated_gm,
84+
args=(example_input_ids, example_cache_position),
85+
kwargs={},
86+
)
6187
else:
6288
from transformers.integrations.executorch import (
6389
convert_and_export_with_cache,
@@ -285,7 +311,10 @@ def _export_encoder(self, encoder_input_ids):
285311
# Export the encoder
286312
with torch.no_grad():
287313
exported_encoder = torch.export.export(
288-
wrapped_encoder, (encoder_input_ids,), dynamic_shapes=dynamic_shapes, strict=True
314+
wrapped_encoder,
315+
(encoder_input_ids,),
316+
dynamic_shapes=dynamic_shapes,
317+
strict=True,
289318
)
290319
return exported_encoder
291320

@@ -354,7 +383,9 @@ def export(
354383
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
355384

356385
self.exported_decoder = self._export_decoder(
357-
example_decoder_input_ids, example_encoder_hidden_states, example_cache_position
386+
example_decoder_input_ids,
387+
example_encoder_hidden_states,
388+
example_cache_position,
358389
)
359390

360391
return {
@@ -375,7 +406,9 @@ def generate(self, prompt_token_ids, max_new_tokens):
375406
for i in range(max_new_tokens - 1):
376407
# Run decoder for next token prediction
377408
logits = self.exported_decoder.module()(
378-
decoder_input_ids, encoder_output, torch.tensor([i], dtype=torch.long)
409+
decoder_input_ids,
410+
encoder_output,
411+
torch.tensor([i], dtype=torch.long),
379412
)
380413

381414
# Get next token

optimum/exporters/executorch/tasks/causal_lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
5555
batch_size = 1
5656
dtype = kwargs.get("dtype", "float32")
5757
use_custom_sdpa = kwargs.get("use_custom_sdpa", False)
58+
use_custom_kv_cache = kwargs.get("use_custom_kv_cache", False)
5859
attn_implementation = kwargs.get("attn_implementation", "custom_sdpa" if use_custom_sdpa else "sdpa")
5960
cache_implementation = kwargs.get("cache_implementation", "static")
6061
max_length = kwargs.get("max_length", 2048)
@@ -120,4 +121,4 @@ def load_causal_lm_model(model_name_or_path: str, **kwargs) -> CausalLMExportabl
120121

121122
unwrap_tensor_subclass(eager_model)
122123

123-
return CausalLMExportableModule(eager_model)
124+
return CausalLMExportableModule(eager_model, use_custom_kv_cache)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"optimum~=1.24",
1616
"executorch>=0.6.0",
1717
"transformers==4.51.0",
18+
"tiktoken",
1819
]
1920

2021
TESTS_REQUIRE = [

0 commit comments

Comments
 (0)