@@ -500,7 +500,8 @@ def test_model_4b_bf16(self):
500500 add_generation_prompt = True ,
501501 ).to (torch_device )
502502
503- output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False )
503+ # cache_implementation="hybrid" an in the original transformers implementation
504+ output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False , cache_implementation = "hybrid" )
504505 output_text = self .processor .batch_decode (output , skip_special_tokens = True )
505506
506507 EXPECTED_TEXTS = Expectations (
@@ -545,7 +546,8 @@ def test_model_4b_batch(self):
545546 add_generation_prompt = True ,
546547 ).to (torch_device )
547548
548- output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False )
549+ # cache_implementation="hybrid" an in the original transformers implementation
550+ output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False , cache_implementation = "hybrid" )
549551 output_text = self .processor .batch_decode (output , skip_special_tokens = True )
550552
551553 EXPECTED_TEXTS = Expectations (
@@ -599,7 +601,8 @@ def test_model_4b_crops(self):
599601 ** crop_config ,
600602 ).to (torch_device )
601603
602- output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False )
604+ # cache_implementation="hybrid" an in the original transformers implementation
605+ output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False , cache_implementation = "hybrid" )
603606 output_text = self .processor .batch_decode (output , skip_special_tokens = True )
604607
605608 EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
@@ -654,7 +657,8 @@ def test_model_4b_batch_crops(self):
654657 ** crop_config ,
655658 ).to (torch_device )
656659
657- output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False )
660+ # cache_implementation="hybrid" an in the original transformers implementation
661+ output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False , cache_implementation = "hybrid" )
658662 output_text = self .processor .batch_decode (output , skip_special_tokens = True )
659663 EXPECTED_NUM_IMAGES = 9 # 3 * (one for the origin image and two crops of images) = 9
660664 EXPECTED_TEXTS = Expectations (
@@ -708,7 +712,8 @@ def test_model_4b_multiimage(self):
708712 add_generation_prompt = True ,
709713 ).to (torch_device )
710714
711- output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False )
715+ # cache_implementation="hybrid" an in the original transformers implementation
716+ output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False , cache_implementation = "hybrid" )
712717 output_text = self .processor .batch_decode (output , skip_special_tokens = True )
713718 EXPECTED_TEXTS = Expectations (
714719 {
@@ -729,7 +734,8 @@ def test_model_1b_text_only(self):
729734 tokenizer = AutoTokenizer .from_pretrained (model_id , padding_side = "left" )
730735 inputs = tokenizer ("Write a poem about Machine Learning." , return_tensors = "pt" ).to (torch_device )
731736
732- output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False )
737+ # cache_implementation="hybrid" an in the original transformers implementation
738+ output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False , cache_implementation = "hybrid" )
733739 output_text = tokenizer .batch_decode (output , skip_special_tokens = True )
734740
735741 EXPECTED_TEXTS = Expectations (
@@ -763,7 +769,8 @@ def test_model_4b_flash_attn(self):
763769 add_generation_prompt = True ,
764770 ).to (torch_device )
765771
766- output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False )
772+ # cache_implementation="hybrid" an in the original transformers implementation
773+ output = model .generate (** inputs , max_new_tokens = 30 , do_sample = False , cache_implementation = "hybrid" )
767774 output_text = self .processor .batch_decode (output , skip_special_tokens = True )
768775
769776 EXPECTED_TEXTS = Expectations (
@@ -803,7 +810,10 @@ def test_generation_beyond_sliding_window(self, attn_implementation: str):
803810 input_size = inputs .input_ids .shape [- 1 ]
804811 self .assertTrue (input_size > model .config .sliding_window )
805812
806- out = model .generate (** inputs , max_new_tokens = 20 , do_sample = False )[:, input_size :]
813+ # cache_implementation="hybrid" an in the original transformers implementation
814+ out = model .generate (** inputs , max_new_tokens = 20 , do_sample = False , cache_implementation = "hybrid" )[
815+ :, input_size :
816+ ]
807817 output_text = tokenizer .batch_decode (out )
808818
809819 EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n \n I really enjoy the scenery, and I'" , ", green, yellow, orange, purple, brown, black, white, gray.\n \n I'" ] # fmt: skip
@@ -844,9 +854,44 @@ def test_export_text_only_with_hybrid_cache(self):
844854 ** input_text ,
845855 max_new_tokens = max_new_tokens_to_generate ,
846856 do_sample = False , # Use greedy decoding to match the exported model
857+ cache_implementation = "hybrid" ,
847858 )
848859
849860 eager_generated_text = tokenizer .decode (eager_outputs [0 ], skip_special_tokens = True )
850861 logging .info (f"\n Eager generated texts: '{ eager_generated_text } '" )
851862
852863 self .assertEqual (export_generated_text , eager_generated_text )
864+
865+ def test_dynamic_sliding_window_is_default (self ):
866+ """
867+ Test that the dynamic sliding window cache (added in #40039) is the default cache implementation for Gemma3
868+ models, despite the fact that Hub checkpoints may have `cache_implementation="hybrid"` (static sliding window).
869+ """
870+ model_id = "google/gemma-3-1b-it"
871+ model = AutoModelForCausalLM .from_pretrained (model_id , device_map = "auto" )
872+
873+ # the default cache is static sliding window
874+ self .assertEqual (model .config .cache_implementation , "hybrid" )
875+ self .assertEqual (model .generation_config .cache_implementation , "hybrid" )
876+
877+ tokenizer = AutoTokenizer .from_pretrained (model_id )
878+ prompt = "What is the capital of France?"
879+ model_inputs = tokenizer (prompt , return_tensors = "pt" ).to (model .device )
880+
881+ foward_outputs = model (** model_inputs )
882+ self .assertIn ("DynamicSlidingWindowLayer" , str (foward_outputs .past_key_values ))
883+
884+ generate_outputs = model .generate (
885+ ** model_inputs , max_new_tokens = 2 , do_sample = False , return_dict_in_generate = True
886+ )
887+ self .assertIn ("DynamicSlidingWindowLayer" , str (generate_outputs .past_key_values ))
888+
889+ # If we manually specify the cache implementation = "hybrid", it will use the static sliding window cache
890+ generate_outputs = model .generate (
891+ ** model_inputs ,
892+ max_new_tokens = 2 ,
893+ do_sample = False ,
894+ return_dict_in_generate = True ,
895+ cache_implementation = "hybrid" ,
896+ )
897+ self .assertNotIn ("DynamicSlidingWindowLayer" , str (generate_outputs .past_key_values ))
0 commit comments