Skip to content

Commit cdadf0b

Browse files
committed
Vectorize whole word masking functions
1 parent 26c92b4 commit cdadf0b

File tree

1 file changed

+45
-42
lines changed

1 file changed

+45
-42
lines changed

src/transformers/data/data_collator.py

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121

2222
import numpy as np
2323

24-
from ..models.bert import BertTokenizer, BertTokenizerFast
25-
from ..tokenization_utils_base import LARGE_INTEGER, PreTrainedTokenizerBase
24+
from ..tokenization_utils_base import PreTrainedTokenizerBase
2625
from ..utils import PaddingStrategy
2726

2827

@@ -797,7 +796,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
797796
return batch
798797

799798
def torch_mask_tokens(
800-
self, inputs: Any, special_tokens_mask: Optional[Any] = None, offset_mapping: Optional[Any] = None
799+
self, inputs: Any, special_tokens_mask: Optional[Any] = None, offset_mapping: Optional[Any] = None
801800
) -> tuple[Any, Any]:
802801
"""
803802
Prepare masked tokens inputs/labels for masked language modeling.
@@ -811,10 +810,10 @@ def torch_mask_tokens(
811810
special_tokens_mask = [
812811
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
813812
]
814-
813+
815814
if self.whole_word_mask:
816815
word_ids, no_mask_mask = self._calc_word_ids_and_prob_mask(
817-
tolist(offset_mapping), tolist(special_tokens_mask)
816+
to_numpy(offset_mapping), to_numpy(special_tokens_mask)
818817
)
819818
no_mask_mask = torch.tensor(no_mask_mask, dtype=torch.bool)
820819
else:
@@ -827,7 +826,7 @@ def torch_mask_tokens(
827826
probability_matrix.masked_fill_(no_mask_mask, value=0.0)
828827
masked_indices = torch.bernoulli(probability_matrix, generator=self.generator).bool()
829828
if self.whole_word_mask:
830-
masked_indices = self._whole_word_mask(word_ids, masked_indices)
829+
masked_indices = torch.BoolTensor(self._whole_word_mask(word_ids, masked_indices))
831830

832831
labels[~masked_indices] = -100 # We only compute loss on masked tokens
833832

@@ -906,9 +905,11 @@ def numpy_mask_tokens(
906905
special_tokens_mask = [
907906
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
908907
]
909-
908+
910909
if self.whole_word_mask:
911-
word_ids, no_mask_mask = self._calc_word_ids_and_prob_mask(tolist(offset_mapping), tolist(special_tokens_mask))
910+
word_ids, no_mask_mask = self._calc_word_ids_and_prob_mask(
911+
to_numpy(offset_mapping), to_numpy(special_tokens_mask)
912+
)
912913
no_mask_mask = np.array(no_mask_mask, dtype=bool)
913914
else:
914915
no_mask_mask = (
@@ -970,61 +971,52 @@ def numpy_mask_tokens(
970971

971972
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
972973
return inputs, labels
973-
974+
974975
@staticmethod
975976
def _calc_word_ids_and_prob_mask(
976-
offsets: list[list[tuple[int, int]]], special_tokens_mask: list[list[int]]
977-
) -> tuple[list[list[int]], list[list[int]]]:
977+
offsets: np.ndarray[np.ndarray[tuple[int, int]]], special_tokens_mask: np.ndarray[np.ndarray[int]]
978+
) -> tuple[np.ndarray[np.ndarray[int]], np.ndarray[np.ndarray[int]]]:
978979
"""
979980
Map tokens to word ids and create mask of tokens to not mask.
980981
Tokens that are part of the same word will have the same word id and we will only
981982
set a mask probability for the first token of each word.
982983
"""
983984

984-
batch_word_ids = []
985-
batch_prob_mask = [[1] * len(o) for o in offsets] # Initialize with 1s, meaning no tokens can be masked
985+
token_starts = offsets[:, :, 0]
986+
token_ends = offsets[:, :, 1]
986987

987-
for seq_idx, offset in enumerate(offsets):
988-
word_ids = []
989-
current_word_id = 0
990-
prev_token_end = None
988+
prev_token_ends = np.roll(token_ends, 1, axis=1)
989+
prev_token_ends[:, 0] = -1 # First token has no previous token
991990

992-
for token_idx, (token_start, token_end) in enumerate(offset):
993-
if special_tokens_mask[seq_idx][token_idx] == 1:
994-
word_ids.append(-1)
995-
prev_token_end = None
996-
continue
991+
prev_token_special = np.roll(special_tokens_mask, 1, axis=1)
992+
prev_token_special[:, 0] = 0
997993

998-
if (prev_token_end is None) or (prev_token_end != token_start):
999-
current_word_id += 1
1000-
batch_prob_mask[seq_idx][token_idx] = 0 # This token can be masked
994+
# Not special token AND (gap from previous or previous token was special)
995+
special_tokens_mask = special_tokens_mask.astype(bool)
996+
is_new_word = (~special_tokens_mask) & ((token_starts != prev_token_ends) | (prev_token_special == 1))
1001997

1002-
word_ids.append(current_word_id)
1003-
prev_token_end = token_end
998+
word_ids = np.cumsum(is_new_word, axis=1)
999+
word_ids[special_tokens_mask] = -1
10041000

1005-
batch_word_ids.append(word_ids)
1001+
prob_mask = (~is_new_word).astype(int)
10061002

1007-
return batch_word_ids, batch_prob_mask
1003+
return word_ids, prob_mask
10081004

10091005
@staticmethod
1010-
def _whole_word_mask(word_ids: list[list[int]], mask: Any) -> Any:
1006+
def _whole_word_mask(word_ids: np.ndarray[np.ndarray[int]], mask: Any) -> Any:
10111007
"""
10121008
Mask whole words based on word ids and mask.
10131009
"""
1014-
for seq_idx, (word_ids, mask_values) in enumerate(zip(word_ids, mask)):
1015-
for word_idx, id in enumerate(word_ids):
1016-
# Skip first word
1017-
if word_idx == 0:
1018-
continue
1010+
mask = to_numpy(mask)
10191011

1020-
# If the current token is the same as the previous token's word id
1021-
# and the previous token is masked, then mask the current token too
1022-
if (id == word_ids[word_idx - 1]) and (mask_values[word_idx - 1]):
1023-
# Previous token for same word is masked, so this one should be too
1024-
mask[seq_idx][word_idx] = True
1012+
valid_ids = word_ids != -1
1013+
1014+
# Create 3D mask where [batch, token_i, token_j] is True if token_i and token_j are the same word
1015+
same_word = (word_ids[:, :, None] == word_ids[:, None, :]) & valid_ids[:, :, None] & valid_ids[:, None, :]
1016+
1017+
# For each token, set True if any token in the same word is masked
1018+
return np.any(same_word & mask[:, None, :], axis=2)
10251019

1026-
return mask
1027-
10281020

10291021
@dataclass
10301022
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
@@ -1054,6 +1046,17 @@ def tolist(x) -> list[Any]:
10541046
return x.tolist()
10551047

10561048

1049+
def to_numpy(x) -> np.ndarray[Any]:
1050+
if isinstance(x, np.ndarray):
1051+
return x
1052+
elif hasattr(x, "numpy"):
1053+
return x.numpy()
1054+
elif hasattr(x, "detach"):
1055+
return x.detach().cpu().numpy()
1056+
else:
1057+
return np.array(x)
1058+
1059+
10571060
@dataclass
10581061
class DataCollatorForSOP(DataCollatorForLanguageModeling):
10591062
"""

0 commit comments

Comments
 (0)