Skip to content

Commit 5337f30

Browse files
authored
🚨🚨 [generate] ignore cache_implementation="hybrid" hub defaults (#40135)
* working? * fix tests
1 parent e4223fa commit 5337f30

File tree

2 files changed

+60
-12
lines changed

2 files changed

+60
-12
lines changed

‎src/transformers/generation/utils.py‎

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,6 +1742,13 @@ def _prepare_generation_config(
17421742
generation_config = self.generation_config
17431743
using_model_generation_config = True
17441744

1745+
# Related to #40039: prior to this PR, models with sliding window attention were forced to have
1746+
# `cache_implementation="hybrid"` (the static sliding window cache). For these models, we now want to use
1747+
# the dynamic sliding window cache by default, so we UNSET `cache_implementation` if it is a default value.
1748+
# (if we're inside this branch, then it is because we're using default values from the Hub)
1749+
if generation_config.cache_implementation == "hybrid":
1750+
generation_config.cache_implementation = None
1751+
17451752
# `torch.export.export` usually raises an exception if it is called
17461753
# with ``strict=True``. deepcopy can only be processed if ``strict=False``.
17471754
generation_config = copy.deepcopy(generation_config)
@@ -1954,10 +1961,6 @@ def _prepare_cache_for_generation(
19541961
)
19551962
generation_config.cache_implementation = None
19561963

1957-
generation_config.cache_implementation = generation_config.cache_implementation or getattr(
1958-
self.config.get_text_config(decoder=True), "cache_implementation", None
1959-
)
1960-
19611964
# assisted decoding and contrastive search need to roll-back the Cache, which is not supported if
19621965
# it has sliding layers - so if we use any of those 2, do not pass the config to DynamicCache, which
19631966
# will result in creating a Cache with only full layers even if model uses sliding window

‎tests/models/gemma3/test_modeling_gemma3.py‎

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,8 @@ def test_model_4b_bf16(self):
500500
add_generation_prompt=True,
501501
).to(torch_device)
502502

503-
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
503+
# cache_implementation="hybrid" an in the original transformers implementation
504+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
504505
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
505506

506507
EXPECTED_TEXTS = Expectations(
@@ -545,7 +546,8 @@ def test_model_4b_batch(self):
545546
add_generation_prompt=True,
546547
).to(torch_device)
547548

548-
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
549+
# cache_implementation="hybrid" an in the original transformers implementation
550+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
549551
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
550552

551553
EXPECTED_TEXTS = Expectations(
@@ -599,7 +601,8 @@ def test_model_4b_crops(self):
599601
**crop_config,
600602
).to(torch_device)
601603

602-
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
604+
# cache_implementation="hybrid" an in the original transformers implementation
605+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
603606
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
604607

605608
EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
@@ -654,7 +657,8 @@ def test_model_4b_batch_crops(self):
654657
**crop_config,
655658
).to(torch_device)
656659

657-
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
660+
# cache_implementation="hybrid" an in the original transformers implementation
661+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
658662
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
659663
EXPECTED_NUM_IMAGES = 9 # 3 * (one for the origin image and two crops of images) = 9
660664
EXPECTED_TEXTS = Expectations(
@@ -708,7 +712,8 @@ def test_model_4b_multiimage(self):
708712
add_generation_prompt=True,
709713
).to(torch_device)
710714

711-
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
715+
# cache_implementation="hybrid" an in the original transformers implementation
716+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
712717
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
713718
EXPECTED_TEXTS = Expectations(
714719
{
@@ -729,7 +734,8 @@ def test_model_1b_text_only(self):
729734
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
730735
inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device)
731736

732-
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
737+
# cache_implementation="hybrid" an in the original transformers implementation
738+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
733739
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
734740

735741
EXPECTED_TEXTS = Expectations(
@@ -763,7 +769,8 @@ def test_model_4b_flash_attn(self):
763769
add_generation_prompt=True,
764770
).to(torch_device)
765771

766-
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
772+
# cache_implementation="hybrid" an in the original transformers implementation
773+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
767774
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
768775

769776
EXPECTED_TEXTS = Expectations(
@@ -803,7 +810,10 @@ def test_generation_beyond_sliding_window(self, attn_implementation: str):
803810
input_size = inputs.input_ids.shape[-1]
804811
self.assertTrue(input_size > model.config.sliding_window)
805812

806-
out = model.generate(**inputs, max_new_tokens=20, do_sample=False)[:, input_size:]
813+
# cache_implementation="hybrid" an in the original transformers implementation
814+
out = model.generate(**inputs, max_new_tokens=20, do_sample=False, cache_implementation="hybrid")[
815+
:, input_size:
816+
]
807817
output_text = tokenizer.batch_decode(out)
808818

809819
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
@@ -844,9 +854,44 @@ def test_export_text_only_with_hybrid_cache(self):
844854
**input_text,
845855
max_new_tokens=max_new_tokens_to_generate,
846856
do_sample=False, # Use greedy decoding to match the exported model
857+
cache_implementation="hybrid",
847858
)
848859

849860
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
850861
logging.info(f"\nEager generated texts: '{eager_generated_text}'")
851862

852863
self.assertEqual(export_generated_text, eager_generated_text)
864+
865+
def test_dynamic_sliding_window_is_default(self):
866+
"""
867+
Test that the dynamic sliding window cache (added in #40039) is the default cache implementation for Gemma3
868+
models, despite the fact that Hub checkpoints may have `cache_implementation="hybrid"` (static sliding window).
869+
"""
870+
model_id = "google/gemma-3-1b-it"
871+
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
872+
873+
# the default cache is static sliding window
874+
self.assertEqual(model.config.cache_implementation, "hybrid")
875+
self.assertEqual(model.generation_config.cache_implementation, "hybrid")
876+
877+
tokenizer = AutoTokenizer.from_pretrained(model_id)
878+
prompt = "What is the capital of France?"
879+
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
880+
881+
foward_outputs = model(**model_inputs)
882+
self.assertIn("DynamicSlidingWindowLayer", str(foward_outputs.past_key_values))
883+
884+
generate_outputs = model.generate(
885+
**model_inputs, max_new_tokens=2, do_sample=False, return_dict_in_generate=True
886+
)
887+
self.assertIn("DynamicSlidingWindowLayer", str(generate_outputs.past_key_values))
888+
889+
# If we manually specify the cache implementation = "hybrid", it will use the static sliding window cache
890+
generate_outputs = model.generate(
891+
**model_inputs,
892+
max_new_tokens=2,
893+
do_sample=False,
894+
return_dict_in_generate=True,
895+
cache_implementation="hybrid",
896+
)
897+
self.assertNotIn("DynamicSlidingWindowLayer", str(generate_outputs.past_key_values))

0 commit comments

Comments
 (0)