Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding FlaxNoRepeatNGramLogitsProcessor #29677

Merged
merged 14 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
"FlaxWhisperTimeStampLogitsProcessor",
"FlaxNoRepeatNGramLogitsProcessor",
]
_import_structure["flax_utils"] = [
"FlaxGenerationMixin",
Expand Down Expand Up @@ -292,6 +293,7 @@
FlaxLogitsProcessorList,
FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor,
FlaxNoRepeatNGramLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper,
Expand Down
85 changes: 85 additions & 0 deletions src/transformers/generation/flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import jax
import jax.lax as lax
import jax.numpy as jnp
from jax.experimental import sparse
gante marked this conversation as resolved.
Show resolved Hide resolved

from ..utils import add_start_docstrings
from ..utils.logging import get_logger
Expand Down Expand Up @@ -455,3 +456,87 @@ def handle_cumulative_probs(logprobs_k, scores_k):
scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)

return scores


class FlaxNoRepeatNGramLogitsProcessor(FlaxLogitsProcessor):
r"""
[`FlaxLogitsProcessor`] that enforces no repetition of n-grams. See
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).

Args:
ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once.
"""

def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size

def get_previous_ngrams(self, input_ids: jnp.ndarray, vocab_size: int, cur_len: int):
"""
get a matrix of size (batch_size,) + (vocab_size,)*n (for n-grams) that
represent the n-grams that occured previously.
The BCOO representation allow to store only the few non-zero entries, instead of the full (huge) matrix
"""
batch_size, seq_len = input_ids.shape

def body_fun(i, val):
b = i % batch_size
pos = i // batch_size
return val.at[i].set(
jnp.array(
[
b,
]
+ [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)]
)
)
Comment on lines +491 to +498
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - can all be one line

Suggested change
return val.at[i].set(
jnp.array(
[
b,
]
+ [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)]
)
)
return val.at[i].set(jnp.array([b] + [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)]))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this was formatted like this by the black / ruff formatter. But if that passes the tests, I agree that it is clearer this way


shape = (batch_size * (seq_len - (self.ngram_size - 1)), self.ngram_size + 1)
all_update_indices = jax.lax.fori_loop(
0, batch_size * (cur_len - (self.ngram_size - 1)), body_fun, jnp.zeros(shape, dtype=input_ids.dtype)
)

# ignore the n-grams not yet generated
data = (
jnp.arange(batch_size * (seq_len - (self.ngram_size - 1))) < batch_size * (cur_len - (self.ngram_size - 1))
).astype("float32")

