Skip to content

Commit

Permalink
Revert "[Performance improvement] "Bad tokens ids" optimization (hugg…
Browse files Browse the repository at this point in the history
…ingface#6064)"

This reverts commit 633dd89.
  • Loading branch information
fabiocapsouza authored Nov 15, 2020
1 parent 0583d15 commit 302fceb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 121 deletions.
90 changes: 0 additions & 90 deletions src/transformers/data/test_generation_utils.py

This file was deleted.

37 changes: 6 additions & 31 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import logging
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -89,12 +89,11 @@ def postprocess_next_token_scores(
scores[i, banned_tokens] = -float("inf")

if bad_words_ids is not None:
# Exclude EOS token (already processed)
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
# calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
# Modify the scores in place by setting the banned tokens logits to `-inf`
set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

for i, banned_tokens in enumerate(banned_tokens):
scores[i, banned_tokens] = -float("inf")

return scores

Expand Down Expand Up @@ -894,7 +893,7 @@ def _tokens_match(prev_tokens, tokens):
bad_words_ids
)

if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue

Expand All @@ -905,30 +904,6 @@ def _tokens_match(prev_tokens, tokens):
return banned_tokens


def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
""" Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
Args:
scores: logits distribution of shape (batch size, vocabulary size)
banned_tokens: list of list of tokens to ban of length (batch_size)
"""
banned_mask_list = []
for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
banned_mask_list.append([idx, token])
if not banned_mask_list:
return
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
# [ 1 0 0 ]

banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
scores.masked_fill_(banned_mask, -float("inf"))


def top_k_top_p_filtering(
logits: Tensor,
top_k: int = 0,
Expand Down

0 comments on commit 302fceb

Please sign in to comment.