Skip to content

Commit f1dd752

Browse files
committed
Unit test whole word masking
1 parent d3bf3c7 commit f1dd752

File tree

1 file changed

+113
-36
lines changed

1 file changed

+113
-36
lines changed

tests/trainer/test_data_collator.py

Lines changed: 113 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
DataCollatorWithFlattening,
3131
DataCollatorWithPadding,
3232
default_data_collator,
33-
is_torch_available,
3433
is_tf_available,
34+
is_torch_available,
3535
set_seed,
3636
)
3737
from transformers.testing_utils import require_torch
@@ -531,9 +531,7 @@ def test_data_collator_for_whole_word_mask(self):
531531

532532
input_tokens = [f"token_{i}" for i in range(8)]
533533
tokenizer.add_tokens(input_tokens)
534-
features = [
535-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)
536-
]
534+
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
537535

538536
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
539537

@@ -543,8 +541,7 @@ def test_data_collator_for_whole_word_mask(self):
543541

544542
# Features can already be tensors
545543
features = [
546-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True).convert_to_tensors("np")
547-
for _ in range(2)
544+
tokenizer(" ".join(input_tokens), return_offsets_mapping=True).convert_to_tensors("np") for _ in range(2)
548545
]
549546
batch = data_collator(features)
550547
self.assertEqual(batch["input_ids"].shape, (2, 10))
@@ -553,7 +550,7 @@ def test_data_collator_for_whole_word_mask(self):
553550
if is_torch_available():
554551
# Features can already be tensors
555552
features = [
556-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True).convert_to_tensors("pt")
553+
tokenizer(" ".join(input_tokens), return_offsets_mapping=True).convert_to_tensors("pt")
557554
for _ in range(2)
558555
]
559556
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
@@ -563,9 +560,10 @@ def test_data_collator_for_whole_word_mask(self):
563560

564561
if is_tf_available():
565562
import tensorflow as tf
563+
566564
# Features can already be tensors
567565
features = [
568-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True).convert_to_tensors("tf")
566+
tokenizer(" ".join(input_tokens), return_offsets_mapping=True).convert_to_tensors("tf")
569567
for _ in range(2)
570568
]
571569
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf")
@@ -578,9 +576,7 @@ def test_data_collator_for_whole_word_mask_with_seed(self):
578576

579577
input_tokens = [f"token_{i}" for i in range(998)]
580578
tokenizer.add_tokens(input_tokens)
581-
features = [
582-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)
583-
]
579+
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
584580

585581
# check if seed is respected between two different DataCollatorForWholeWordMask instances
586582
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
@@ -598,9 +594,7 @@ def test_data_collator_for_whole_word_mask_with_seed(self):
598594

599595
# check if seed is respected in multiple workers situation
600596
if is_torch_available():
601-
features = [
602-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(10)
603-
]
597+
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(10)]
604598
dataloader = torch.utils.data.DataLoader(
605599
features,
606600
batch_size=2,
@@ -975,17 +969,13 @@ def test_whole_world_masking_collator_immutability(self):
975969

976970
input_tokens = [f"token_{i}" for i in range(8)]
977971
tokenizer.add_tokens(input_tokens)
978-
original_data = [
979-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)
980-
]
972+
original_data = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
981973
for feature in original_data:
982-
feature['labels'] = (1,)
974+
feature["labels"] = (1,)
983975

984-
batch_data = [
985-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)
986-
]
976+
batch_data = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
987977
for feature in batch_data:
988-
feature['labels'] = (1,)
978+
feature["labels"] = (1,)
989979

990980
whole_word_masking_collator = DataCollatorForWholeWordMask(tokenizer)
991981

@@ -1450,9 +1440,7 @@ def test_data_collator_for_whole_word_mask(self):
14501440

14511441
input_tokens = [f"token_{i}" for i in range(8)]
14521442
tokenizer.add_tokens(input_tokens)
1453-
features = [
1454-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)
1455-
]
1443+
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
14561444

14571445
batch = data_collator(features)
14581446
self.assertEqual(batch["input_ids"].shape, (2, 10))
@@ -1471,9 +1459,7 @@ def test_data_collator_for_whole_word_mask_with_seed(self):
14711459

14721460
input_tokens = [f"token_{i}" for i in range(998)]
14731461
tokenizer.add_tokens(input_tokens)
1474-
features = [
1475-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)
1476-
]
1462+
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
14771463

14781464
# check if seed is respected between two different DataCollatorForWholeWordMask instances
14791465
data_collator = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
@@ -1816,17 +1802,13 @@ def test_whole_world_masking_collator_immutability(self):
18161802

18171803
input_tokens = [f"token_{i}" for i in range(8)]
18181804
tokenizer.add_tokens(input_tokens)
1819-
original_data = [
1820-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)
1821-
]
1805+
original_data = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
18221806
for feature in original_data:
1823-
feature['labels'] = (1,)
1807+
feature["labels"] = (1,)
18241808

