55# modular_dummy_bert.py file directly. One of our CI enforces this.
66# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
77import math
8- import os
98from typing import Optional , Union
109
1110import torch
12- from packaging import version
1311from torch import nn
1412
1513from ...activations import ACT2FN
1917from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions , BaseModelOutputWithPoolingAndCrossAttentions
2018from ...modeling_utils import PreTrainedModel
2119from ...pytorch_utils import apply_chunking_to_forward , find_pruneable_heads_and_indices , prune_linear_layer
22- from ...utils import auto_docstring , get_torch_version , logging
20+ from ...utils import auto_docstring , logging
2321from ...utils .deprecation import deprecate_kwarg
2422from .configuration_dummy_bert import DummyBertConfig
2523
@@ -36,8 +34,7 @@ def __init__(self, config):
3634 self .position_embeddings = nn .Embedding (config .max_position_embeddings , config .hidden_size )
3735 self .token_type_embeddings = nn .Embedding (config .type_vocab_size , config .hidden_size )
3836
39- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
40- # any TensorFlow checkpoint file
37+ # self.LayerNorm is not snake-cased due to old tensorflow checkpoint name matching
4138 self .LayerNorm = nn .LayerNorm (config .hidden_size , eps = config .layer_norm_eps )
4239 self .dropout = nn .Dropout (config .hidden_dropout_prob )
4340 # position_ids (1, len position emb) is contiguous in memory and exported when serialized
@@ -228,7 +225,6 @@ class DummyBertSdpaSelfAttention(DummyBertSelfAttention):
228225 def __init__ (self , config , position_embedding_type = None , layer_idx = None ):
229226 super ().__init__ (config , position_embedding_type = position_embedding_type , layer_idx = layer_idx )
230227 self .dropout_prob = config .attention_probs_dropout_prob
231- self .require_contiguous_qkv = version .parse (get_torch_version ()) < version .parse ("2.2.0" )
232228
233229 # Adapted from DummyBertSelfAttention
234230 @deprecate_kwarg ("past_key_value" , new_name = "past_key_values" , version = "4.58" )
@@ -308,14 +304,6 @@ def forward(
308304 if is_cross_attention and isinstance (past_key_values , EncoderDecoderCache ):
309305 past_key_values .is_updated [self .layer_idx ] = True
310306
311- # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
312- # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
313- # Reference: https://github.com/pytorch/pytorch/issues/112577
314- if self .require_contiguous_qkv and query_layer .device .type == "cuda" and attention_mask is not None :
315- query_layer = query_layer .contiguous ()
316- key_layer = key_layer .contiguous ()
317- value_layer = value_layer .contiguous ()
318-
319307 # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
320308 # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
321309 # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
@@ -655,92 +643,16 @@ def forward(self, hidden_states):
655643 return hidden_states
656644
657645
658- def load_tf_weights_in_dummy_bert (model , config , tf_checkpoint_path ):
659- """Load tf checkpoints in a pytorch model."""
660- try :
661- import re
662-
663- import numpy as np
664- import tensorflow as tf
665- except ImportError :
666- logger .error (
667- "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
668- "https://www.tensorflow.org/install/ for installation instructions."
669- )
670- raise
671- tf_path = os .path .abspath (tf_checkpoint_path )
672- logger .info (f"Converting TensorFlow checkpoint from { tf_path } " )
673- # Load weights from TF model
674- init_vars = tf .train .list_variables (tf_path )
675- names = []
676- arrays = []
677- for name , shape in init_vars :
678- logger .info (f"Loading TF weight { name } with shape { shape } " )
679- array = tf .train .load_variable (tf_path , name )
680- names .append (name )
681- arrays .append (array )
682-
683- for name , array in zip (names , arrays ):
684- name = name .split ("/" )
685- # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
686- # which are not required for using pretrained model
687- if any (
688- n in ["adam_v" , "adam_m" , "AdamWeightDecayOptimizer" , "AdamWeightDecayOptimizer_1" , "global_step" ]
689- for n in name
690- ):
691- logger .info (f"Skipping { '/' .join (name )} " )
692- continue
693- pointer = model
694- for m_name in name :
695- if re .fullmatch (r"[A-Za-z]+_\d+" , m_name ):
696- scope_names = re .split (r"_(\d+)" , m_name )
697- else :
698- scope_names = [m_name ]
699- if scope_names [0 ] == "kernel" or scope_names [0 ] == "gamma" :
700- pointer = getattr (pointer , "weight" )
701- elif scope_names [0 ] == "output_bias" or scope_names [0 ] == "beta" :
702- pointer = getattr (pointer , "bias" )
703- elif scope_names [0 ] == "output_weights" :
704- pointer = getattr (pointer , "weight" )
705- elif scope_names [0 ] == "squad" :
706- pointer = getattr (pointer , "classifier" )
707- else :
708- try :
709- pointer = getattr (pointer , scope_names [0 ])
710- except AttributeError :
711- logger .info (f"Skipping { '/' .join (name )} " )
712- continue
713- if len (scope_names ) >= 2 :
714- num = int (scope_names [1 ])
715- pointer = pointer [num ]
716- if m_name [- 11 :] == "_embeddings" :
717- pointer = getattr (pointer , "weight" )
718- elif m_name == "kernel" :
719- array = np .transpose (array )
720- try :
721- if pointer .shape != array .shape :
722- raise ValueError (f"Pointer shape { pointer .shape } and array shape { array .shape } mismatched" )
723- except ValueError as e :
724- e .args += (pointer .shape , array .shape )
725- raise
726- logger .info (f"Initialize PyTorch weight { name } " )
727- pointer .data = torch .from_numpy (array )
728- return model
729-
730-
731646@auto_docstring
732647class DummyBertPreTrainedModel (PreTrainedModel ):
733648 config : DummyBertConfig
734- load_tf_weights = load_tf_weights_in_dummy_bert
735649 base_model_prefix = "dummy_bert"
736650 supports_gradient_checkpointing = True
737651 _supports_sdpa = True
738652
739653 def _init_weights (self , module ):
740654 """Initialize the weights"""
741655 if isinstance (module , nn .Linear ):
742- # Slightly different from the TF version which uses truncated_normal for initialization
743- # cf https://github.com/pytorch/pytorch/pull/5617
744656 module .weight .data .normal_ (mean = 0.0 , std = self .config .initializer_range )
745657 if module .bias is not None :
746658 module .bias .data .zero_ ()
0 commit comments