Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: handle cache_position update in generate #29467

Merged
merged 6 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import torch

from .configuration_utils import PretrainedConfig
from .utils import logging


logger = logging.get_logger(__name__)


@dataclass
Expand Down Expand Up @@ -57,6 +61,17 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
return max_length - new_seq_length
return previous_seq_length

@property
def seen_tokens(self):
logger.warning_once(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None


class DynamicCache(Cache):
"""
Expand All @@ -69,7 +84,7 @@ class DynamicCache(Cache):
def __init__(self) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Expand Down Expand Up @@ -121,7 +136,7 @@ def update(
"""
# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]
self._seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
Expand Down Expand Up @@ -191,7 +206,7 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_cache = {}
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

@staticmethod
def _rotate_half(x):
Expand Down Expand Up @@ -272,7 +287,7 @@ def update(

# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]
self._seen_tokens += key_states.shape[-2]

# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
Expand Down Expand Up @@ -398,16 +413,11 @@ def update(

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise ValueError(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)

def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise ValueError(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return (self.key_cache[0, 0].any(dim=-1)).sum()
Comment on lines +418 to +420
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, we are deprecating this anyways


def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
Expand Down
79 changes: 63 additions & 16 deletions src/transformers/generation/utils.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also set the dtype of the cache positions to int32 wdyt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our integers inputs (input_ids, attention_mask, ...) are all int64, I think we should keep a consistent type :p

Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,6 @@ def _update_model_kwargs_for_generation(
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
Expand Down Expand Up @@ -663,7 +662,8 @@ def _update_model_kwargs_for_generation(
dim=-1,
)

model_kwargs["cache_position"] = model_inputs.get("cache_position", None)
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my single worry here is potential stride, adding a .contiguous() might be needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've double-checked, it's always (1,) 🤗 (which makes sense, since it's a 1D tensor)

Its shape will indeed be different, at least between prefill and subsequent generation


return model_kwargs

Expand Down Expand Up @@ -1931,10 +1931,15 @@ def _contrastive_search(
)

# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
batch_size = input_ids.shape[0]

while True:
if synced_gpus:
Expand Down Expand Up @@ -1975,7 +1980,6 @@ def _contrastive_search(
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
model_inputs=model_inputs,
)
if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
Expand Down Expand Up @@ -2170,7 +2174,9 @@ def _contrastive_search(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -2389,7 +2395,13 @@ def _greedy_search(
)

# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
while True:
Expand Down Expand Up @@ -2459,7 +2471,6 @@ def _greedy_search(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
model_inputs=model_inputs,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -2688,7 +2699,13 @@ def _sample(
)

# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
Expand Down Expand Up @@ -2758,7 +2775,9 @@ def _sample(
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down Expand Up @@ -3003,6 +3022,7 @@ def _beam_search(
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if num_beams * batch_size != batch_beam_size:
raise ValueError(
Expand Down Expand Up @@ -3156,7 +3176,9 @@ def _beam_search(
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -3397,6 +3419,7 @@ def _beam_sample(
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
Expand Down Expand Up @@ -3510,7 +3533,9 @@ def _beam_sample(
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -3747,6 +3772,7 @@ def _group_beam_search(
device = input_ids.device

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
Expand Down Expand Up @@ -3916,7 +3942,9 @@ def _group_beam_search(
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -4155,6 +4183,7 @@ def _constrained_beam_search(
num_beams = constrained_beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if num_beams * batch_size != batch_beam_size:
raise ValueError(
Expand Down Expand Up @@ -4275,7 +4304,9 @@ def _constrained_beam_search(

input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
Expand Down Expand Up @@ -4511,7 +4542,13 @@ def _assisted_decoding(
)

# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
batch_size, cur_len = batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

# other auxiliary variables
max_len = stopping_criteria[0].max_length
Expand Down Expand Up @@ -4555,6 +4592,14 @@ def _assisted_decoding(
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
)
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
if "cache_position" in candidate_kwargs:
candidate_kwargs["cache_position"] = torch.cat(
(
candidate_kwargs["cache_position"],
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
),
dim=0,
)

model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)

Expand Down Expand Up @@ -4673,7 +4718,9 @@ def _assisted_decoding(
)

model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# if eos_token was found in one sentence, set sentence to finished
Expand Down
Loading
Loading