Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: standardize cache interface #29005

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c518ee3
wow I was scared!
ArthurZucker Feb 9, 2024
0b20058
fix everything
ArthurZucker Feb 9, 2024
1c5b2c0
nits
ArthurZucker Feb 9, 2024
1acc62f
make it BC?
ArthurZucker Feb 12, 2024
bd93ac7
nits
ArthurZucker Feb 12, 2024
9ef722f
is_tracing should still be used to pass tracing tests
ArthurZucker Feb 12, 2024
5078b37
nits
ArthurZucker Feb 12, 2024
10cc68f
some nits to make sure genration works with static cache uncompiled
ArthurZucker Feb 12, 2024
f83592e
fix FA2 for both static and dynamic in a better way?
ArthurZucker Feb 14, 2024
87631c8
fix sequential beam searcg
ArthurZucker Feb 14, 2024
c3f3c0b
style
ArthurZucker Feb 14, 2024
561fa32
use `keys_to_ignore`
ArthurZucker Feb 14, 2024
ed11a75
nit
ArthurZucker Feb 14, 2024
d623190
correct dtype inference when init
ArthurZucker Feb 14, 2024
c51cc75
:( the fix for FA2 is still not optimal to investigate!
ArthurZucker Feb 14, 2024
8d9e9f4
styling
ArthurZucker Feb 14, 2024
4176694
nits
ArthurZucker Feb 14, 2024
1936cf8
nit
ArthurZucker Feb 14, 2024
c476ad3
this might work better
ArthurZucker Feb 14, 2024
cfbcf6a
add comment
ArthurZucker Feb 14, 2024
3a2a785
Update src/transformers/models/llama/modeling_llama.py
ArthurZucker Feb 14, 2024
ed6c60d
tmp commit
gante Feb 13, 2024
3f0f207
wip
gante Feb 13, 2024
93c9e2e
tmp commit
gante Feb 14, 2024
e69eec2
tmp
gante Feb 14, 2024
5b38bf7
merge errors
gante Feb 15, 2024
bc84704
reduce diff
gante Feb 15, 2024
958875e
smaller llama diff
gante Feb 15, 2024
104b208
make fixup
gante Feb 15, 2024
425c6ed
nearly all tests passing
gante Feb 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,21 +359,34 @@ A [`Constraint`] can be used to force the generation to include specific tokens

## Caches

[[autodoc]] Cache
[[autodoc]] ModelCache
- update
- get_seq_length
- get_max_length
- get_usable_length
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] Cache

[[autodoc]] DynamicCache
- update
- get_seq_length
- get_max_length
- get_usable_length
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] SinkCache
- update
- get_seq_length
- get_max_length
- get_usable_length
- reorder_cache

[[autodoc]] StaticCache
- update
- get_seq_length
- get_seq_length
- get_max_length
- get_usable_length
- reorder_cache
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,7 @@
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "ModelCache", "SinkCache", "StaticCache"]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down Expand Up @@ -6086,7 +6086,7 @@
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from .cache_utils import Cache, DynamicCache, ModelCache, SinkCache, StaticCache
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
Expand Down
291 changes: 163 additions & 128 deletions src/transformers/cache_utils.py

Large diffs are not rendered by default.

42 changes: 28 additions & 14 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch.distributed as dist
from torch import nn

from ..cache_utils import Cache, DynamicCache, StaticCache
from ..cache_utils import Cache, DynamicCache, ModelCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand Down Expand Up @@ -1431,7 +1431,15 @@ def generate(
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
" Make sure it has a `_setup_cache` function."
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
self._setup_cache(
cache_cls=cache_cls,
cache_kwargs={
"max_batch_size": batch_size,
"max_cache_len": generation_config.max_length,
"config": self.config,
"device": self.device,
},
)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

Expand Down Expand Up @@ -1493,7 +1501,7 @@ def generate(
)

# 12. run assisted generate
return self.assisted_decoding(
generate_output = self.assisted_decoding(
input_ids,
candidate_generator=candidate_generator,
do_sample=generation_config.do_sample,
Expand All @@ -1510,7 +1518,7 @@ def generate(
)
if generation_mode == GenerationMode.GREEDY_SEARCH:
# 11. run greedy search
return self.greedy_search(
generate_output = self.greedy_search(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
Expand All @@ -1527,7 +1535,7 @@ def generate(
if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`")

return self.contrastive_search(
generate_output = self.contrastive_search(
input_ids,
top_k=generation_config.top_k,
penalty_alpha=generation_config.penalty_alpha,
Expand Down Expand Up @@ -1556,7 +1564,7 @@ def generate(
)

# 13. run sample
return self.sample(
generate_output = self.sample(
input_ids,
logits_processor=prepared_logits_processor,
logits_warper=logits_warper,
Expand Down Expand Up @@ -1589,7 +1597,7 @@ def generate(
**model_kwargs,
)
# 13. run beam search
return self.beam_search(
generate_output = self.beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
Expand Down Expand Up @@ -1627,7 +1635,7 @@ def generate(
)

# 14. run beam sample
return self.beam_sample(
generate_output = self.beam_sample(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
Expand Down Expand Up @@ -1661,7 +1669,7 @@ def generate(
**model_kwargs,
)
# 13. run beam search
return self.group_beam_search(
generate_output = self.group_beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
Expand Down Expand Up @@ -1734,7 +1742,7 @@ def typeerror():
**model_kwargs,
)
# 13. run beam search
return self.constrained_beam_search(
generate_output = self.constrained_beam_search(
input_ids,
constrained_beam_scorer=constrained_beam_scorer,
logits_processor=prepared_logits_processor,
Expand All @@ -1747,6 +1755,12 @@ def typeerror():
**model_kwargs,
)

# Finally, reset the model cache if has one
if hasattr(self, "_reset_cache"):
self._reset_cache()

return generate_output

@torch.no_grad()
def contrastive_search(
self,
Expand Down Expand Up @@ -2735,17 +2749,17 @@ def _temporary_reorder_cache(self, past_key_values, beam_idx):
# Exception 1: code path for models using the legacy cache format
if isinstance(past_key_values, (tuple, list)):
past_key_values = self._reorder_cache(past_key_values, beam_idx)
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
# Exception 2: models with different cache formats. These are limited to `DynamicCache` caches until their
# cache format is standardized, to avoid adding complexity to the codebase.
elif "bloom" in model_class or "gptbigcode" in model_class:
if not isinstance(past_key_values, DynamicCache):
if not isinstance(past_key_values.caches[0], DynamicCache):
raise ValueError(
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
"legacy tuple format or `DynamicCache`"
)
past_key_values = self._reorder_cache(past_key_values, beam_idx)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
# Standard code path: use the `Cache.reorder_cache`
past_key_values = ModelCache.from_legacy_cache(past_key_values)
# Standard code path: use the cache's `.reorder_cache`
else:
past_key_values.reorder_cache(beam_idx)
return past_key_values
Expand Down
Loading
Loading