Skip to content

Commit

Permalink
Static Cache: load models with MQA or GQA (#28975)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Feb 13, 2024
1 parent da20209 commit 3e70a20
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,12 @@ def __init__(
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype

cache_shape = (max_batch_size, self.num_heads, self.max_cache_len, self.head_dim)
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.seen_tokens = 0
Expand Down
46 changes: 45 additions & 1 deletion tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
LlamaConfig,
LlamaForCausalLM,
SinkCache,
StaticCache,
)


@require_torch
class CacheTest(unittest.TestCase):
def test_cache_equivalence(self):
def test_dynamic_cache_retrocompatibility(self):
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
legacy_cache = ()
new_cache = DynamicCache()
Expand Down Expand Up @@ -120,6 +122,48 @@ def test_reorder_cache_retrocompatibility(self):
)
)

def test_static_cache_mha_mqa_gqa(self):
"""
Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query
attention (MQA)
"""

def _random_kvs(config):
# shape for key and values: (batch_size, num_heads, seq_len, head_dim)
random_keys = torch.rand(
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
device=torch_device,
)
random_values = torch.rand(
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
device=torch_device,
)
return random_keys, random_values

mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
)
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
self.assertTrue(cached_values.shape == (1, 32, 10, 128))

gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
)
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
self.assertTrue(cached_values.shape == (1, 4, 10, 128))

mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
)
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))


@require_torch_gpu
@slow
Expand Down

0 comments on commit 3e70a20

Please sign in to comment.