Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1742,6 +1742,13 @@ def _prepare_generation_config(
generation_config = self.generation_config
using_model_generation_config = True

# Related to #40039: prior to this PR, models with sliding window attention were forced to have
# `cache_implementation="hybrid"` (the static sliding window cache). For these models, we now want to use
# the dynamic sliding window cache by default, so we UNSET `cache_implementation` if it is a default value.
# (if we're inside this branch, then it is because we're using default values from the Hub)
if generation_config.cache_implementation == "hybrid":
generation_config.cache_implementation = None

# `torch.export.export` usually raises an exception if it is called
# with ``strict=True``. deepcopy can only be processed if ``strict=False``.
generation_config = copy.deepcopy(generation_config)
Expand Down Expand Up @@ -1954,10 +1961,6 @@ def _prepare_cache_for_generation(
)
generation_config.cache_implementation = None

generation_config.cache_implementation = generation_config.cache_implementation or getattr(
self.config.get_text_config(decoder=True), "cache_implementation", None
)
Comment on lines -1957 to -1959
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was originally added on #37353, but shouldn't be needed: on L1731, inside GenerationConfig.from_model_config(), we pull data from the text config into the generation config.

Conceptually, all config -> generation_config should happen inside _prepare_generation_config, otherwise it will be hard to track all changes to generation_config 🤗

If llama4 still has issues with this, then it likely means it has some configuration issue in its files 🤔


# assisted decoding and contrastive search need to roll-back the Cache, which is not supported if
# it has sliding layers - so if we use any of those 2, do not pass the config to DynamicCache, which
# will result in creating a Cache with only full layers even if model uses sliding window
Expand Down
61 changes: 53 additions & 8 deletions tests/models/gemma3/test_modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,8 @@ def test_model_4b_bf16(self):
add_generation_prompt=True,
).to(torch_device)

output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = self.processor.batch_decode(output, skip_special_tokens=True)

EXPECTED_TEXTS = Expectations(
Expand Down Expand Up @@ -545,7 +546,8 @@ def test_model_4b_batch(self):
add_generation_prompt=True,
).to(torch_device)

output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = self.processor.batch_decode(output, skip_special_tokens=True)

EXPECTED_TEXTS = Expectations(
Expand Down Expand Up @@ -599,7 +601,8 @@ def test_model_4b_crops(self):
**crop_config,
).to(torch_device)

output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = self.processor.batch_decode(output, skip_special_tokens=True)

EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
Expand Down Expand Up @@ -654,7 +657,8 @@ def test_model_4b_batch_crops(self):
**crop_config,
).to(torch_device)

output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_NUM_IMAGES = 9 # 3 * (one for the origin image and two crops of images) = 9
EXPECTED_TEXTS = Expectations(
Expand Down Expand Up @@ -708,7 +712,8 @@ def test_model_4b_multiimage(self):
add_generation_prompt=True,
).to(torch_device)

output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = Expectations(
{
Expand All @@ -729,7 +734,8 @@ def test_model_1b_text_only(self):
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device)

output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)

EXPECTED_TEXTS = Expectations(
Expand Down Expand Up @@ -763,7 +769,8 @@ def test_model_4b_flash_attn(self):
add_generation_prompt=True,
).to(torch_device)

output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = self.processor.batch_decode(output, skip_special_tokens=True)

EXPECTED_TEXTS = Expectations(
Expand Down Expand Up @@ -803,7 +810,10 @@ def test_generation_beyond_sliding_window(self, attn_implementation: str):
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)

out = model.generate(**inputs, max_new_tokens=20, do_sample=False)[:, input_size:]
# cache_implementation="hybrid" an in the original transformers implementation
out = model.generate(**inputs, max_new_tokens=20, do_sample=False, cache_implementation="hybrid")[
:, input_size:
]
output_text = tokenizer.batch_decode(out)

EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
Expand Down Expand Up @@ -844,9 +854,44 @@ def test_export_text_only_with_hybrid_cache(self):
**input_text,
max_new_tokens=max_new_tokens_to_generate,
do_sample=False, # Use greedy decoding to match the exported model
cache_implementation="hybrid",
)

eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
logging.info(f"\nEager generated texts: '{eager_generated_text}'")

self.assertEqual(export_generated_text, eager_generated_text)

def test_dynamic_sliding_window_is_default(self):
"""
Test that the dynamic sliding window cache (added in #40039) is the default cache implementation for Gemma3
models, despite the fact that Hub checkpoints may have `cache_implementation="hybrid"` (static sliding window).
"""
model_id = "google/gemma-3-1b-it"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")

# the default cache is static sliding window
self.assertEqual(model.config.cache_implementation, "hybrid")
self.assertEqual(model.generation_config.cache_implementation, "hybrid")

tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "What is the capital of France?"
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

foward_outputs = model(**model_inputs)
self.assertIn("DynamicSlidingWindowLayer", str(foward_outputs.past_key_values))

generate_outputs = model.generate(
**model_inputs, max_new_tokens=2, do_sample=False, return_dict_in_generate=True
)
self.assertIn("DynamicSlidingWindowLayer", str(generate_outputs.past_key_values))

# If we manually specify the cache implementation = "hybrid", it will use the static sliding window cache
generate_outputs = model.generate(
**model_inputs,
max_new_tokens=2,
do_sample=False,
return_dict_in_generate=True,
cache_implementation="hybrid",
)
self.assertNotIn("DynamicSlidingWindowLayer", str(generate_outputs.past_key_values))