Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding PrefixConstrainedLogitsProcessor #8529

Merged
merged 13 commits into from
Nov 18, 2020
Next Next commit
Adding PrefixConstrainedLogitsProcessor
nicola-decao committed Nov 13, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 6b3d5bbb33a412699fdde4525cfa3b40609af9e3
25 changes: 24 additions & 1 deletion src/transformers/generation_logits_process.py
Original file line number Diff line number Diff line change
@@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from abc import ABC
from typing import Iterable, List
from typing import Callable, Iterable, List

import numpy as np
import torch
@@ -372,3 +373,25 @@ def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_toke
)
scores = scores.masked_fill(banned_mask, -float("inf"))
return scores


class PrefixConstrainedLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` .

Args:
prefix_allowed_tokens_fn (:obj:`Callable`):

nicola-decao marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, prefix_allowed_tokens_fn: Callable, num_beams: int):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.full_like(scores, -math.inf)
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a future PR we could probably speed this up by just using torch.Tensor operations and not Python loops. Python loops really slow down the computation on GPU apparently (see: #6064). But we can do this in a future PR as well

Copy link
Contributor Author

@nicola-decao nicola-decao Nov 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to keep the same signature as in fairseq as if someone has already implemented one it can use the same.

for beam_id, sent in enumerate(beam_sent):
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0

return scores + mask
17 changes: 16 additions & 1 deletion src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import torch
from torch.nn import functional as F
@@ -26,6 +26,7 @@
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
@@ -249,6 +250,8 @@ def _get_logits_processor(
bad_words_ids: List[List[int]],
min_length: int,
eos_token_id: int,
prefix_allowed_tokens_fn: Callable,
num_beams: int,
) -> LogitsProcessorList:
"""
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
@@ -276,6 +279,8 @@ def _get_logits_processor(
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > -1:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if prefix_allowed_tokens_fn is not None:
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
return processors

@torch.no_grad()
@@ -300,6 +305,7 @@ def generate(
num_return_sequences: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
prefix_allowed_tokens_fn: Callable = None,
nicola-decao marked this conversation as resolved.
Show resolved Hide resolved
**model_kwargs
) -> torch.LongTensor:
r"""
@@ -366,6 +372,13 @@ def generate(
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
prefix_allowed_tokens_fn: (:obj:`Callable`, `optional`, defaults to :obj:`None`):
If provided, it has to be a function that has as arguments :obj:`inputs_id`. At each step of Beam
Search, this function is called with the :obj:`inputs_id` containing the previously generated tokens as
nicola-decao marked this conversation as resolved.
Show resolved Hide resolved
a tensor of shape :obj:`(batch_size * num_beams)`:. This function has to return a list of lists with
the allowed BPE tokens at the next step (list of batches and list of beams).
This argument is useful for constrained generation conditioned on the prefix. If not provided no
constrain is applied.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific
@@ -485,6 +498,8 @@ def generate(
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
)

if is_greedy_gen_mode: