diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 31426435d09f..0c28f46d5ec2 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -436,7 +436,11 @@ def test_flash_attn_2_generate_use_cache(self): # Just test that a large cache works as expected _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, ) @require_flash_attn diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9d9e96db4347..c69b5ed77fe5 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3166,7 +3166,11 @@ def test_flash_attn_2_generate_use_cache(self): # Just test that a large cache works as expected _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, ) @require_flash_attn