Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Adding overrides for max cache seq length #1449

Merged
merged 26 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def _model_generate(
# Technically this is not necessary, but it's a good way to ensure that
# the caches won't error on a different batch size. In addition, caches
# are not needed for a regular model call, so we just setup here
# TODO @joecummings this is being called multiple times resulting in many WARNINGs
if self.enable_kv_cache:
with context.device:
self._model.setup_caches(batch_size=curr_batch_size, dtype=self._dtype)
Expand All @@ -163,6 +162,7 @@ def _model_generate(
top_k=None, # do_sample is not supported currently
stop_tokens=self._tokenizer.stop_tokens,
)
self._model.reset_caches()
return torch.tensor(toks, dtype=torch.int32)


Expand Down
27 changes: 25 additions & 2 deletions tests/torchtune/generation/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ def generation_model_kv_cache_batched(self):
model.eval()
return model

@pytest.fixture
def generation_model_batched_fixed_cache_seq_len(self, dtype=torch.float32):
model = llama2(
vocab_size=4_000,
embed_dim=128,
num_layers=2,
num_heads=4,
num_kv_heads=4,
max_seq_len=2048,
)
fixed_init_model(model)
model.setup_caches(batch_size=3, dtype=dtype, decoder_max_seq_len=1024)
model.eval()
return model

@pytest.fixture
def prompt_tokens(self):
"""
Expand Down Expand Up @@ -256,11 +271,19 @@ def test_reproducibility(self, request, model1, model2, prompt_tokens):

@pytest.mark.parametrize(
"model1",
["generation_model_no_kv_cache", "generation_model_kv_cache_batched"],
[
"generation_model_no_kv_cache",
"generation_model_kv_cache_batched",
"generation_model_batched_fixed_cache_seq_len",
],
)
@pytest.mark.parametrize(
"model2",
["generation_model_no_kv_cache", "generation_model_kv_cache_batched"],
[
"generation_model_no_kv_cache",
"generation_model_kv_cache_batched",
"generation_model_batched_fixed_cache_seq_len",
],
)
@pytest.mark.parametrize(
"prompt1", ["prompt_tokens_batched", "prompt_tokens_batched_left_padded"]
Expand Down
46 changes: 41 additions & 5 deletions tests/torchtune/modules/model_fusion/test_fusion_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,34 @@ def random():
set_seed(1)


class DummyLayer(nn.Module):
class DummyCrossAttentionLayer(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)
self.cache_enabled = False
self.encoder_max_seq_len = None

def setup_cache(self, batch_size, dtype):
def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
self.cache_enabled = True
self.encoder_max_seq_len = encoder_max_seq_len

def reset_cache(self):
self.cache_enabled = False

def forward(self, x):
return self.linear(x)


class DummySelfAttentionLayer(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)
self.cache_enabled = False
self.decoder_max_seq_len = None

def setup_cache(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_len):
self.cache_enabled = True
self.decoder_max_seq_len = decoder_max_seq_len

def reset_cache(self):
self.cache_enabled = False
Expand All @@ -45,13 +65,13 @@ def dim(self) -> int:

@pytest.fixture
def layer(self, dim) -> nn.Module:
layer = DummyLayer(dim)
layer = DummySelfAttentionLayer(dim)
fixed_init_model(layer, min_val=-0.1, max_val=0.1)
return layer

@pytest.fixture
def fusion_layer(self, dim) -> nn.Module:
layer = DummyLayer(dim)
layer = DummyCrossAttentionLayer(dim)
fixed_init_model(layer, min_val=-0.2, max_val=0.2)
return layer

Expand Down Expand Up @@ -115,7 +135,23 @@ def test_setup_cache(self, fused_layer):
"""
Test that the cache methods works as expected.
"""
fused_layer.setup_cache(2, torch.float32)
fused_layer.setup_cache(
2, torch.float32, encoder_max_seq_len=10, decoder_max_seq_len=10
)
assert fused_layer.cache_enabled
fused_layer.reset_cache()
assert not fused_layer.cache_enabled

def test_setup_cache_different_cache_seq_len(self, fused_layer):
"""
Test that the cache methods works as expected.
"""
fused_layer.setup_cache(
2, torch.float32, encoder_max_seq_len=5, decoder_max_seq_len=10
)

assert fused_layer.layer.decoder_max_seq_len == 10
assert fused_layer.fusion_layer.encoder_max_seq_len == 5

assert not hasattr(fused_layer.layer, "encoder_max_seq_len")
assert not hasattr(fused_layer.fusion_layer, "decoder_max_seq_len")
2 changes: 1 addition & 1 deletion tests/torchtune/modules/model_fusion/test_fusion_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, dim, vocab_size):
self.v = nn.Linear(dim, dim)
self.output = nn.Linear(dim, vocab_size)

def setup_caches(self, batch_size, dtype):
def setup_caches(self, batch_size, dtype, *args, **kwargs):
self.cache_enabled = True

def caches_are_enabled(self):
Expand Down
153 changes: 153 additions & 0 deletions tests/torchtune/modules/test_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from torchtune.modules import KVCache

BSZ = 2
MAX_SEQ_LEN = 16
NUM_HEADS = 4
HEAD_DIM = 256
DTYPE = torch.float32


class TestKVCache:
@pytest.fixture()
def k_vals_full(self):
return (
torch.tril(torch.ones(MAX_SEQ_LEN, HEAD_DIM))[
None,
None,
:,
:,
]
.repeat(BSZ, NUM_HEADS, 1, 1)
.to(DTYPE)
)

@pytest.fixture()
def v_vals_full(self):
return (
torch.tril(torch.ones(MAX_SEQ_LEN, HEAD_DIM))[
None,
None,
:,
:,
].repeat(BSZ, NUM_HEADS, 1, 1)
* 2
).to(DTYPE)

@pytest.fixture()
def kv_cache(self):
return KVCache(BSZ, MAX_SEQ_LEN, NUM_HEADS, HEAD_DIM, DTYPE)

def test_kv_cache_init(self, kv_cache):
# kv cache should be init with zero
assert (kv_cache.k_cache == 0).all() and (kv_cache.v_cache == 0).all()

def test_kv_cache_reset(self, kv_cache, k_vals_full, v_vals_full):
kv_cache.update(k_vals_full, v_vals_full)
kv_cache.reset()
assert (kv_cache.k_cache == 0).all() and (kv_cache.v_cache == 0).all()
assert kv_cache.size == 0

def test_kv_cache_error_when_bsz_exceeded(self, kv_cache, k_vals_full, v_vals_full):
with pytest.raises(ValueError):
kv_cache.update(k_vals_full.repeat(4, 1, 1, 1), v_vals_full)

def test_kv_cache_error_when_seq_len_exceeded(
self, kv_cache, k_vals_full, v_vals_full
):
with pytest.raises(ValueError):
kv_cache.update(k_vals_full.repeat(1, 1, 4, 1), v_vals_full)

def test_kv_cache_error_when_seq_len_exceeded_after_update(
self, kv_cache, k_vals_full, v_vals_full
):
# test that the cache position is correctly being used to check for seq len exceeded
# make a valid update filling half the cache
kv_cache.update(
k_vals_full[:, :, : (MAX_SEQ_LEN // 2)],
v_vals_full[:, :, : (MAX_SEQ_LEN // 2)],
)
with pytest.raises(
ValueError,
match=f"cache has reached a sequence length of {MAX_SEQ_LEN + MAX_SEQ_LEN // 2}",
):
# now an invalid update exceeding the cache
kv_cache.update(k_vals_full, v_vals_full)

def test_kv_cache_size_update(self, kv_cache, k_vals_full, v_vals_full):
# tests that the kv_cache is correctly tracking the cache position

# make a valid update filling half the cache - like a prefill
kv_cache.update(
k_vals_full[:, :, : (MAX_SEQ_LEN // 2)],
v_vals_full[:, :, : (MAX_SEQ_LEN // 2)],
)
assert kv_cache.size == MAX_SEQ_LEN // 2
# now one update with the next key and value
kv_cache.update(
k_vals_full[:, :, (MAX_SEQ_LEN // 2) + 1].unsqueeze(-2),
v_vals_full[:, :, (MAX_SEQ_LEN // 2) + 1].unsqueeze(-2),
)
assert kv_cache.size == (MAX_SEQ_LEN // 2) + 1

def test_kv_cache_single_update(self, kv_cache, k_vals_full, v_vals_full):
# tests that the kv_cache is correctly returning the updated cache values
# after a single cache update

# make a valid update filling half the cache - like a prefill
k_out, v_out = kv_cache.update(
k_vals_full[:, :, : (MAX_SEQ_LEN // 2)],
v_vals_full[:, :, : (MAX_SEQ_LEN // 2)],
)

expected_k_out = torch.zeros_like(k_vals_full)
expected_v_out = torch.zeros_like(v_vals_full)

expected_k_out[:, :, torch.arange(0, (MAX_SEQ_LEN // 2))] = k_vals_full[
:, :, : (MAX_SEQ_LEN // 2)
]
expected_v_out[:, :, torch.arange(0, (MAX_SEQ_LEN // 2))] = v_vals_full[
:, :, : (MAX_SEQ_LEN // 2)
]

assert torch.equal(expected_k_out, k_out)
assert torch.equal(expected_v_out, v_out)

def test_kv_cache_multiple_updates(self, kv_cache, k_vals_full, v_vals_full):
# tests that the kv_cache is correctly returning the updated cache values
# after a single cache update, followed by another cache update with seq_len=1

# make an update filling half the cache - like a prefill
# fills position 0 through to (MAX_SEQ_LEN // 2) - 1
kv_cache.update(
k_vals_full[:, :, : (MAX_SEQ_LEN // 2)],
v_vals_full[:, :, : (MAX_SEQ_LEN // 2)],
)

# make an update for one more token, which is the value at
# (MAX_SEQ_LEN // 2)
k_out, v_out = kv_cache.update(
k_vals_full[:, :, (MAX_SEQ_LEN // 2)].unsqueeze(2),
v_vals_full[:, :, (MAX_SEQ_LEN // 2)].unsqueeze(2),
)

expected_k_out = torch.zeros_like(k_vals_full)
expected_v_out = torch.zeros_like(v_vals_full)

# cache should be incremented by one position
expected_k_out[:, :, torch.arange(0, ((MAX_SEQ_LEN // 2) + 1))] = k_vals_full[
:, :, : ((MAX_SEQ_LEN // 2) + 1)
]
expected_v_out[:, :, torch.arange(0, ((MAX_SEQ_LEN // 2) + 1))] = v_vals_full[
:, :, : ((MAX_SEQ_LEN // 2) + 1)
]

assert torch.equal(expected_k_out, k_out)
assert torch.equal(expected_v_out, v_out)
Loading
Loading