diff --git a/mlora/models/modeling_gemma2.py b/mlora/models/modeling_gemma2.py index 042bef90..abdd4b3c 100644 --- a/mlora/models/modeling_gemma2.py +++ b/mlora/models/modeling_gemma2.py @@ -26,7 +26,7 @@ ) from mlora.models.modeling_gemma import GemmaEmbedding, GemmaRMSNorm from mlora.models.modeling_llama import LlamaMLP -from mlora.utils import copy_parameters +from mlora.utils import copy_parameters, is_package_available if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -36,6 +36,10 @@ inspect.signature(flash_attn_func).parameters ) + assert is_package_available( + "flash_attn", "2.6.0" + ), "Gemma2 requires flash_attn>=2.6.0" + @dataclass class Gemma2Config(LLMModelConfig):