Skip to content

Commit

Permalink
Generate: handle cache_position update in generate (#29467)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Mar 14, 2024
1 parent 7b87ecb commit 23db187
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 78 deletions.
38 changes: 24 additions & 14 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import torch

from .configuration_utils import PretrainedConfig
from .utils import logging


logger = logging.get_logger(__name__)


@dataclass
Expand Down Expand Up @@ -57,6 +61,17 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
return max_length - new_seq_length
return previous_seq_length

@property
def seen_tokens(self):
logger.warning_once(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None


class DynamicCache(Cache):
"""
Expand All @@ -69,7 +84,7 @@ class DynamicCache(Cache):
def __init__(self) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Expand Down Expand Up @@ -121,7 +136,7 @@ def update(
"""
# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]
self._seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
Expand Down Expand Up @@ -191,7 +206,7 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_cache = {}
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

@staticmethod
def _rotate_half(x):
Expand Down Expand Up @@ -272,7 +287,7 @@ def update(

# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]
self._seen_tokens += key_states.shape[-2]

# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
Expand Down Expand Up @@ -398,16 +413,11 @@ def update(

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise ValueError(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)

def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise ValueError(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return (self.key_cache[0, 0].any(dim=-1)).sum()

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
Expand Down
79 changes: 63 additions & 16 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,6 @@ def _update_model_kwargs_for_generation(
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
Expand Down Expand Up @@ -663,7 +662,8 @@ def _update_model_kwargs_for_generation(
dim=-1,
)

model_kwargs["cache_position"] = model_inputs.get("cache_position", None)
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1

return model_kwargs

Expand Down Expand Up @@ -1931,10 +1931,15 @@ def _contrastive_search(
)

# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
batch_size = input_ids.shape[0]

while True:
if synced_gpus:
Expand Down Expand Up @@ -1975,7 +1980,6 @@ def _contrastive_search(
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
model_inputs=model_inputs,
)
if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
Expand Down Expand Up @@ -2170,7 +2174,9 @@ def _contrastive_search(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -2389,7 +2395,13 @@ def _greedy_search(
)

# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
while True:
Expand Down Expand Up @@ -2459,7 +2471,6 @@ def _greedy_search(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
model_inputs=model_inputs,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -2688,7 +2699,13 @@ def _sample(
)

# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
Expand Down Expand Up @@ -2758,7 +2775,9 @@ def _sample(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -3003,6 +3022,7 @@ def _beam_search(
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if num_beams * batch_size != batch_beam_size:
raise ValueError(
Expand Down Expand Up @@ -3156,7 +3176,9 @@ def _beam_search(
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -3397,6 +3419,7 @@ def _beam_sample(
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

This comment has been minimized.

Copy link
@minostauros

minostauros Apr 24, 2024

This always overwrites model_kwargs["cache_position"]. Why don't we first check if user specifies the initial "cache_position"?


# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
Expand Down Expand Up @@ -3510,7 +3533,9 @@ def _beam_sample(
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -3747,6 +3772,7 @@ def _group_beam_search(
device = input_ids.device

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
Expand Down Expand Up @@ -3916,7 +3942,9 @@ def _group_beam_search(
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -4155,6 +4183,7 @@ def _constrained_beam_search(
num_beams = constrained_beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if num_beams * batch_size != batch_beam_size:
raise ValueError(
Expand Down Expand Up @@ -4275,7 +4304,9 @@ def _constrained_beam_search(

input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -4511,7 +4542,13 @@ def _assisted_decoding(
)

# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
batch_size, cur_len = batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

# other auxiliary variables
max_len = stopping_criteria[0].max_length
Expand Down Expand Up @@ -4555,6 +4592,14 @@ def _assisted_decoding(
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
)
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
if "cache_position" in candidate_kwargs:
candidate_kwargs["cache_position"] = torch.cat(
(
candidate_kwargs["cache_position"],
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
),
dim=0,
)

model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)

Expand Down Expand Up @@ -4673,7 +4718,9 @@ def _assisted_decoding(
)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down
Loading

0 comments on commit 23db187

Please sign in to comment.