|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from typing import Any, Dict, Optional, Tuple, Union |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | + |
| 12 | +try: |
| 13 | + from transformers.cache_utils import StaticCache |
| 14 | +except ImportError: |
| 15 | + # If transformers is not installed, raise an ImportError |
| 16 | + try: |
| 17 | + from transformers.cache_utils import StaticCache |
| 18 | + except ImportError: |
| 19 | + raise ImportError("transformers is not installed. Please install it to use StaticCache.") |
| 20 | + |
| 21 | + |
| 22 | +class ETCustomStaticCache(StaticCache): |
| 23 | + """ |
| 24 | + Custom KV Cache implementation for ExecutorTorch that inherits from Hugging Face's StaticCache |
| 25 | + but uses custom operations for cache updates similar to ExecutorTorch's CustomStaticCache. |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + config, |
| 31 | + max_batch_size: int, |
| 32 | + max_cache_len: Optional[int] = None, |
| 33 | + device: Union[torch.device, str, None] = None, |
| 34 | + dtype: torch.dtype = torch.float32, |
| 35 | + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, |
| 36 | + ): |
| 37 | + super().__init__( |
| 38 | + config=config, |
| 39 | + max_batch_size=max_batch_size, |
| 40 | + max_cache_len=max_cache_len, |
| 41 | + device=device, |
| 42 | + dtype=dtype, |
| 43 | + layer_device_map=layer_device_map, |
| 44 | + ) |
| 45 | + |
| 46 | + # make sure layer_device_map is none |
| 47 | + assert layer_device_map is None |
| 48 | + |
| 49 | + # Clear existing caches |
| 50 | + self.key_cache = [] |
| 51 | + self.value_cache = [] |
| 52 | + |
| 53 | + # Initialize cache buffers with our custom shape |
| 54 | + cache_shape = ( |
| 55 | + self.max_batch_size, |
| 56 | + self.max_cache_len, |
| 57 | + self.num_key_value_heads, |
| 58 | + self.head_dim, |
| 59 | + ) |
| 60 | + assert device is None or device == "cpu", "Device must be None or 'cpu'" |
| 61 | + |
| 62 | + for _ in range(config.num_hidden_layers): |
| 63 | + self.new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device="cpu") |
| 64 | + self.new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device="cpu") |
| 65 | + |
| 66 | + self.key_cache.append(self.new_layer_key_cache) |
| 67 | + self.value_cache.append(self.new_layer_value_cache) |
| 68 | + |
| 69 | + def update( |
| 70 | + self, |
| 71 | + key_states: torch.Tensor, |
| 72 | + value_states: torch.Tensor, |
| 73 | + layer_idx: int, |
| 74 | + cache_kwargs: Optional[Dict[str, Any]] = None, |
| 75 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 76 | + """ |
| 77 | + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx` |
| 78 | + using custom operations. |
| 79 | +
|
| 80 | + Args: |
| 81 | + key_states (`torch.Tensor`): |
| 82 | + The new key states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] |
| 83 | + value_states (`torch.Tensor`): |
| 84 | + The new value states to cache. Shape: [batch_size, n_heads, seq_len, head_dim] |
| 85 | + layer_idx (`int`): |
| 86 | + The index of the layer to cache the states for. |
| 87 | + cache_kwargs (`Dict[str, Any]`, `optional`): |
| 88 | + Additional arguments for the cache update. |
| 89 | +
|
| 90 | + Returns: |
| 91 | + A tuple containing the updated key and value states. |
| 92 | + """ |
| 93 | + assert cache_kwargs is not None |
| 94 | + |
| 95 | + # Get cache position from cache_kwargs (used by StaticCache) |
| 96 | + cache_position = cache_kwargs.get("cache_position") |
| 97 | + assert cache_position is not None |
| 98 | + |
| 99 | + # Get the current cache for this layer |
| 100 | + k_out = self.key_cache[layer_idx] |
| 101 | + v_out = self.value_cache[layer_idx] |
| 102 | + |
| 103 | + # Transpose key and value states to match our cache shape |
| 104 | + # From [batch_size, n_heads, seq_len, head_dim] to [batch_size, seq_len, n_heads, head_dim] |
| 105 | + k_val = key_states.transpose(1, 2) |
| 106 | + v_val = value_states.transpose(1, 2) |
| 107 | + |
| 108 | + # Use custom operations to update the cache |
| 109 | + # Update cache with indices for more complex update patterns |
| 110 | + assert isinstance(cache_position, torch.Tensor) |
| 111 | + start_pos = cache_position[0].item() |
| 112 | + _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) |
| 113 | + _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) |
| 114 | + |
| 115 | + # Return the updated cache in the format expected by the model |
| 116 | + # Transpose back from [batch_size, seq_len, n_heads, head_dim] to [batch_size, n_heads, seq_len, head_dim] |
| 117 | + return k_out.transpose(1, 2), v_out.transpose(1, 2) |
| 118 | + |
| 119 | + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| 120 | + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
| 121 | + # Occupied cache == any slot in the 2nd dim (sequence length) holds a non-zero value |
| 122 | + # This is different from StaticCache which checks the 3rd dim |
| 123 | + return (self.key_cache[layer_idx][0, :, 0].any(dim=-1)).sum() |
| 124 | + |
| 125 | + @classmethod |
| 126 | + def from_legacy_cache( |
| 127 | + cls, |
| 128 | + config, |
| 129 | + legacy_cache, |
| 130 | + max_cache_len=None, |
| 131 | + device=None, |
| 132 | + dtype=None, |
| 133 | + ): |
| 134 | + """ |
| 135 | + Create an ETCustomStaticCache from a legacy cache implementation. |
| 136 | +
|
| 137 | + Args: |
| 138 | + config: The model configuration |
| 139 | + legacy_cache: The legacy cache implementation |
| 140 | + max_cache_len: The maximum cache length |
| 141 | + device: The device for the new cache |
| 142 | + dtype: The data type for the new cache |
| 143 | +
|
| 144 | + Returns: |
| 145 | + A new ETCustomStaticCache instance |
| 146 | + """ |
| 147 | + assert hasattr(legacy_cache, "k_cache") and hasattr(legacy_cache, "v_cache") |
| 148 | + # Extract dimensions from the legacy cache |
| 149 | + assert len(legacy_cache.k_cache.shape) == 4 |
| 150 | + if legacy_cache.k_cache.shape[1] == legacy_cache.n_heads: |
| 151 | + # Shape is [batch_size, n_heads, seq_len, head_dim] |
| 152 | + max_batch_size = legacy_cache.k_cache.shape[0] |
| 153 | + else: |
| 154 | + # Shape is [batch_size, seq_len, n_heads, head_dim] |
| 155 | + max_batch_size = legacy_cache.k_cache.shape[0] |
| 156 | + |
| 157 | + # Use the legacy cache's device and dtype if not specified |
| 158 | + if device is None and hasattr(legacy_cache, "device"): |
| 159 | + device = legacy_cache.device |
| 160 | + elif device is None and hasattr(legacy_cache.k_cache, "device"): |
| 161 | + device = legacy_cache.k_cache.device |
| 162 | + |
| 163 | + if dtype is None and hasattr(legacy_cache, "dtype"): |
| 164 | + dtype = legacy_cache.dtype |
| 165 | + elif dtype is None and hasattr(legacy_cache.k_cache, "dtype"): |
| 166 | + dtype = legacy_cache.k_cache.dtype |
| 167 | + |
| 168 | + assert device is None or device == "cpu" |
| 169 | + assert dtype is None or dtype == torch.float32 |
| 170 | + |
| 171 | + # Use the legacy cache's max_seq_len if max_cache_len is not specified |
| 172 | + if max_cache_len is None and hasattr(legacy_cache, "max_seq_len"): |
| 173 | + max_cache_len = legacy_cache.max_seq_len |
| 174 | + elif max_cache_len is None and hasattr(legacy_cache, "max_cache_len"): |
| 175 | + max_cache_len = legacy_cache.max_cache_len |
| 176 | + |
| 177 | + return cls( |
| 178 | + config=config, |
| 179 | + max_batch_size=max_batch_size, |
| 180 | + max_cache_len=max_cache_len, |
| 181 | + device=device, |
| 182 | + dtype=dtype, |
| 183 | + ) |
| 184 | + |
| 185 | + |
| 186 | +def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype): |
| 187 | + """ |
| 188 | + Replace all KV caches in the module with ETCustomStaticCache. |
| 189 | + This modifies the model in place. |
| 190 | +
|
| 191 | + Args: |
| 192 | + module: The module to modify |
| 193 | + config: The model configuration |
| 194 | +
|
| 195 | + Returns: |
| 196 | + The modified module |
| 197 | + """ |
| 198 | + # Ensure custom ops are registered |
| 199 | + try: |
| 200 | + op = torch.ops.llama.update_cache |
| 201 | + assert op is not None |
| 202 | + except: |
| 203 | + try: |
| 204 | + from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 |
| 205 | + |
| 206 | + op = torch.ops.llama.update_cache |
| 207 | + assert op is not None |
| 208 | + except ImportError: |
| 209 | + raise ImportError( |
| 210 | + "ExecutorTorch custom operations are not available. " |
| 211 | + "Please install executorch with custom operations support." |
| 212 | + ) |
| 213 | + |
| 214 | + # Recursively replace KV caches |
| 215 | + return _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype) |
| 216 | + |
| 217 | + |
| 218 | +def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype): |
| 219 | + """ |
| 220 | + Helper function to recursively replace KV caches in the module. |
| 221 | +
|
| 222 | + Args: |
| 223 | + module: The module to modify |
| 224 | + config: The model configuration |
| 225 | +
|
| 226 | + Returns: |
| 227 | + The modified module |
| 228 | + """ |
| 229 | + assert hasattr(module, "static_cache") |
| 230 | + assert isinstance( |
| 231 | + module.static_cache, StaticCache |
| 232 | + ), "Only StaticCache transform is supported. Hybrid cache with local global attention is not yet supported" |
| 233 | + # TODO: Add replace_cache to exported module |
| 234 | + # in transformer's executorch.py |
| 235 | + if getattr(module, "replace_cache", None) is not None: |
| 236 | + static_cache = ETCustomStaticCache( |
| 237 | + config=config, |
| 238 | + max_batch_size=generation_config.cache_config.batch_size, |
| 239 | + max_cache_len=generation_config.cache_config.max_cache_len, |
| 240 | + device=generation_config.cache_config.device, |
| 241 | + dtype=cache_dtype, |
| 242 | + ) |
| 243 | + module.replace_cache(static_cache) |
| 244 | + else: |
| 245 | + module.static_cache = ETCustomStaticCache( |
| 246 | + config=config, |
| 247 | + max_batch_size=generation_config.cache_config.batch_size, |
| 248 | + max_cache_len=generation_config.cache_config.max_cache_len, |
| 249 | + device=generation_config.cache_config.device, |
| 250 | + dtype=cache_dtype, |
| 251 | + ) |
| 252 | + for i in range(len(module.static_cache.key_cache)): |
| 253 | + setattr(module, f"key_cache_{i}", module.static_cache.key_cache[i]) |
| 254 | + setattr(module, f"value_cache_{i}", module.static_cache.value_cache[i]) |
| 255 | + |
| 256 | + return module |
0 commit comments