@@ -2599,7 +2599,7 @@ def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool:
25992599 BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
26002600 before instantiating the full models if we know that the model does not support the requested attention.
26012601 """
2602- if not self ._supports_sdpa :
2602+ if not self ._supports_sdpa and not is_init_check :
26032603 raise ValueError (
26042604 f"{ self .__class__ .__name__ } does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
26052605 " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
@@ -2683,34 +2683,51 @@ def _check_and_adjust_attn_implementation(
26832683 if re .match (r"^[^/:]+/[^/:]+:?[^/:]+$" , applicable_attn_implementation ):
26842684 if not is_kernels_available ():
26852685 raise ValueError ("kernels is not installed. Please install it with `pip install kernels`." )
2686-
2686+ attention_wrapper = None
2687+ # FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work
2688+ if "|" in applicable_attn_implementation :
2689+ attention_wrapper , applicable_attn_implementation = applicable_attn_implementation .split ("|" )
2690+ # `transformers` has wrapper for sdpa, paged, flash, flex etc.
2691+ attention_wrapper = ALL_ATTENTION_FUNCTIONS .get (attention_wrapper )
26872692 # Extract repo_id and kernel_name from the string
26882693 if ":" in applicable_attn_implementation :
26892694 repo_id , kernel_name = attn_implementation .split (":" )
26902695 kernel_name = kernel_name .strip ()
26912696 else :
2692- repo_id = attn_implementation
2697+ repo_id = applicable_attn_implementation
26932698 kernel_name = None
26942699 repo_id = repo_id .strip ()
26952700 try :
26962701 kernel = get_kernel (repo_id )
26972702 if hasattr (kernel , "flash_attn_varlen_func" ):
2698- kernel_function = partial (flash_attention_forward , implementation = kernel )
2703+ if attention_wrapper is None :
2704+ attention_wrapper = flash_attention_forward
2705+ kernel_function = partial (attention_wrapper , implementation = kernel )
26992706 elif kernel_name is not None :
27002707 kernel_function = getattr (kernel , kernel_name )
2701- # Register it
2702- ALL_ATTENTION_FUNCTIONS .register (repo_id , kernel_function )
2703- ALL_MASK_ATTENTION_FUNCTIONS . register ( repo_id , ALL_MASK_ATTENTION_FUNCTIONS ["flash_attention_2" ])
2704- applicable_attn_implementation = repo_id
2708+ ALL_ATTENTION_FUNCTIONS . register ( applicable_attn_implementation , kernel_function )
2709+ ALL_MASK_ATTENTION_FUNCTIONS .register (
2710+ applicable_attn_implementation , ALL_MASK_ATTENTION_FUNCTIONS ["flash_attention_2" ]
2711+ )
27052712 except Exception as e :
27062713 logger .warning_once (
27072714 f"Could not find a kernel repository '{ repo_id } ' compatible with your device in the hub: { e } . Using "
27082715 "default attention implementation instead (sdpa if available, eager otherwise)."
27092716 )
2717+
27102718 applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
2711- if applicable_attn_implementation not in ["eager" ] + ALL_ATTENTION_FUNCTIONS .valid_keys ():
2719+ return applicable_attn_implementation
2720+ else :
2721+ return self .get_correct_attn_implementation (applicable_attn_implementation , is_init_check )
2722+
2723+ def get_correct_attn_implementation (self , _requested_attention : str , is_init_check : bool = False ) -> str :
2724+ requested_attention = "sdpa" if _requested_attention is None else _requested_attention
2725+ if is_init_check and requested_attention == "sdpa" :
2726+ if not self ._supports_sdpa :
2727+ requested_attention = "eager"
2728+ if requested_attention not in ["eager" ] + ALL_ATTENTION_FUNCTIONS .valid_keys ():
27122729 message = (
2713- f'Specified `attn_implementation="{ attn_implementation } "` is not supported. The only possible arguments are '
2730+ f'Specified `attn_implementation="{ requested_attention } "` is not supported. The only possible arguments are '
27142731 '`attn_implementation="eager"` (manual attention implementation)'
27152732 )
27162733 # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
@@ -2726,23 +2743,21 @@ def _check_and_adjust_attn_implementation(
27262743 raise ValueError (message + "." )
27272744
27282745 # Perform relevant checks
2729- if applicable_attn_implementation == "flash_attention_2" :
2746+ if requested_attention == "flash_attention_2" :
27302747 self ._flash_attn_2_can_dispatch (is_init_check )
2731- elif applicable_attn_implementation == "flash_attention_3" :
2748+ elif requested_attention == "flash_attention_3" :
27322749 self ._flash_attn_3_can_dispatch (is_init_check )
2733- elif applicable_attn_implementation == "flex_attention" :
2750+ elif requested_attention == "flex_attention" :
27342751 self ._flex_attn_can_dispatch (is_init_check )
2735- elif applicable_attn_implementation == "sdpa" :
2752+ elif requested_attention == "sdpa" :
27362753 # Sdpa is the default, so we try it and fallback to eager otherwise when not possible
27372754 try :
27382755 self ._sdpa_can_dispatch (is_init_check )
27392756 except (ValueError , ImportError ) as e :
2740- # In this case, sdpa was requested explicitly, but we can't use it, so let's raise
2741- if attn_implementation == "sdpa" :
2757+ if _requested_attention == "sdpa" :
27422758 raise e
2743- applicable_attn_implementation = "eager"
2744-
2745- return applicable_attn_implementation
2759+ requested_attention = "eager"
2760+ return requested_attention
27462761
27472762 @classmethod
27482763 def _can_set_attn_implementation (cls ) -> bool :
@@ -2790,7 +2805,7 @@ def set_attn_implementation(self, attn_implementation: Union[str, dict]):
27902805 )
27912806 # Apply the change (on the internal attr, to avoid setting it recursively)
27922807 self .config ._attn_implementation_internal = applicable_attn_implementation
2793- except ( ValueError , ImportError ) as e :
2808+ except Exception as e :
27942809 logger .warning (
27952810 f"Impossible to set the requested `attn_implementation`. The following error was captured: { str (e )} "
27962811 )
@@ -2814,8 +2829,13 @@ def set_attn_implementation(self, attn_implementation: Union[str, dict]):
28142829 subconfig_key , submodule .config ._attn_implementation
28152830 )
28162831 break
2817- submodule .set_attn_implementation (sub_implementation )
2818- subconfigs_changed .add (submodule .config .__class__ )
2832+ # check the module can use correctly, otherwise we silently set the config without the model using it
2833+ try :
2834+ sub_implementation = submodule .get_correct_attn_implementation (sub_implementation )
2835+ submodule .config ._attn_implementation = sub_implementation
2836+ subconfigs_changed .add (submodule .config .__class__ )
2837+ except Exception :
2838+ pass
28192839
28202840 # We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel
28212841 for subconfig_key in self .config .sub_configs :
@@ -5746,6 +5766,8 @@ def supports_tp_plan(self):
57465766 # Check if base model has a TP plan
57475767 if getattr (self .base_model , "_tp_plan" , None ) is not None :
57485768 return True
5769+ if self .config .base_model_tp_plan is not None :
5770+ return True
57495771 return False
57505772
57515773 @property
0 commit comments