Skip to content

Commit

Permalink
Graduate MaskTransform from prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Aug 17, 2022
1 parent 2fd12f3 commit 7e27a5b
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 264 deletions.
120 changes: 0 additions & 120 deletions test/prototype/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import shutil
import tempfile
from unittest.mock import patch

import torch
from test.common.assets import get_asset_path
Expand All @@ -10,12 +9,9 @@
sentencepiece_processor,
sentencepiece_tokenizer,
VectorTransform,
MaskTransform,
)
from torchtext.prototype.vectors import FastText

from ..common.parameterized_utils import nested_params


class TestTransforms(TorchtextTestCase):
def test_sentencepiece_processor(self) -> None:
Expand Down Expand Up @@ -140,119 +136,3 @@ def test_sentencepiece_load_and_save(self) -> None:
torch.save(spm, save_path)
loaded_spm = torch.load(save_path)
self.assertEqual(expected, loaded_spm(input))


class TestMaskTransform(TorchtextTestCase):

"""
Testing under these assumed conditions:
Vocab maps the following tokens to the following ids:
['a', 'b', 'c', 'd', '[PAD]', '[MASK]', '[BOS]'] -> [0, 1, 2, 3, 4, 5, 6]
The sample token sequences are:
[["[BOS]", "a", "b", "c", "d"],
["[BOS]", "a", "b", "[PAD]", "[PAD]"]]
"""

sample_token_ids = torch.tensor([[6, 0, 1, 2, 3], [6, 0, 1, 4, 4]])

vocab_len = 7
pad_idx = 4
mask_idx = 5
bos_idx = 6

@nested_params([0.0, 1.0])
def test_mask_transform_probs(self, test_mask_prob):

# We pass (vocab_len - 1) into MaskTransform to test masking with a random token.
# This modifies the distribution from which token ids are randomly selected such that the
# largest token id availible for selection is 1 less than the actual largest token id in our
# vocab, which we've assigned to the [BOS] token. This allows us to test random replacement
# by ensuring that when the first token ([BOS]) in the first sample sequence is selected for random replacement,
# we know with certainty the token it is replaced with is different from the [BOS] token.
# In practice, however, the actual vocab length should be provided as the input parameter so that random
# replacement selects from all possible tokens in the vocab.
mask_transform = MaskTransform(
self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=False, mask_prob=test_mask_prob
)

# when mask_prob = 0, we expect the first token of the first sample sequence to be chosen for replacement
if test_mask_prob == 0.0:

# when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
self.assertEqual(self.sample_token_ids, masked_tokens)

# when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement to be
# changed to a random token_id
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 1.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)

# first token in first sequence should be different
self.assertNotEqual(masked_tokens[0, 0], self.sample_token_ids[0, 0])
# replaced token id should still be in vocab, not including [BOS]
assert masked_tokens[0, 0] in range(self.vocab_len - 1)

# all other tokens except for first token of first sequence should remain the same
self.assertEqual(self.sample_token_ids[0, 1:], masked_tokens[0, 1:])
self.assertEqual(self.sample_token_ids[1], masked_tokens[1])

# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
exp_tokens = torch.tensor([[5, 0, 1, 2, 3], [6, 0, 1, 4, 4]])
self.assertEqual(exp_tokens, masked_tokens)

# when mask_prob = 1, we expect all tokens that are not [BOS] or [PAD] to be chosen for replacement
# (under the default condition that mask_transform.mask_bos=False)
if test_mask_prob == 1.0:

# when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
self.assertEqual(self.sample_token_ids, masked_tokens)

# when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement
# to be changed to random token_ids. It is possible that the randomly selected token id is the same
# as the original token id, however we know deterministically that [BOS] and [PAD] tokens
# in the sequences will remain unchanged.
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 1.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
self.assertEqual(masked_tokens[:, 0], 6 * torch.ones_like(masked_tokens[:, 0]))
self.assertEqual(masked_tokens[1, 3:], 4 * torch.ones_like(masked_tokens[1, 3:]))

# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
exp_tokens = torch.tensor([[6, 5, 5, 5, 5], [6, 5, 5, 4, 4]])
self.assertEqual(exp_tokens, masked_tokens)

def test_mask_transform_mask_bos(self) -> None:
# MaskTransform has boolean parameter mask_bos to indicate whether or not [BOS] tokens
# should be eligible for replacement. The above tests of MaskTransform are under default value
# mask_bos = False. Here we test the case where mask_bos = True
mask_transform = MaskTransform(
self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=True, mask_prob=1.0
)

# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
exp_tokens = torch.tensor([[5, 5, 5, 5, 5], [5, 5, 5, 4, 4]])
self.assertEqual(exp_tokens, masked_tokens)
119 changes: 118 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from collections import OrderedDict
from unittest.mock import patch

import torch
from torchtext import transforms
from torchtext.transforms import RegexTokenizer
from torchtext.transforms import MaskTransform, RegexTokenizer
from torchtext.vocab import vocab

from .common.assets import get_asset_path
Expand Down Expand Up @@ -750,3 +751,119 @@ def test_regex_tokenizer_save_load(self) -> None:
loaded_tokenizer = torch.jit.load(save_path)
results = loaded_tokenizer(self.test_sample)
self.assertEqual(results, self.ref_results)


class TestMaskTransform(TorchtextTestCase):

"""
Testing under these assumed conditions:
Vocab maps the following tokens to the following ids:
['a', 'b', 'c', 'd', '[PAD]', '[MASK]', '[BOS]'] -> [0, 1, 2, 3, 4, 5, 6]
The sample token sequences are:
[["[BOS]", "a", "b", "c", "d"],
["[BOS]", "a", "b", "[PAD]", "[PAD]"]]
"""

sample_token_ids = torch.tensor([[6, 0, 1, 2, 3], [6, 0, 1, 4, 4]])

vocab_len = 7
pad_idx = 4
mask_idx = 5
bos_idx = 6

@nested_params([0.0, 1.0])
def test_mask_transform_probs(self, test_mask_prob):

# We pass (vocab_len - 1) into MaskTransform to test masking with a random token.
# This modifies the distribution from which token ids are randomly selected such that the
# largest token id availible for selection is 1 less than the actual largest token id in our
# vocab, which we've assigned to the [BOS] token. This allows us to test random replacement
# by ensuring that when the first token ([BOS]) in the first sample sequence is selected for random replacement,
# we know with certainty the token it is replaced with is different from the [BOS] token.
# In practice, however, the actual vocab length should be provided as the input parameter so that random
# replacement selects from all possible tokens in the vocab.
mask_transform = MaskTransform(
self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=False, mask_prob=test_mask_prob
)

# when mask_prob = 0, we expect the first token of the first sample sequence to be chosen for replacement
if test_mask_prob == 0.0:

# when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
self.assertEqual(self.sample_token_ids, masked_tokens)

# when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement to be
# changed to a random token_id
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
"torchtext.transforms.MaskTransform.rand_mask_prob", 1.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)

# first token in first sequence should be different
self.assertNotEqual(masked_tokens[0, 0], self.sample_token_ids[0, 0])
# replaced token id should still be in vocab, not including [BOS]
assert masked_tokens[0, 0] in range(self.vocab_len - 1)

# all other tokens except for first token of first sequence should remain the same
self.assertEqual(self.sample_token_ids[0, 1:], masked_tokens[0, 1:])
self.assertEqual(self.sample_token_ids[1], masked_tokens[1])

# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
exp_tokens = torch.tensor([[5, 0, 1, 2, 3], [6, 0, 1, 4, 4]])
self.assertEqual(exp_tokens, masked_tokens)

# when mask_prob = 1, we expect all tokens that are not [BOS] or [PAD] to be chosen for replacement
# (under the default condition that mask_transform.mask_bos=False)
if test_mask_prob == 1.0:

# when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
self.assertEqual(self.sample_token_ids, masked_tokens)

# when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement
# to be changed to random token_ids. It is possible that the randomly selected token id is the same
# as the original token id, however we know deterministically that [BOS] and [PAD] tokens
# in the sequences will remain unchanged.
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
"torchtext.transforms.MaskTransform.rand_mask_prob", 1.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
self.assertEqual(masked_tokens[:, 0], 6 * torch.ones_like(masked_tokens[:, 0]))
self.assertEqual(masked_tokens[1, 3:], 4 * torch.ones_like(masked_tokens[1, 3:]))

# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
exp_tokens = torch.tensor([[6, 5, 5, 5, 5], [6, 5, 5, 4, 4]])
self.assertEqual(exp_tokens, masked_tokens)

def test_mask_transform_mask_bos(self) -> None:
# MaskTransform has boolean parameter mask_bos to indicate whether or not [BOS] tokens
# should be eligible for replacement. The above tests of MaskTransform are under default value
# mask_bos = False. Here we test the case where mask_bos = True
mask_transform = MaskTransform(
self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=True, mask_prob=1.0
)

# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
):
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
exp_tokens = torch.tensor([[5, 5, 5, 5, 5], [5, 5, 5, 4, 4]])
self.assertEqual(exp_tokens, masked_tokens)
Loading

0 comments on commit 7e27a5b

Please sign in to comment.