1825-
batch_data = [
1826-
tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)
1827-
]
1809+
batch_data = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
18281810
for feature in batch_data:
1829-
feature['labels'] = (1,)
1811+
feature["labels"] = (1,)
18301812

18311813
whole_word_masking_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
18321814

@@ -1902,3 +1884,98 @@ def test_sentence_order_prediction_collator_immutability(self):
19021884
self._validate_original_data_against_collated_data(
19031885
collator=sop_collator, original_data=features_original, batch_data=features_batch
19041886
)
1887+
1888+
1889+
class DataCollatorForLanguageModelingUnitTest(unittest.TestCase):
1890+
def test__calc_word_ids_and_prob_mask(self):
1891+
offsets = np.array(
1892+
[
1893+
[(0, 0), (0, 3), (3, 4), (5, 6), (6, 7), (8, 9)],
1894+
[(0, 0), (0, 3), (3, 4), (5, 6), (6, 7), (0, 0)],
1895+
[(0, 0), (0, 3), (3, 4), (0, 0), (6, 7), (0, 0)],
1896+
[(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)],
1897+
[(1, 1), (2, 2), (3, 4), (5, 6), (7, 8), (9, 10)],
1898+
[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],
1899+
]
1900+
)
1901+
1902+
special_tokens_mask = np.array(
1903+
[
1904+
[1, 0, 0, 0, 0, 0],
1905+
[1, 0, 0, 0, 0, 1],
1906+
[1, 0, 0, 1, 0, 1],
1907+
[0, 0, 0, 0, 0, 0],
1908+
[0, 0, 0, 0, 0, 0],
1909+
[1, 1, 1, 1, 1, 1],
1910+
]
1911+
)
1912+
1913+
output_word_ids, output_prob_mask = DataCollatorForLanguageModeling._calc_word_ids_and_prob_mask(
1914+
offsets, special_tokens_mask
1915+
)
1916+
1917+
expected_word_ids = np.array(
1918+
[
1919+
[-1, 1, 1, 2, 2, 3],
1920+
[-1, 1, 1, 2, 2, -1],
1921+
[-1, 1, 1, -1, 2, -1],
1922+
[1, 1, 1, 1, 1, 1],
1923+
[1, 2, 3, 4, 5, 6],
1924+
[-1, -1, -1, -1, -1, -1],
1925+
]
1926+
)
1927+
1928+
expected_prob_mask = np.array(
1929+
[
1930+
[1, 0, 1, 0, 1, 0],
1931+
[1, 0, 1, 0, 1, 1],
1932+
[1, 0, 1, 1, 0, 1],
1933+
[0, 1, 1, 1, 1, 1],
1934+
[0, 0, 0, 0, 0, 0],
1935+
[1, 1, 1, 1, 1, 1],
1936+
]
1937+
)
1938+
1939+
assert np.array_equal(output_word_ids, expected_word_ids)
1940+
assert np.array_equal(output_prob_mask, expected_prob_mask)
1941+
1942+
def test__whole_word_mask(self):
1943+
word_ids = np.array(
1944+
[
1945+
[-1, 1, 1, 2, 2, 3],
1946+
[-1, 1, 1, 2, 2, -1],
1947+
[-1, 1, 1, -1, 2, -1],
1948+
[1, 1, 1, 1, 1, 1],
1949+
[1, 2, 3, 4, 5, 6],
1950+
[1, 2, 3, 4, 5, 6],
1951+
[-1, -1, -1, -1, -1, -1],
1952+
]
1953+
)
1954+
1955+
mask = np.array(
1956+
[
1957+
[0, 1, 0, 0, 0, 0],
1958+
[0, 1, 0, 1, 0, 0],
1959+
[0, 0, 0, 0, 1, 0],
1960+
[1, 0, 0, 0, 0, 0],
1961+
[0, 0, 0, 0, 0, 0],
1962+
[0, 1, 0, 1, 0, 1],
1963+
[0, 0, 0, 0, 0, 0],
1964+
]
1965+
).astype(bool)
1966+
1967+
output_mask = DataCollatorForLanguageModeling._whole_word_mask(word_ids, mask)
1968+
1969+
expected_mask = np.array(
1970+
[
1971+
[0, 1, 1, 0, 0, 0],
1972+
[0, 1, 1, 1, 1, 0],
1973+
[0, 0, 0, 0, 1, 0],
1974+
[1, 1, 1, 1, 1, 1],
1975+
[0, 0, 0, 0, 0, 0],
1976+
[0, 1, 0, 1, 0, 1],
1977+
[0, 0, 0, 0, 0, 0],
1978+
]
1979+
).astype(bool)
1980+
1981+
np.testing.assert_array_equal(output_mask, expected_mask)

0 commit comments

Comments
 (0)