Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

static cache implementation is not compatible with attn_implementation==flash_attention_2 #32040

Open
2 of 4 tasks
faaany opened this issue Jul 18, 2024 · 3 comments
Open
2 of 4 tasks
Labels
bug Cache Feature request Request for a new feature

Comments

@faaany
Copy link
Contributor

faaany commented Jul 18, 2024

System Info

  • transformers version: 4.43.0.dev0
  • Platform: Linux-4.18.0-425.3.1.el8.x86_64-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.23.5
  • Safetensors version: 0.4.3
  • Accelerate version: 0.33.0.dev0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100 80GB PCIe

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

pytest -rA tests/test_cache_utils.py::CacheIntegrationTest -k "test_static_cache_greedy_decoding_pad_left and flash_attention"

fails with

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if isinstance(past_key_value, StaticCache):
>           raise ValueError(
                "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
                "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
            )
E           ValueError: `static` cache implementation is not compatible with `attn_implementation==flash_attention_2` make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers

src/transformers/models/llama/modeling_llama.py:388: ValueError

And the right padding test case also fails:

pytest -rA tests/test_cache_utils.py::CacheIntegrationTest -k "test_static_cache_greedy_decoding_pad_right and flash_attention"

Expected behavior

Either we don't test flash_attention in this case, or we should add a if check to skip setting cache_implementation to static.

@faaany faaany added the bug label Jul 18, 2024
@faaany
Copy link
Contributor Author

faaany commented Jul 18, 2024

I made a possible fix suggestion in this PR draft: #32039. But I am not sure whether this is correct. So I also filed this issue.

@amyeroberts
Copy link
Collaborator

cc @gante too

@ArthurZucker ArthurZucker added the Feature request Request for a new feature label Jul 18, 2024
@zucchini-nlp
Copy link
Member

Incompatibility also affecting Gemma2 with flash-attn, as it doesn't support dynamic cache

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Cache Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

4 participants