return sparse.BCOO((data, all_update_indices), shape=(batch_size,) + (vocab_size,) * self.ngram_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than calculate (seq_len - (self.ngram_size - 1) and (cur_len - (self.ngram_size - 1) several times, it'll be easier to read and follow this code if they're set to variables and then used


def get_banned_tokens_mask(self, latest_tokens: jnp.ndarray, previous_ngrams) -> jnp.ndarray:
"""
Determines which tokens must be banned given latest tokens and the previously seen
ngrams.
"""

@sparse.sparsify
@jax.vmap
def inner_fn(latest_tokens, previous_ngrams):
return previous_ngrams[tuple(latest_tokens)]

return sparse.bcoo_todense(inner_fn(latest_tokens, previous_ngrams))

def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
def true_fn():
_, vocab_size = scores.shape
# store the previously seen n-grams
previous_ngrams = self.get_previous_ngrams(input_ids, vocab_size, cur_len)

# get the n-1 last tokens that prefix the n-gram being generated
latest_tokens = jnp.zeros((input_ids.shape[0], self.ngram_size - 1), dtype=input_ids.dtype)
latest_tokens = jax.lax.dynamic_update_slice(
latest_tokens,
jax.lax.dynamic_slice(
input_ids, (0, cur_len - (self.ngram_size - 1)), (input_ids.shape[0], (self.ngram_size - 1))
),
(0, 0),
)
gante marked this conversation as resolved.
Show resolved Hide resolved

# compute the banned tokens, ie all the tokens that when added to the latest tokens lead to a n-gram that was previously generated
banned_tokens_indices_mask = self.get_banned_tokens_mask(latest_tokens, previous_ngrams).astype("bool")
return jnp.where(banned_tokens_indices_mask, -float("inf"), scores)

output = jax.lax.cond((cur_len >= self.ngram_size - 1), true_fn, lambda: scores)
return output
3 changes: 3 additions & 0 deletions src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
FlaxForceTokensLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxNoRepeatNGramLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper,
Expand Down Expand Up @@ -534,6 +535,8 @@ def _get_logits_processor(
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
]
processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
processors.append(FlaxNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
processors = self._merge_criteria_processor_list(processors, logits_processor)

return processors
Expand Down
45 changes: 43 additions & 2 deletions tests/generation/test_flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxNoRepeatNGramLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
Expand Down Expand Up @@ -197,6 +198,26 @@ def test_forced_eos_token_logits_processor(self):
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores).any())

def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2

cur_len = 4
input_ids = np.array([[1, 1, 2, 1], [0, 1, 0, 1]], dtype="i4")
scores = self._get_uniform_logits(batch_size, vocab_size)

no_repeat_proc_2_gram = FlaxNoRepeatNGramLogitsProcessor(2)
no_repeat_proc_3_gram = FlaxNoRepeatNGramLogitsProcessor(3)

filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores, cur_len=cur_len)
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores, cur_len=cur_len)

# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(jnp.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])

# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(jnp.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]])
Comment on lines +215 to +219
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realise these are copied from other parts of the library, but the structure here is really quite confusing.

I'm assuming that:

  • By batch we mean sample in a minibatch i.e. "1st batch" is [1, 1, 2, 1]
  • By tokens we mean token ids
  • The output values e.g. [False, True, True] are across the vocab
  • When the token ids are referred to e.g. "2nd and 3rd token at 1st batch" what we're referring to are the token ids in the vocabulary [0, 1, 2] and NOT the 2nd, 3rd token ids in [1, 1, 2, 1]. The comment makes this really confusing by 1) having the same positional values for the sample in the batch as in the vocab 2) saying "at 1st batch". I'd strongly recommend rewriting this to remove this ambiguity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tests were copied with minimal modifications from the pytorch ones so i didn't change these comments. I agree that the descriptions could be clearer.

I think more comprehensive tests could also be a good idea. For example, i didn't see at first that my first code iteration didn't work in the case where a n-gram appears > 1 time, which is not a case that is tested, so my code passed the tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. As these are so close to the PT & TF tests, lets leave as-is for now. If you have a test in mind and are willing to open a follow-up PR to add, I'd be very happy to review :)


def test_processor_list(self):
batch_size = 4
sequence_length = 10
Expand All @@ -216,6 +237,7 @@ def test_processor_list(self):
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)

# instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
Expand All @@ -231,10 +253,19 @@ def test_processor_list(self):
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)

# with processor list
processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
[
temp_dist_warp,
top_k_warp,
top_p_warp,
min_dist_proc,
bos_dist_proc,
eos_dist_proc,
no_repeat_proc,
]
)
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)

Expand Down Expand Up @@ -263,6 +294,7 @@ def test_processor_list_jitted(self):
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)

# instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
Expand All @@ -279,12 +311,21 @@ def run_no_processor_list(input_ids, scores, cur_len):
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)
return scores

# with processor list
def run_processor_list(input_ids, scores, cur_len):
processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
[
temp_dist_warp,
top_k_warp,
top_p_warp,
min_dist_proc,
bos_dist_proc,
eos_dist_proc,
no_repeat_proc,
]
)
scores = processor(input_ids, scores, cur_len=cur_len)
return scores
Expand Down