diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 1e7c6eae0ee719..5f8eaf89ed9353 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -30,6 +30,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -37,11 +38,11 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -49,16 +50,6 @@ from .configuration_glm import GlmConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward -from ...processing_utils import Unpack - - -_CHECKPOINT_FOR_DOC = "dummy" - - _CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b" diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index 9cfd617eeb2353..39ee4a2ad5803e 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -44,11 +44,9 @@ from .configuration_glm import GlmConfig -_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b" - logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "dummy" +_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b" class GlmRMSNorm(Phi3RMSNorm): diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 9e638c27afa41d..a1a86e3672d5fc 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -39,7 +40,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -47,9 +47,6 @@ from .configuration_phi3 import Phi3Config -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"