Skip to content

Commit d3bf3c7

Browse files
committed
Vectorize whole word masking functions
1 parent af9c275 commit d3bf3c7

File tree

1 file changed

+51
-48
lines changed

1 file changed

+51
-48
lines changed

src/transformers/data/data_collator.py

Lines changed: 51 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121

2222
import numpy as np
2323

24-
from ..models.bert import BertTokenizer, BertTokenizerFast
25-
from ..tokenization_utils_base import LARGE_INTEGER, PreTrainedTokenizerBase
24+
from ..tokenization_utils_base import PreTrainedTokenizerBase
2625
from ..utils import PaddingStrategy
2726

2827

@@ -932,10 +931,10 @@ def tf_bernoulli(shape, probability, generator=None):
932931
return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
933932

934933
def tf_mask_tokens(
935-
self,
936-
inputs: Any,
937-
vocab_size,
938-
mask_token_id,
934+
self,
935+
inputs: Any,
936+
vocab_size,
937+
mask_token_id,
939938
special_tokens_mask: Optional[Any] = None,
940939
offset_mapping: Optional[Any] = None,
941940
) -> tuple[Any, Any]:
@@ -948,7 +947,7 @@ def tf_mask_tokens(
948947

949948
if self.whole_word_mask:
950949
word_ids, no_mask_mask = self._calc_word_ids_and_prob_mask(
951-
tolist(offset_mapping), tolist(special_tokens_mask)
950+
to_numpy(offset_mapping), to_numpy(special_tokens_mask)
952951
)
953952
no_mask_mask = tf.cast(tf.constant(no_mask_mask), tf.bool)
954953
else:
@@ -1029,7 +1028,7 @@ def tf_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict
10291028
special_tokens_mask=special_tokens_mask,
10301029
mask_token_id=self.tokenizer.mask_token_id,
10311030
vocab_size=len(self.tokenizer),
1032-
offset_mapping=offset_mapping
1031+
offset_mapping=offset_mapping,
10331032
)
10341033
else:
10351034
labels = batch["input_ids"]
@@ -1073,7 +1072,7 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
10731072
return batch
10741073

