Skip to content

Commit

Permalink
Merge pull request #417 from alexander-zap/fix_transformation_masked_lm
Browse files Browse the repository at this point in the history
Fixed bug where in masked_lm transformations only subwords were candidates for top_words
  • Loading branch information
jinyongyoo authored Feb 15, 2021
2 parents ad2cf4a + 203dba9 commit eebf207
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,21 @@ def _get_new_words(self, current_text, indices_to_modify):
top_words = []
for _id in ranked_indices:
_id = _id.item()
token = self._lm_tokenizer.convert_ids_to_tokens(_id)
word = self._lm_tokenizer.convert_ids_to_tokens(_id)
if utils.check_if_subword(
token,
word,
self._language_model.config.model_type,
(masked_index == 1),
):
word = utils.strip_BPE_artifacts(
token, self._language_model.config.model_type
word, self._language_model.config.model_type
)
if (
mask_token_probs[_id] >= self.min_confidence
and utils.is_one_word(word)
and not utils.check_if_punctuations(word)
):
top_words.append(word)
if (
mask_token_probs[_id] >= self.min_confidence
and utils.is_one_word(word)
and not utils.check_if_punctuations(word)
):
top_words.append(word)

if (
len(top_words) >= self.max_candidates
Expand Down
18 changes: 9 additions & 9 deletions textattack/transformations/word_merges/word_merge_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,21 +112,21 @@ def _get_merged_words(self, current_text, indices_to_modify):
top_words = []
for _id in ranked_indices:
_id = _id.item()
token = self._lm_tokenizer.convert_ids_to_tokens(_id)
word = self._lm_tokenizer.convert_ids_to_tokens(_id)
if utils.check_if_subword(
token,
word,
self._language_model.config.model_type,
(masked_index == 1),
):
word = utils.strip_BPE_artifacts(
token, self._language_model.config.model_type
word, self._language_model.config.model_type
)
if (
mask_token_probs[_id] >= self.min_confidence
and utils.is_one_word(word)
and not utils.check_if_punctuations(word)
):
top_words.append(word)
if (
mask_token_probs[_id] >= self.min_confidence
and utils.is_one_word(word)
and not utils.check_if_punctuations(word)
):
top_words.append(word)

if (
len(top_words) >= self.max_candidates
Expand Down
18 changes: 9 additions & 9 deletions textattack/transformations/word_swaps/word_swap_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,21 +136,21 @@ def _bae_replacement_words(self, current_text, indices_to_modify):
top_words = []
for _id in ranked_indices:
_id = _id.item()
token = self._lm_tokenizer.convert_ids_to_tokens(_id)
word = self._lm_tokenizer.convert_ids_to_tokens(_id)
if utils.check_if_subword(
token,
word,
self._language_model.config.model_type,
(masked_index == 1),
):
word = utils.strip_BPE_artifacts(
token, self._language_model.config.model_type
word, self._language_model.config.model_type
)
if (
mask_token_probs[_id] >= self.min_confidence
and utils.is_one_word(word)
and not utils.check_if_punctuations(word)
):
top_words.append(word)
if (
mask_token_probs[_id] >= self.min_confidence
and utils.is_one_word(word)
and not utils.check_if_punctuations(word)
):
top_words.append(word)

if (
len(top_words) >= self.max_candidates
Expand Down

0 comments on commit eebf207

Please sign in to comment.