Skip to content

Commit

Permalink
deprecation cycle for seen_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Mar 14, 2024
1 parent e6277fa commit 8c29e49
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 10 deletions.
23 changes: 19 additions & 4 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import torch

from .configuration_utils import PretrainedConfig
from .utils import logging


logger = logging.get_logger(__name__)


@dataclass
Expand Down Expand Up @@ -57,6 +61,17 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
return max_length - new_seq_length
return previous_seq_length

@property
def seen_tokens(self):
logger.warning_once(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None


class DynamicCache(Cache):
"""
Expand All @@ -69,7 +84,7 @@ class DynamicCache(Cache):
def __init__(self) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Expand Down Expand Up @@ -121,7 +136,7 @@ def update(
"""
# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]
self._seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
Expand Down Expand Up @@ -191,7 +206,7 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_cache = {}
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

@staticmethod
def _rotate_half(x):
Expand Down Expand Up @@ -272,7 +287,7 @@ def update(

# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]
self._seen_tokens += key_states.shape[-2]

# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,7 +1932,7 @@ def _contrastive_search(

# keep track of which sequences are already finished
batch_size, cur_len = (
model_kwargs["attention_mask"].shape if "attention_mask" in model_kwargs else input_ids.shape
model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
Expand Down Expand Up @@ -2394,7 +2394,7 @@ def _greedy_search(

# keep track of which sequences are already finished
batch_size, cur_len = (
model_kwargs["attention_mask"].shape if "attention_mask" in model_kwargs else input_ids.shape
model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
Expand Down Expand Up @@ -2696,7 +2696,7 @@ def _sample(

# keep track of which sequences are already finished
batch_size, cur_len = (
model_kwargs["attention_mask"].shape if "attention_mask" in model_kwargs else input_ids.shape
model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
Expand Down Expand Up @@ -4537,7 +4537,7 @@ def _assisted_decoding(

# keep track of which sequences are already finished
batch_size, cur_len = batch_size, cur_len = (
model_kwargs["attention_mask"].shape if "attention_mask" in model_kwargs else input_ids.shape
model_kwargs["attention_mask"].shape if model_kwargs.get("attention_mask", None) else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,11 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,11 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
Expand Down

0 comments on commit 8c29e49

Please sign in to comment.