10751074
def torch_mask_tokens(
1076-
self, inputs: Any, special_tokens_mask: Optional[Any] = None, offset_mapping: Optional[Any] = None
1075+
self, inputs: Any, special_tokens_mask: Optional[Any] = None, offset_mapping: Optional[Any] = None
10771076
) -> tuple[Any, Any]:
10781077
"""
10791078
Prepare masked tokens inputs/labels for masked language modeling.
@@ -1087,10 +1086,10 @@ def torch_mask_tokens(
10871086
special_tokens_mask = [
10881087
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
10891088
]
1090-
1089+
10911090
if self.whole_word_mask:
10921091
word_ids, no_mask_mask = self._calc_word_ids_and_prob_mask(
1093-
tolist(offset_mapping), tolist(special_tokens_mask)
1092+
to_numpy(offset_mapping), to_numpy(special_tokens_mask)
10941093
)
10951094
no_mask_mask = torch.tensor(no_mask_mask, dtype=torch.bool)
10961095
else:
@@ -1103,7 +1102,7 @@ def torch_mask_tokens(
11031102
probability_matrix.masked_fill_(no_mask_mask, value=0.0)
11041103
masked_indices = torch.bernoulli(probability_matrix, generator=self.generator).bool()
11051104
if self.whole_word_mask:
1106-
masked_indices = self._whole_word_mask(word_ids, masked_indices)
1105+
masked_indices = torch.BoolTensor(self._whole_word_mask(word_ids, masked_indices))
11071106

11081107
labels[~masked_indices] = -100 # We only compute loss on masked tokens
11091108

@@ -1182,9 +1181,11 @@ def numpy_mask_tokens(
11821181
special_tokens_mask = [
11831182
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
11841183
]
1185-
1184+
11861185
if self.whole_word_mask:
1187-
word_ids, no_mask_mask = self._calc_word_ids_and_prob_mask(tolist(offset_mapping), tolist(special_tokens_mask))
1186+
word_ids, no_mask_mask = self._calc_word_ids_and_prob_mask(
1187+
to_numpy(offset_mapping), to_numpy(special_tokens_mask)
1188+
)
11881189
no_mask_mask = np.array(no_mask_mask, dtype=bool)
11891190
else:
11901191
no_mask_mask = (
@@ -1246,61 +1247,52 @@ def numpy_mask_tokens(
12461247

12471248
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
12481249
return inputs, labels
1249-
1250+
12501251
@staticmethod
12511252
def _calc_word_ids_and_prob_mask(
1252-
offsets: list[list[tuple[int, int]]], special_tokens_mask: list[list[int]]
1253-
) -> tuple[list[list[int]], list[list[int]]]:
1253+
offsets: np.ndarray[np.ndarray[tuple[int, int]]], special_tokens_mask: np.ndarray[np.ndarray[int]]
1254+
) -> tuple[np.ndarray[np.ndarray[int]], np.ndarray[np.ndarray[int]]]:
12541255
"""
12551256
Map tokens to word ids and create mask of tokens to not mask.
12561257
Tokens that are part of the same word will have the same word id and we will only
12571258
set a mask probability for the first token of each word.
12581259
"""
12591260

1260-
batch_word_ids = []
1261-
batch_prob_mask = [[1] * len(o) for o in offsets] # Initialize with 1s, meaning no tokens can be masked
1261+
token_starts = offsets[:, :, 0]
1262+
token_ends = offsets[:, :, 1]
12621263

1263-
for seq_idx, offset in enumerate(offsets):
1264-
word_ids = []
1265-
current_word_id = 0
1266-
prev_token_end = None
1264+
prev_token_ends = np.roll(token_ends, 1, axis=1)
1265+
prev_token_ends[:, 0] = -1 # First token has no previous token
12671266

1268-
for token_idx, (token_start, token_end) in enumerate(offset):
1269-
if special_tokens_mask[seq_idx][token_idx] == 1:
1270-
word_ids.append(-1)
1271-
prev_token_end = None
1272-
continue
1267+
prev_token_special = np.roll(special_tokens_mask, 1, axis=1)
1268+
prev_token_special[:, 0] = 0
12731269

1274-
if (prev_token_end is None) or (prev_token_end != token_start):
1275-
current_word_id += 1
1276-
batch_prob_mask[seq_idx][token_idx] = 0 # This token can be masked
1270+
# Not special token AND (gap from previous or previous token was special)
1271+
special_tokens_mask = special_tokens_mask.astype(bool)
1272+
is_new_word = (~special_tokens_mask) & ((token_starts != prev_token_ends) | (prev_token_special == 1))
12771273

1278-
word_ids.append(current_word_id)
1279-
prev_token_end = token_end
1274+
word_ids = np.cumsum(is_new_word, axis=1)
1275+
word_ids[special_tokens_mask] = -1
12801276

1281-
batch_word_ids.append(word_ids)
1277+
prob_mask = (~is_new_word).astype(int)
12821278

1283-
return batch_word_ids, batch_prob_mask
1279+
return word_ids, prob_mask
12841280

12851281
@staticmethod
1286-
def _whole_word_mask(word_ids: list[list[int]], mask: Any) -> Any:
1282+
def _whole_word_mask(word_ids: np.ndarray[np.ndarray[int]], mask: Any) -> Any:
12871283
"""
12881284
Mask whole words based on word ids and mask.
12891285
"""
1290-
for seq_idx, (word_ids, mask_values) in enumerate(zip(word_ids, mask)):
1291-
for word_idx, id in enumerate(word_ids):
1292-
# Skip first word
1293-
if word_idx == 0:
1294-
continue
1286+
mask = to_numpy(mask)
1287+
1288+
valid_ids = word_ids != -1
12951289

1296-
# If the current token is the same as the previous token's word id
1297-
# and the previous token is masked, then mask the current token too
1298-
if (id == word_ids[word_idx - 1]) and (mask_values[word_idx - 1]):
1299-
# Previous token for same word is masked, so this one should be too
1300-
mask[seq_idx][word_idx] = True
1290+
# Create 3D mask where [batch, token_i, token_j] is True if token_i and token_j are the same word
1291+
same_word = (word_ids[:, :, None] == word_ids[:, None, :]) & valid_ids[:, :, None] & valid_ids[:, None, :]
1292+
1293+
# For each token, set True if any token in the same word is masked
1294+
return np.any(same_word & mask[:, None, :], axis=2)
13011295

1302-
return mask
1303-
13041296

13051297
@dataclass
13061298
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
@@ -1330,6 +1322,17 @@ def tolist(x) -> list[Any]:
13301322
return x.tolist()
13311323

13321324

1325+
def to_numpy(x) -> np.ndarray[Any]:
1326+
if isinstance(x, np.ndarray):
1327+
return x
1328+
elif hasattr(x, "numpy"):
1329+
return x.numpy()
1330+
elif hasattr(x, "detach"):
1331+
return x.detach().cpu().numpy()
1332+
else:
1333+
return np.array(x)
1334+
1335+
13331336
@dataclass
13341337
class DataCollatorForSOP(DataCollatorForLanguageModeling):
13351338
"""

0 commit comments

Comments
 (0)