Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
100ae2c
Simplify the logic quite a bit
Cyrilvallez Jul 30, 2025
19ecbd8
Update cache_utils.py
Cyrilvallez Jul 30, 2025
7b3d65c
continue work
Cyrilvallez Jul 30, 2025
f385ac7
continue simplifying a lot
Cyrilvallez Jul 30, 2025
d54e338
style
Cyrilvallez Jul 30, 2025
2a7aac7
Update cache_utils.py
Cyrilvallez Jul 30, 2025
ec96c77
offloading much simpler
Cyrilvallez Jul 31, 2025
2081941
style
Cyrilvallez Jul 31, 2025
2592240
Update cache_utils.py
Cyrilvallez Jul 31, 2025
37bd555
update inits
Cyrilvallez Jul 31, 2025
c0c964f
Update cache_utils.py
Cyrilvallez Jul 31, 2025
9fd8803
consistemncy
Cyrilvallez Jul 31, 2025
2518e75
Update cache_utils.py
Cyrilvallez Jul 31, 2025
17ca71e
update generate
Cyrilvallez Jul 31, 2025
8dade3d
style
Cyrilvallez Jul 31, 2025
a404dba
fix
Cyrilvallez Jul 31, 2025
74ab8c8
fix
Cyrilvallez Jul 31, 2025
78ffd4c
add early_initialization
Cyrilvallez Jul 31, 2025
19fef9d
fix
Cyrilvallez Jul 31, 2025
c0ce446
fix mamba caches
Cyrilvallez Jul 31, 2025
b051526
update
Cyrilvallez Jul 31, 2025
3dc2538
fix
Cyrilvallez Jul 31, 2025
ccda84d
fix
Cyrilvallez Jul 31, 2025
8ee7cc9
fix
Cyrilvallez Jul 31, 2025
11a8f97
fix tests
Cyrilvallez Jul 31, 2025
b41a4b9
fix configs
Cyrilvallez Jul 31, 2025
b57cedf
revert
Cyrilvallez Jul 31, 2025
709e51f
fix tests
Cyrilvallez Aug 1, 2025
11e22b6
alright
Cyrilvallez Aug 1, 2025
f890769
Update modeling_gptj.py
Cyrilvallez Aug 1, 2025
4f9581a
fix the constructors
Cyrilvallez Aug 1, 2025
9c4ce68
cache tests
Cyrilvallez Aug 1, 2025
d990e80
Update test_cache_utils.py
Cyrilvallez Aug 1, 2025
0c1f41a
fix
Cyrilvallez Aug 1, 2025
36d2470
simplify
Cyrilvallez Aug 3, 2025
241d48a
back to before -> avoid compile bug
Cyrilvallez Aug 4, 2025
03b8401
doc
Cyrilvallez Aug 4, 2025
2d007c1
mistral test
Cyrilvallez Aug 4, 2025
71ada77
llama4 test dtype
Cyrilvallez Aug 4, 2025
23054e2
Update test_modeling_llama4.py
Cyrilvallez Aug 4, 2025
e8ceb9d
CIs
Cyrilvallez Aug 4, 2025
d0763b8
Finally find a nice impl
Cyrilvallez Aug 5, 2025
06fd9e4
Update cache_utils.py
Cyrilvallez Aug 5, 2025
b6eeae2
Update cache_utils.py
Cyrilvallez Aug 5, 2025
ca32e1f
add lazy methods in autodoc
Cyrilvallez Aug 5, 2025
a173a64
typo
Cyrilvallez Aug 5, 2025
1f7dd27
better doc
Cyrilvallez Aug 5, 2025
203ab69
Add detailed docstring for lazy init
Cyrilvallez Aug 5, 2025
48e78d0
CIs
Cyrilvallez Aug 5, 2025
236bf9d
style
Cyrilvallez Aug 7, 2025
0630cd2
fix
Cyrilvallez Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 11 additions & 18 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,37 +363,34 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- get_max_cache_shape
- reset
- reorder_cache
- lazy_initialization

[[autodoc]] DynamicLayer
- update
- lazy_initialization
- crop
- batch_repeat_interleave
- batch_select_indices

[[autodoc]] StaticLayer
- update
- lazy_initialization

[[autodoc]] SlidingWindowLayer
- update
- lazy_initialization

[[autodoc]] CacheProcessor
- pre_update
- post_update

[[autodoc]] OffloadedCacheProcessor
- pre_update

[[autodoc]] QuantizedCacheProcessor
- post_update

[[autodoc]] QuantoQuantizedCacheProcessor
- post_update
[[autodoc]] QuantoQuantizedLayer
- update
- lazy_initialization

[[autodoc]] HQQQuantizedCacheProcessor
- post_update
[[autodoc]] HQQQuantizedLayer
- update
- lazy_initialization

[[autodoc]] Cache
- update
- early_initialization
- get_seq_length
- get_mask_sizes
- get_max_cache_shape
Expand All @@ -411,12 +408,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] QuantoQuantizedCache

[[autodoc]] QuantoQuantizedCacheProcessor

[[autodoc]] HQQQuantizedCache

[[autodoc]] HQQQuantizedCacheProcessor

[[autodoc]] OffloadedCache

