2121
2222import numpy as np
2323
24- from ..models .bert import BertTokenizer , BertTokenizerFast
25- from ..tokenization_utils_base import LARGE_INTEGER , PreTrainedTokenizerBase
24+ from ..tokenization_utils_base import PreTrainedTokenizerBase
2625from ..utils import PaddingStrategy
2726
2827
@@ -932,10 +931,10 @@ def tf_bernoulli(shape, probability, generator=None):
932931 return tf .cast (prob_matrix - tf .random .uniform (shape , 0 , 1 ) >= 0 , tf .bool )
933932
934933 def tf_mask_tokens (
935- self ,
936- inputs : Any ,
937- vocab_size ,
938- mask_token_id ,
934+ self ,
935+ inputs : Any ,
936+ vocab_size ,
937+ mask_token_id ,
939938 special_tokens_mask : Optional [Any ] = None ,
940939 offset_mapping : Optional [Any ] = None ,
941940 ) -> tuple [Any , Any ]:
@@ -948,7 +947,7 @@ def tf_mask_tokens(
948947
949948 if self .whole_word_mask :
950949 word_ids , no_mask_mask = self ._calc_word_ids_and_prob_mask (
951- tolist (offset_mapping ), tolist (special_tokens_mask )
950+ to_numpy (offset_mapping ), to_numpy (special_tokens_mask )
952951 )
953952 no_mask_mask = tf .cast (tf .constant (no_mask_mask ), tf .bool )
954953 else :
@@ -1029,7 +1028,7 @@ def tf_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict
10291028 special_tokens_mask = special_tokens_mask ,
10301029 mask_token_id = self .tokenizer .mask_token_id ,
10311030 vocab_size = len (self .tokenizer ),
1032- offset_mapping = offset_mapping
1031+ offset_mapping = offset_mapping ,
10331032 )
10341033 else :
10351034 labels = batch ["input_ids" ]
@@ -1073,7 +1072,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
10731072 return batch
10741073
10751074 def torch_mask_tokens (
1076- self , inputs : Any , special_tokens_mask : Optional [Any ] = None , offset_mapping : Optional [Any ] = None
1075+ self , inputs : Any , special_tokens_mask : Optional [Any ] = None , offset_mapping : Optional [Any ] = None
10771076 ) -> tuple [Any , Any ]:
10781077 """
10791078 Prepare masked tokens inputs/labels for masked language modeling.
@@ -1087,10 +1086,10 @@ def torch_mask_tokens(
10871086 special_tokens_mask = [
10881087 self .tokenizer .get_special_tokens_mask (val , already_has_special_tokens = True ) for val in labels .tolist ()
10891088 ]
1090-
1089+
10911090 if self .whole_word_mask :
10921091 word_ids , no_mask_mask = self ._calc_word_ids_and_prob_mask (
1093- tolist (offset_mapping ), tolist (special_tokens_mask )
1092+ to_numpy (offset_mapping ), to_numpy (special_tokens_mask )
10941093 )
10951094 no_mask_mask = torch .tensor (no_mask_mask , dtype = torch .bool )
10961095 else :
@@ -1103,7 +1102,7 @@ def torch_mask_tokens(
11031102 probability_matrix .masked_fill_ (no_mask_mask , value = 0.0 )
11041103 masked_indices = torch .bernoulli (probability_matrix , generator = self .generator ).bool ()
11051104 if self .whole_word_mask :
1106- masked_indices = self ._whole_word_mask (word_ids , masked_indices )
1105+ masked_indices = torch . BoolTensor ( self ._whole_word_mask (word_ids , masked_indices ) )
11071106
11081107 labels [~ masked_indices ] = - 100 # We only compute loss on masked tokens
11091108
@@ -1182,9 +1181,11 @@ def numpy_mask_tokens(
11821181 special_tokens_mask = [
11831182 self .tokenizer .get_special_tokens_mask (val , already_has_special_tokens = True ) for val in labels .tolist ()
11841183 ]
1185-
1184+
11861185 if self .whole_word_mask :
1187- word_ids , no_mask_mask = self ._calc_word_ids_and_prob_mask (tolist (offset_mapping ), tolist (special_tokens_mask ))
1186+ word_ids , no_mask_mask = self ._calc_word_ids_and_prob_mask (
1187+ to_numpy (offset_mapping ), to_numpy (special_tokens_mask )
1188+ )
11881189 no_mask_mask = np .array (no_mask_mask , dtype = bool )
11891190 else :
11901191 no_mask_mask = (
@@ -1246,61 +1247,52 @@ def numpy_mask_tokens(
12461247
12471248 # The rest of the time (10% of the time) we keep the masked input tokens unchanged
12481249 return inputs , labels
1249-
1250+
12501251 @staticmethod
12511252 def _calc_word_ids_and_prob_mask (
1252- offsets : list [ list [tuple [int , int ]]], special_tokens_mask : list [ list [int ]]
1253- ) -> tuple [list [ list [int ]], list [ list [int ]]]:
1253+ offsets : np . ndarray [ np . ndarray [tuple [int , int ]]], special_tokens_mask : np . ndarray [ np . ndarray [int ]]
1254+ ) -> tuple [np . ndarray [ np . ndarray [int ]], np . ndarray [ np . ndarray [int ]]]:
12541255 """
12551256 Map tokens to word ids and create mask of tokens to not mask.
12561257 Tokens that are part of the same word will have the same word id and we will only
12571258 set a mask probability for the first token of each word.
12581259 """
12591260
1260- batch_word_ids = [ ]
1261- batch_prob_mask = [[ 1 ] * len ( o ) for o in offsets ] # Initialize with 1s, meaning no tokens can be masked
1261+ token_starts = offsets [:, :, 0 ]
1262+ token_ends = offsets [:, :, 1 ]
12621263
1263- for seq_idx , offset in enumerate (offsets ):
1264- word_ids = []
1265- current_word_id = 0
1266- prev_token_end = None
1264+ prev_token_ends = np .roll (token_ends , 1 , axis = 1 )
1265+ prev_token_ends [:, 0 ] = - 1 # First token has no previous token
12671266
1268- for token_idx , (token_start , token_end ) in enumerate (offset ):
1269- if special_tokens_mask [seq_idx ][token_idx ] == 1 :
1270- word_ids .append (- 1 )
1271- prev_token_end = None
1272- continue
1267+ prev_token_special = np .roll (special_tokens_mask , 1 , axis = 1 )
1268+ prev_token_special [:, 0 ] = 0
12731269
1274- if ( prev_token_end is None ) or ( prev_token_end != token_start ):
1275- current_word_id += 1
1276- batch_prob_mask [ seq_idx ][ token_idx ] = 0 # This token can be masked
1270+ # Not special token AND (gap from previous or previous token was special)
1271+ special_tokens_mask = special_tokens_mask . astype ( bool )
1272+ is_new_word = ( ~ special_tokens_mask ) & (( token_starts != prev_token_ends ) | ( prev_token_special == 1 ))
12771273
1278- word_ids . append ( current_word_id )
1279- prev_token_end = token_end
1274+ word_ids = np . cumsum ( is_new_word , axis = 1 )
1275+ word_ids [ special_tokens_mask ] = - 1
12801276
1281- batch_word_ids . append ( word_ids )
1277+ prob_mask = ( ~ is_new_word ). astype ( int )
12821278
1283- return batch_word_ids , batch_prob_mask
1279+ return word_ids , prob_mask
12841280
12851281 @staticmethod
1286- def _whole_word_mask (word_ids : list [ list [int ]], mask : Any ) -> Any :
1282+ def _whole_word_mask (word_ids : np . ndarray [ np . ndarray [int ]], mask : Any ) -> Any :
12871283 """
12881284 Mask whole words based on word ids and mask.
12891285 """
1290- for seq_idx , (word_ids , mask_values ) in enumerate (zip (word_ids , mask )):
1291- for word_idx , id in enumerate (word_ids ):
1292- # Skip first word
1293- if word_idx == 0 :
1294- continue
1286+ mask = to_numpy (mask )
1287+
1288+ valid_ids = word_ids != - 1
12951289
1296- # If the current token is the same as the previous token's word id
1297- # and the previous token is masked, then mask the current token too
1298- if ( id == word_ids [ word_idx - 1 ]) and ( mask_values [ word_idx - 1 ]):
1299- # Previous token for same word is masked, so this one should be too
1300- mask [seq_idx ][ word_idx ] = True
1290+ # Create 3D mask where [batch, token_i, token_j] is True if token_i and token_j are the same word
1291+ same_word = ( word_ids [:, :, None ] == word_ids [:, None , :]) & valid_ids [:, :, None ] & valid_ids [:, None , :]
1292+
1293+ # For each token, set True if any token in the same word is masked
1294+ return np . any ( same_word & mask [:, None , :], axis = 2 )
13011295
1302- return mask
1303-
13041296
13051297@dataclass
13061298class DataCollatorForWholeWordMask (DataCollatorForLanguageModeling ):
@@ -1330,6 +1322,17 @@ def tolist(x) -> list[Any]:
13301322 return x .tolist ()
13311323
13321324
1325+ def to_numpy (x ) -> np .ndarray [Any ]:
1326+ if isinstance (x , np .ndarray ):
1327+ return x
1328+ elif hasattr (x , "numpy" ):
1329+ return x .numpy ()
1330+ elif hasattr (x , "detach" ):
1331+ return x .detach ().cpu ().numpy ()
1332+ else :
1333+ return np .array (x )
1334+
1335+
13331336@dataclass
13341337class DataCollatorForSOP (DataCollatorForLanguageModeling ):
13351338 """
0 commit comments