Skip to content

Commit 49a2cc0

Browse files
rjgleatonRocketknight1
authored andcommitted
Remove support for TF in whole word masking
1 parent 9f66754 commit 49a2cc0

File tree

2 files changed

+1
-18
lines changed

2 files changed

+1
-18
lines changed

src/transformers/data/data_collator.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,6 @@ def numpy_mask_tokens(
910910
word_ids, no_mask_mask = self._calc_word_ids_and_prob_mask(
911911
to_numpy(offset_mapping), to_numpy(special_tokens_mask)
912912
)
913-
no_mask_mask = np.array(no_mask_mask, dtype=bool)
914913
else:
915914
no_mask_mask = (
916915
special_tokens_mask.astype(bool)
@@ -998,7 +997,7 @@ def _calc_word_ids_and_prob_mask(
998997
word_ids = np.cumsum(is_new_word, axis=1)
999998
word_ids[special_tokens_mask] = -1
1000999

1001-
prob_mask = (~is_new_word).astype(int)
1000+
prob_mask = ~is_new_word
10021001

10031002
return word_ids, prob_mask
10041003

@@ -1049,8 +1048,6 @@ def tolist(x) -> list[Any]:
10491048
def to_numpy(x) -> np.ndarray[Any]:
10501049
if isinstance(x, np.ndarray):
10511050
return x
1052-
elif hasattr(x, "numpy"):
1053-
return x.numpy()
10541051
elif hasattr(x, "detach"):
10551052
return x.detach().cpu().numpy()
10561053
else:

tests/trainer/test_data_collator.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
DataCollatorWithFlattening,
3131
DataCollatorWithPadding,
3232
default_data_collator,
33-
is_tf_available,
3433
is_torch_available,
3534
set_seed,
3635
)
@@ -558,19 +557,6 @@ def test_data_collator_for_whole_word_mask(self):
558557
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
559558
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
560559

561-
if is_tf_available():
562-
import tensorflow as tf
563-
564-
# Features can already be tensors
565-
features = [
566-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True).convert_to_tensors("tf")
567-
for _ in range(2)
568-
]
569-
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf")
570-
batch = data_collator(features)
571-
self.assertEqual(batch["input_ids"].shape, tf.TensorShape((2, 10)))
572-
self.assertEqual(batch["labels"].shape, tf.TensorShape((2, 10)))
573-
574560
def test_data_collator_for_whole_word_mask_with_seed(self):
575561
tokenizer = BertTokenizerFast(self.vocab_file)
576562

0 commit comments

Comments
 (0)