[[autodoc]] StaticCache
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/kv_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)

# Init StaticCache with big enough max-length (1024 tokens for the below example)
# You can also init a DynamicCache, if that suits you better
prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device=model.device.type, dtype=torch.bfloat16)
prompt_cache = StaticCache(config=model.config, max_cache_len=1024)

INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to(model.device.type)
Expand Down
5 changes: 1 addition & 4 deletions docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,8 @@ model.generation_config.max_new_tokens = 16

past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device,
dtype=model.dtype
)
outputs = model.generate(**input_ids, past_key_values=past_key_values)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
Expand Down Expand Up @@ -159,7 +156,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
past_key_values = StaticCache(
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
config=model.config, max_cache_len=4096
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
Expand Down
3 changes: 1 addition & 2 deletions docs/source/en/model_doc/gemma2.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ visualizer("You are an assistant. Make sure you print me")

inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
max_generated_length = inputs.input_ids.shape[1] + 10
past_key_values = HybridCache(config=model.config, max_batch_size=1,
max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
past_key_values = HybridCache(config=model.config, max_cache_len=max_generated_length)
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
```

Expand Down
22 changes: 4 additions & 18 deletions docs/source/ko/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,21 +362,11 @@ generation_output[:2]
[[autodoc]] SlidingWindowLayer
- update

[[autodoc]] CacheProcessor
- pre_update
- post_update

[[autodoc]] OffloadedCacheProcessor
- pre_update

[[autodoc]] QuantizedCacheProcessor
- post_update

[[autodoc]] QuantoQuantizedCacheProcessor
- post_update
[[autodoc]] QuantoQuantizedLayer
- update

[[autodoc]] HQQQuantizedCacheProcessor
- post_update
[[autodoc]] HQQQuantizedLayer
- update

[[autodoc]] Cache
- update
Expand All @@ -397,12 +387,8 @@ generation_output[:2]

[[autodoc]] QuantoQuantizedCache

[[autodoc]] QuantoQuantizedCacheProcessor

[[autodoc]] HQQQuantizedCache

[[autodoc]] HQQQuantizedCacheProcessor

[[autodoc]] OffloadedCache

[[autodoc]] StaticCache
Expand Down
5 changes: 1 addition & 4 deletions docs/source/ko/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,8 @@ model.generation_config.max_new_tokens = 16

past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
# 캐시를 재사용할 계획이 있는 경우, 모든 경우에 충분한 캐시 길이를 설정해야 합니다
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device,
dtype=model.dtype
)
outputs = model.generate(**input_ids, past_key_values=past_key_values)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
Expand Down Expand Up @@ -161,7 +158,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
past_key_values = StaticCache(
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
config=model.config, max_cache_len=4096
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,23 +377,18 @@
"StaticLayer",
"SlidingWindowLayer",
"ChunkedSlidingLayer",
"CacheProcessor",
"OffloadedCacheProcessor",
"QuantizedCacheProcessor",
"QuantoQuantizedCacheProcessor",
"HQQQuantizedCacheProcessor",
"QuantoQuantizedLayer",
"HQQQuantizedLayer",
"Cache",
"CacheConfig",
"DynamicCache",
"EncoderDecoderCache",
"HQQQuantizedCache",
"HQQQuantizedCacheProcessor",
"HybridCache",
"HybridChunkedCache",
"OffloadedCache",
"OffloadedStaticCache",
"QuantizedCache",
"QuantoQuantizedCacheProcessor",
"QuantizedCacheConfig",
"QuantoQuantizedCache",
"SinkCache",
Expand Down Expand Up @@ -586,19 +581,25 @@
# All modeling imports
from .cache_utils import Cache as Cache
from .cache_utils import CacheConfig as CacheConfig
from .cache_utils import ChunkedSlidingLayer as ChunkedSlidingLayer
from .cache_utils import DynamicCache as DynamicCache
from .cache_utils import DynamicLayer as DynamicLayer
from .cache_utils import EncoderDecoderCache as EncoderDecoderCache
from .cache_utils import HQQQuantizedCache as HQQQuantizedCache
from .cache_utils import HQQQuantizedLayer as HQQQuantizedLayer
from .cache_utils import HybridCache as HybridCache
from .cache_utils import MambaCache as MambaCache
from .cache_utils import OffloadedCache as OffloadedCache
from .cache_utils import OffloadedStaticCache as OffloadedStaticCache
from .cache_utils import QuantizedCache as QuantizedCache
from .cache_utils import QuantizedCacheConfig as QuantizedCacheConfig
from .cache_utils import QuantoQuantizedCache as QuantoQuantizedCache
from .cache_utils import QuantoQuantizedLayer as QuantoQuantizedLayer
from .cache_utils import SinkCache as SinkCache
from .cache_utils import SlidingWindowCache as SlidingWindowCache
from .cache_utils import SlidingWindowLayer as SlidingWindowLayer
from .cache_utils import StaticCache as StaticCache
from .cache_utils import StaticLayer as StaticLayer
from .configuration_utils import PretrainedConfig as PretrainedConfig
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS as SLOW_TO_FAST_CONVERTERS
from .convert_slow_tokenizer import convert_slow_tokenizer as convert_slow_tokenizer
Expand Down
Loading