diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index b3c9ff7cc640..352bc82a1e40 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -53,7 +53,7 @@ def paged_attention_forward( sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0) if implementation is not None: flash_attn_varlen_func = implementation.flash_attn_varlen_func - custom_kwargs = {"s_aux": kwargs.get("s_aux")} + custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {} attn_output = flash_attn_varlen_func( q.transpose(1, 2).squeeze(0).contiguous(), k.transpose(1, 2).squeeze(0).contiguous(),