Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding PrefixConstrainedLogitsProcessor #8529

Merged
merged 13 commits into from
Nov 18, 2020
30 changes: 29 additions & 1 deletion src/transformers/generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from abc import ABC
from typing import Iterable, List
from typing import Callable, Iterable, List

import numpy as np
import torch
Expand Down Expand Up @@ -372,3 +373,30 @@ def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_toke
)
scores = scores.masked_fill(banned_mask, -float("inf"))
return scores


class PrefixConstrainedLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` that enforces contrained generation and is useful for prefix-conditioned
constrained generation. See `Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__ for more
information.

Args:
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`):
This function constraints the beam search to allowed tokens only at each step. This function takes 2
arguments :obj:`inputs_ids` and the batch ID :obj:`batch_id`. It has to return a list with the allowed
tokens for the next generation step conditioned on the previously generated tokens :obj:`inputs_ids` and
the batch ID :obj:`batch_id`.
"""

def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.full_like(scores, -math.inf)
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a future PR we could probably speed this up by just using torch.Tensor operations and not Python loops. Python loops really slow down the computation on GPU apparently (see: #6064). But we can do this in a future PR as well

Copy link
Contributor Author

@nicola-decao nicola-decao Nov 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to keep the same signature as in fairseq as if someone has already implemented one it can use the same.

for beam_id, sent in enumerate(beam_sent):
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0

return scores + mask
17 changes: 16 additions & 1 deletion src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import torch
from torch.nn import functional as F
Expand All @@ -26,6 +26,7 @@
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
Expand Down Expand Up @@ -258,6 +259,8 @@ def _get_logits_processor(
bad_words_ids: List[List[int]],
min_length: int,
eos_token_id: int,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
num_beams: int,
) -> LogitsProcessorList:
"""
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant
Expand Down Expand Up @@ -285,6 +288,8 @@ def _get_logits_processor(
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > -1:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if prefix_allowed_tokens_fn is not None:
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
return processors

@torch.no_grad()
Expand All @@ -309,6 +314,7 @@ def generate(
num_return_sequences: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**model_kwargs
) -> torch.LongTensor:
r"""
Expand Down Expand Up @@ -375,6 +381,13 @@ def generate(
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
argument is useful for constrained generation conditioned on the prefix, as described in
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific
Expand Down Expand Up @@ -494,6 +507,8 @@ def generate(
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
)

if is_greedy_gen_mode:
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""RAG model implementation."""

from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -1229,6 +1229,7 @@ def generate(
num_return_sequences=None,
decoder_start_token_id=None,
n_docs=None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
**model_kwargs
):
"""
Expand Down Expand Up @@ -1302,6 +1303,13 @@ def generate(
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
prefix_allowed_tokens_fn: (:obj:`Callable[[int, torch.Tensor], List[int]]`, `optional`):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments :obj:`inputs_ids` and the batch ID
:obj:`batch_id`. It has to return a list with the allowed tokens for the next generation step
conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This
argument is useful for constrained generation conditioned on the prefix, as described in
`Autoregressive Entity Retrieval <https://arxiv.org/abs/2010.00904>`__.

Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
Expand Down Expand Up @@ -1395,6 +1403,8 @@ def extend_enc_output(tensor, num_beams=None):
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
)

if num_beams == 1:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
Expand Down Expand Up @@ -281,3 +282,23 @@ def test_processor_list(self):

# input_ids should never be changed
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())

def test_prefix_constrained_logits_processor(self):
vocab_size = 5
batch_size = 2

input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
scores = self._get_uniform_logits(batch_size, vocab_size)

def prefix_allowed_tokens_fn(batch_id, inputs_ids):
return [[0, 1], [2, 3]][batch_id]

prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, 1)

filtered_scores = prefix_constrained_logits_proc(input_ids, scores.clone())

# batch 1: 1st, 2nd (0, 1) token are allowed
# batch 2: 3rd, 4th (2, 3) token are allowed
self.assertListEqual(
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
)