Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
add version check
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Jul 18, 2024
1 parent af1457e commit 415c286
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion mlora/models/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from mlora.models.modeling_gemma import GemmaEmbedding, GemmaRMSNorm
from mlora.models.modeling_llama import LlamaMLP
from mlora.utils import copy_parameters
from mlora.utils import copy_parameters, is_package_available

if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
Expand All @@ -36,6 +36,10 @@
inspect.signature(flash_attn_func).parameters
)

assert is_package_available(
"flash_attn", "2.6.0"
), "Gemma2 requires flash_attn>=2.6.0"


@dataclass
class Gemma2Config(LLMModelConfig):
Expand Down

0 comments on commit 415c286

Please sign in to comment.