2121
2222import numpy as np
2323
24- from ..models .bert import BertTokenizer , BertTokenizerFast
25- from ..tokenization_utils_base import LARGE_INTEGER , PreTrainedTokenizerBase
24+ from ..tokenization_utils_base import PreTrainedTokenizerBase
2625from ..utils import PaddingStrategy
2726
2827
@@ -797,7 +796,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
797796 return batch
798797
799798 def torch_mask_tokens (
800- self , inputs : Any , special_tokens_mask : Optional [Any ] = None , offset_mapping : Optional [Any ] = None
799+ self , inputs : Any , special_tokens_mask : Optional [Any ] = None , offset_mapping : Optional [Any ] = None
801800 ) -> tuple [Any , Any ]:
802801 """
803802 Prepare masked tokens inputs/labels for masked language modeling.
@@ -811,10 +810,10 @@ def torch_mask_tokens(
811810 special_tokens_mask = [
812811 self .tokenizer .get_special_tokens_mask (val , already_has_special_tokens = True ) for val in labels .tolist ()
813812 ]
814-
813+
815814 if self .whole_word_mask :
816815 word_ids , no_mask_mask = self ._calc_word_ids_and_prob_mask (
817- tolist (offset_mapping ), tolist (special_tokens_mask )
816+ to_numpy (offset_mapping ), to_numpy (special_tokens_mask )
818817 )
819818 no_mask_mask = torch .tensor (no_mask_mask , dtype = torch .bool )
820819 else :
@@ -827,7 +826,7 @@ def torch_mask_tokens(
827826 probability_matrix .masked_fill_ (no_mask_mask , value = 0.0 )
828827 masked_indices = torch .bernoulli (probability_matrix , generator = self .generator ).bool ()
829828 if self .whole_word_mask :
830- masked_indices = self ._whole_word_mask (word_ids , masked_indices )
829+ masked_indices = torch . BoolTensor ( self ._whole_word_mask (word_ids , masked_indices ) )
831830
832831 labels [~ masked_indices ] = - 100 # We only compute loss on masked tokens
833832
@@ -906,9 +905,11 @@ def numpy_mask_tokens(
906905 special_tokens_mask = [
907906 self .tokenizer .get_special_tokens_mask (val , already_has_special_tokens = True ) for val in labels .tolist ()
908907 ]
909-
908+
910909 if self .whole_word_mask :
911- word_ids , no_mask_mask = self ._calc_word_ids_and_prob_mask (tolist (offset_mapping ), tolist (special_tokens_mask ))
910+ word_ids , no_mask_mask = self ._calc_word_ids_and_prob_mask (
911+ to_numpy (offset_mapping ), to_numpy (special_tokens_mask )
912+ )
912913 no_mask_mask = np .array (no_mask_mask , dtype = bool )
913914 else :
914915 no_mask_mask = (
@@ -970,61 +971,52 @@ def numpy_mask_tokens(
970971
971972 # The rest of the time (10% of the time) we keep the masked input tokens unchanged
972973 return inputs , labels
973-
974+
974975 @staticmethod
975976 def _calc_word_ids_and_prob_mask (
976- offsets : list [ list [tuple [int , int ]]], special_tokens_mask : list [ list [int ]]
977- ) -> tuple [list [ list [int ]], list [ list [int ]]]:
977+ offsets : np . ndarray [ np . ndarray [tuple [int , int ]]], special_tokens_mask : np . ndarray [ np . ndarray [int ]]
978+ ) -> tuple [np . ndarray [ np . ndarray [int ]], np . ndarray [ np . ndarray [int ]]]:
978979 """
979980 Map tokens to word ids and create mask of tokens to not mask.
980981 Tokens that are part of the same word will have the same word id and we will only
981982 set a mask probability for the first token of each word.
982983 """
983984
984- batch_word_ids = [ ]
985- batch_prob_mask = [[ 1 ] * len ( o ) for o in offsets ] # Initialize with 1s, meaning no tokens can be masked
985+ token_starts = offsets [:, :, 0 ]
986+ token_ends = offsets [:, :, 1 ]
986987
987- for seq_idx , offset in enumerate (offsets ):
988- word_ids = []
989- current_word_id = 0
990- prev_token_end = None
988+ prev_token_ends = np .roll (token_ends , 1 , axis = 1 )
989+ prev_token_ends [:, 0 ] = - 1 # First token has no previous token
991990
992- for token_idx , (token_start , token_end ) in enumerate (offset ):
993- if special_tokens_mask [seq_idx ][token_idx ] == 1 :
994- word_ids .append (- 1 )
995- prev_token_end = None
996- continue
991+ prev_token_special = np .roll (special_tokens_mask , 1 , axis = 1 )
992+ prev_token_special [:, 0 ] = 0
997993
998- if ( prev_token_end is None ) or ( prev_token_end != token_start ):
999- current_word_id += 1
1000- batch_prob_mask [ seq_idx ][ token_idx ] = 0 # This token can be masked
994+ # Not special token AND (gap from previous or previous token was special)
995+ special_tokens_mask = special_tokens_mask . astype ( bool )
996+ is_new_word = ( ~ special_tokens_mask ) & (( token_starts != prev_token_ends ) | ( prev_token_special == 1 ))
1001997
1002- word_ids . append ( current_word_id )
1003- prev_token_end = token_end
998+ word_ids = np . cumsum ( is_new_word , axis = 1 )
999+ word_ids [ special_tokens_mask ] = - 1
10041000
1005- batch_word_ids . append ( word_ids )
1001+ prob_mask = ( ~ is_new_word ). astype ( int )
10061002
1007- return batch_word_ids , batch_prob_mask
1003+ return word_ids , prob_mask
10081004
10091005 @staticmethod
1010- def _whole_word_mask (word_ids : list [ list [int ]], mask : Any ) -> Any :
1006+ def _whole_word_mask (word_ids : np . ndarray [ np . ndarray [int ]], mask : Any ) -> Any :
10111007 """
10121008 Mask whole words based on word ids and mask.
10131009 """
1014- for seq_idx , (word_ids , mask_values ) in enumerate (zip (word_ids , mask )):
1015- for word_idx , id in enumerate (word_ids ):
1016- # Skip first word
1017- if word_idx == 0 :
1018- continue
1010+ mask = to_numpy (mask )
10191011
1020- # If the current token is the same as the previous token's word id
1021- # and the previous token is masked, then mask the current token too
1022- if (id == word_ids [word_idx - 1 ]) and (mask_values [word_idx - 1 ]):
1023- # Previous token for same word is masked, so this one should be too
1024- mask [seq_idx ][word_idx ] = True
1012+ valid_ids = word_ids != - 1
1013+
1014+ # Create 3D mask where [batch, token_i, token_j] is True if token_i and token_j are the same word
1015+ same_word = (word_ids [:, :, None ] == word_ids [:, None , :]) & valid_ids [:, :, None ] & valid_ids [:, None , :]
1016+
1017+ # For each token, set True if any token in the same word is masked
1018+ return np .any (same_word & mask [:, None , :], axis = 2 )
10251019
1026- return mask
1027-
10281020
10291021@dataclass
10301022class DataCollatorForWholeWordMask (DataCollatorForLanguageModeling ):
@@ -1054,6 +1046,17 @@ def tolist(x) -> list[Any]:
10541046 return x .tolist ()
10551047
10561048
1049+ def to_numpy (x ) -> np .ndarray [Any ]:
1050+ if isinstance (x , np .ndarray ):
1051+ return x
1052+ elif hasattr (x , "numpy" ):
1053+ return x .numpy ()
1054+ elif hasattr (x , "detach" ):
1055+ return x .detach ().cpu ().numpy ()
1056+ else :
1057+ return np .array (x )
1058+
1059+
10571060@dataclass
10581061class DataCollatorForSOP (DataCollatorForLanguageModeling ):
10591062 """
0 commit comments