Skip to content

Commit

Permalink
Add ALL_LAYERNORM_LAYERS to match pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
a8nova committed Nov 21, 2023
1 parent b42fe29 commit 7e0a351
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/idefics/modeling_tf_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PretrainedConfig
from ...modeling_tf_utils import shape_list
#from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...tf_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
Expand Down Expand Up @@ -487,7 +487,7 @@ def call(self, hidden_states):
return self.weight * hidden_states


#ALL_LAYERNORM_LAYERS.append(TFIdeficsRMSNorm)
ALL_LAYERNORM_LAYERS.append(TFIdeficsRMSNorm)


class TFIdeficsEmbedding(tf.keras.layers.Layer):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from .utils import logging

ALL_LAYERNORM_LAYERS = [tf.keras.layers.LayerNormalization]

logger = logging.get_logger(__name__)

Expand Down

0 comments on commit 7e0a351

Please sign in to comment.