-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding FlaxNoRepeatNGramLogitsProcessor #29677
Adding FlaxNoRepeatNGramLogitsProcessor #29677
Conversation
…est_processor_list_jitted tests
…pdate description of get_previous_ngrams
…ed with jittable version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for opening the PR 🔥 In general looks good to me, I've added a few questions before approving.
And thank you for keeping the exact same tests as in our PT counterpart, it makes maintenance much simpler 🙌
) | ||
|
||
data = jnp.ones((all_update_indices.shape[0],), dtype=jnp.uint16) | ||
data = data * (jnp.arange(data.shape[0]) < batch_size * (cur_len - (self.ngram_size - 1))) # ignore the n-grams not yet generated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we could slice input_ids
before creating all_update_indices
, i.e. input_ids = input_ids[:, :cur_len]
, and save some time/memory when creating all_update_indices
.
Or is the result slower, because cur_length
changes each iteration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from my experience with jax, in order for the code to work with jit, you cannot use arrays of shapes that are not fixed. This is why i opted to pad the indices to have a known size. There may be more efficient ways to do it but I found that this is working, as opposed to []
slices, or even dynamic slices, because cur_length
changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see :) I made the suggestion based on my past XLA+TF experience -- slicing input_ids = input_ids[:, :cur_len]
is allowed there (example)
But then again, JAX is usually stricter. Let's keep as you suggested 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it also works in jax in some situations, but in this case the error when using jit is kind of explicit
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
anyway, i took some time to benchmark different ways of doing this kind of operation, and found that in this instance, using jax.lax.fori_loop
to update the all_update_indices
is significantly faster than using dynamic slices updates (with this + removed useless operations, i measured >10x speedup for the jitted function). there might still be room for improvement
The red CI can be fixed by running |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for iterating 💛
…update indices using jax.lax.fori_loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
I just one request to update a test. I realise this is inherited, but it really should be addressed as the test is v. confusing
return val.at[i].set( | ||
jnp.array( | ||
[ | ||
b, | ||
] | ||
+ [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)] | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - can all be one line
return val.at[i].set( | |
jnp.array( | |
[ | |
b, | |
] | |
+ [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)] | |
) | |
) | |
return val.at[i].set(jnp.array([b] + [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this was formatted like this by the black / ruff formatter. But if that passes the tests, I agree that it is clearer this way
shape = (batch_size * (seq_len - (self.ngram_size - 1)), self.ngram_size + 1) | ||
all_update_indices = jax.lax.fori_loop( | ||
0, batch_size * (cur_len - (self.ngram_size - 1)), body_fun, jnp.zeros(shape, dtype=input_ids.dtype) | ||
) | ||
|
||
# ignore the n-grams not yet generated | ||
data = ( | ||
jnp.arange(batch_size * (seq_len - (self.ngram_size - 1))) < batch_size * (cur_len - (self.ngram_size - 1)) | ||
).astype("float32") | ||
|
||
return sparse.BCOO((data, all_update_indices), shape=(batch_size,) + (vocab_size,) * self.ngram_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than calculate (seq_len - (self.ngram_size - 1)
and (cur_len - (self.ngram_size - 1)
several times, it'll be easier to read and follow this code if they're set to variables and then used
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch | ||
self.assertListEqual(jnp.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]]) | ||
|
||
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch | ||
self.assertListEqual(jnp.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realise these are copied from other parts of the library, but the structure here is really quite confusing.
I'm assuming that:
- By batch we mean sample in a minibatch i.e. "1st batch" is
[1, 1, 2, 1]
- By tokens we mean token ids
- The output values e.g.
[False, True, True]
are across the vocab - When the token ids are referred to e.g. "2nd and 3rd token at 1st batch" what we're referring to are the token ids in the vocabulary
[0, 1, 2]
and NOT the 2nd, 3rd token ids in[1, 1, 2, 1]
. The comment makes this really confusing by 1) having the same positional values for the sample in the batch as in the vocab 2) saying "at 1st batch". I'd strongly recommend rewriting this to remove this ambiguity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the tests were copied with minimal modifications from the pytorch ones so i didn't change these comments. I agree that the descriptions could be clearer.
I think more comprehensive tests could also be a good idea. For example, i didn't see at first that my first code iteration didn't work in the case where a n-gram appears > 1 time, which is not a case that is tested, so my code passed the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. As these are so close to the PT & TF tests, lets leave as-is for now. If you have a test in mind and are willing to open a follow-up PR to add, I'd be very happy to review :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
Thinking back on it - let's not block waiting for the test reworks, this can be done in follow-ups.
Just needs a make fixup
run to resolve the quality checks and we should be good to merge!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Merging as there is a green light and green CI! 🥳 |
* fix issue with logit processor in beam search in Flax * adding FlaxNoRepeatNGramLogitsProcessor class + unit test * style correction and code verification * add FlaxNoRepeatNGramLogitsProcessor to the test_processor_list and test_processor_list_jitted tests * fix an issue where ngrams are banned only if they appear ==1 time + update description of get_previous_ngrams * replace non-jit compatible masking of ngrams that are not yet generated with jittable version * Revert "fix issue with logit processor in beam search in Flax" This reverts commit 09b70d7. * add FlaxNoRepeatNGramLogitsProcessor to _get_logits_processor * change the method of casting to boolean of banned tokens indices * fix code style * remove some useless operations + significantly faster computation of update indices using jax.lax.fori_loop * remove useless loop iterations * set some variables that were calculated and used multiple times * fix format
* fix issue with logit processor in beam search in Flax * adding FlaxNoRepeatNGramLogitsProcessor class + unit test * style correction and code verification * add FlaxNoRepeatNGramLogitsProcessor to the test_processor_list and test_processor_list_jitted tests * fix an issue where ngrams are banned only if they appear ==1 time + update description of get_previous_ngrams * replace non-jit compatible masking of ngrams that are not yet generated with jittable version * Revert "fix issue with logit processor in beam search in Flax" This reverts commit 09b70d7. * add FlaxNoRepeatNGramLogitsProcessor to _get_logits_processor * change the method of casting to boolean of banned tokens indices * fix code style * remove some useless operations + significantly faster computation of update indices using jax.lax.fori_loop * remove useless loop iterations * set some variables that were calculated and used multiple times * fix format
What does this PR do?
Adding the no repeat n-gram logits processor to Flax, compatible with jitting.
I also added the test
test_no_repeat_ngram_dist_processor
, adapted from the torch one, and added theFlaxNoRepeatNGramLogitsProcessor
to thetest_processor_list
andtest_processor_list_jitted
.All the tests are passing, as
RUN_SLOW=1 pytest -sv tests/generation/test_flax_logits_process.py
prints:Note: in order to work properly within beam search, this processor needs the fix proposed in PR #29636 for the bug discussed in #29635
Let me know if you have any comments or questions regarding this feature.
Who can review?
@gante
@sanchit-gandhi