Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding FlaxNoRepeatNGramLogitsProcessor #29677

Merged
merged 14 commits into from
Apr 2, 2024

Conversation

giganttheo
Copy link
Contributor

@giganttheo giganttheo commented Mar 15, 2024

What does this PR do?

Adding the no repeat n-gram logits processor to Flax, compatible with jitting.
I also added the test test_no_repeat_ngram_dist_processor, adapted from the torch one, and added the FlaxNoRepeatNGramLogitsProcessor to the test_processor_list and test_processor_list_jitted.

All the tests are passing, as RUN_SLOW=1 pytest -sv tests/generation/test_flax_logits_process.py prints:

======================================= test session starts ========================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.4.0 -- /usr/bin/python3
cachedir: .pytest_cache
rootdir: /content/transformers
configfile: pyproject.toml
plugins: anyio-3.7.1
collected 9 items                                                                                  

tests/generation/test_flax_logits_process.py::LogitsProcessorTest::test_forced_bos_token_logits_processor PASSED
tests/generation/test_flax_logits_process.py::LogitsProcessorTest::test_forced_eos_token_logits_processor PASSED
tests/generation/test_flax_logits_process.py::LogitsProcessorTest::test_min_length_dist_processor PASSED
tests/generation/test_flax_logits_process.py::LogitsProcessorTest::test_no_repeat_ngram_dist_processor PASSED
tests/generation/test_flax_logits_process.py::LogitsProcessorTest::test_processor_list PASSED
tests/generation/test_flax_logits_process.py::LogitsProcessorTest::test_processor_list_jitted PASSED
tests/generation/test_flax_logits_process.py::LogitsProcessorTest::test_temperature_dist_warper PASSED
tests/generation/test_flax_logits_process.py::LogitsProcessorTest::test_top_k_dist_warper PASSED
tests/generation/test_flax_logits_process.py::LogitsProcessorTest::test_top_p_dist_warper PASSED

========================================= warnings summary =========================================
../../usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1373
  /usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1373: PytestConfigWarning: Unknown config option: doctest_glob
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================== 9 passed, 1 warning in 14.73s ===================================

Note: in order to work properly within beam search, this processor needs the fix proposed in PR #29636 for the bug discussed in #29635

Let me know if you have any comments or questions regarding this feature.

Who can review?

@gante
@sanchit-gandhi

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for opening the PR 🔥 In general looks good to me, I've added a few questions before approving.

And thank you for keeping the exact same tests as in our PT counterpart, it makes maintenance much simpler 🙌

src/transformers/generation/flax_utils.py Outdated Show resolved Hide resolved
)

data = jnp.ones((all_update_indices.shape[0],), dtype=jnp.uint16)
data = data * (jnp.arange(data.shape[0]) < batch_size * (cur_len - (self.ngram_size - 1))) # ignore the n-grams not yet generated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could slice input_ids before creating all_update_indices, i.e. input_ids = input_ids[:, :cur_len], and save some time/memory when creating all_update_indices.

Or is the result slower, because cur_length changes each iteration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from my experience with jax, in order for the code to work with jit, you cannot use arrays of shapes that are not fixed. This is why i opted to pad the indices to have a known size. There may be more efficient ways to do it but I found that this is working, as opposed to [] slices, or even dynamic slices, because cur_length changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see :) I made the suggestion based on my past XLA+TF experience -- slicing input_ids = input_ids[:, :cur_len] is allowed there (example)

But then again, JAX is usually stricter. Let's keep as you suggested 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it also works in jax in some situations, but in this case the error when using jit is kind of explicit

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

anyway, i took some time to benchmark different ways of doing this kind of operation, and found that in this instance, using jax.lax.fori_loop to update the all_update_indices is significantly faster than using dynamic slices updates (with this + removed useless operations, i measured >10x speedup for the jitted function). there might still be room for improvement

src/transformers/generation/flax_logits_process.py Outdated Show resolved Hide resolved
@gante
Copy link
Member

gante commented Mar 18, 2024

The red CI can be fixed by running make fixup on your terminal (in the transformers folder), then committing the changes

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating 💛

@gante gante requested a review from amyeroberts March 20, 2024 20:27
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

I just one request to update a test. I realise this is inherited, but it really should be addressed as the test is v. confusing

Comment on lines +487 to +494
return val.at[i].set(
jnp.array(
[
b,
]
+ [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)]
)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - can all be one line

Suggested change
return val.at[i].set(
jnp.array(
[
b,
]
+ [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)]
)
)
return val.at[i].set(jnp.array([b] + [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)]))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this was formatted like this by the black / ruff formatter. But if that passes the tests, I agree that it is clearer this way

Comment on lines 496 to 506
shape = (batch_size * (seq_len - (self.ngram_size - 1)), self.ngram_size + 1)
all_update_indices = jax.lax.fori_loop(
0, batch_size * (cur_len - (self.ngram_size - 1)), body_fun, jnp.zeros(shape, dtype=input_ids.dtype)
)

# ignore the n-grams not yet generated
data = (
jnp.arange(batch_size * (seq_len - (self.ngram_size - 1))) < batch_size * (cur_len - (self.ngram_size - 1))
).astype("float32")

return sparse.BCOO((data, all_update_indices), shape=(batch_size,) + (vocab_size,) * self.ngram_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than calculate (seq_len - (self.ngram_size - 1) and (cur_len - (self.ngram_size - 1) several times, it'll be easier to read and follow this code if they're set to variables and then used

Comment on lines +215 to +219
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(jnp.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])

# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(jnp.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realise these are copied from other parts of the library, but the structure here is really quite confusing.

I'm assuming that:

  • By batch we mean sample in a minibatch i.e. "1st batch" is [1, 1, 2, 1]
  • By tokens we mean token ids
  • The output values e.g. [False, True, True] are across the vocab
  • When the token ids are referred to e.g. "2nd and 3rd token at 1st batch" what we're referring to are the token ids in the vocabulary [0, 1, 2] and NOT the 2nd, 3rd token ids in [1, 1, 2, 1]. The comment makes this really confusing by 1) having the same positional values for the sample in the batch as in the vocab 2) saying "at 1st batch". I'd strongly recommend rewriting this to remove this ambiguity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tests were copied with minimal modifications from the pytorch ones so i didn't change these comments. I agree that the descriptions could be clearer.

I think more comprehensive tests could also be a good idea. For example, i didn't see at first that my first code iteration didn't work in the case where a n-gram appears > 1 time, which is not a case that is tested, so my code passed the tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. As these are so close to the PT & TF tests, lets leave as-is for now. If you have a test in mind and are willing to open a follow-up PR to add, I'd be very happy to review :)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Thinking back on it - let's not block waiting for the test reworks, this can be done in follow-ups.

Just needs a make fixup run to resolve the quality checks and we should be good to merge!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker
Copy link
Collaborator

Merging as there is a green light and green CI! 🥳

@ArthurZucker ArthurZucker merged commit fed27ff into huggingface:main Apr 2, 2024
21 checks passed
@giganttheo giganttheo deleted the feature/ngram_blocking_flax branch April 21, 2024 22:32
ArthurZucker pushed a commit that referenced this pull request Apr 22, 2024
* fix issue with logit processor in beam search in Flax

* adding FlaxNoRepeatNGramLogitsProcessor class + unit test

* style correction and code verification

* add FlaxNoRepeatNGramLogitsProcessor to the test_processor_list and test_processor_list_jitted tests

* fix an issue where ngrams are banned only if they appear ==1 time + update description of get_previous_ngrams

* replace non-jit compatible masking of ngrams that are not yet generated with jittable version

* Revert "fix issue with logit processor in beam search in Flax"

This reverts commit 09b70d7.

* add FlaxNoRepeatNGramLogitsProcessor to _get_logits_processor

* change the method of casting to boolean of banned tokens indices

* fix code style

* remove some useless operations + significantly faster computation of update indices using jax.lax.fori_loop

* remove useless loop iterations

* set some variables that were calculated and used multiple times

* fix format
itazap pushed a commit that referenced this pull request May 14, 2024
* fix issue with logit processor in beam search in Flax

* adding FlaxNoRepeatNGramLogitsProcessor class + unit test

* style correction and code verification

* add FlaxNoRepeatNGramLogitsProcessor to the test_processor_list and test_processor_list_jitted tests

* fix an issue where ngrams are banned only if they appear ==1 time + update description of get_previous_ngrams

* replace non-jit compatible masking of ngrams that are not yet generated with jittable version

* Revert "fix issue with logit processor in beam search in Flax"

This reverts commit 09b70d7.

* add FlaxNoRepeatNGramLogitsProcessor to _get_logits_processor

* change the method of casting to boolean of banned tokens indices

* fix code style

* remove some useless operations + significantly faster computation of update indices using jax.lax.fori_loop

* remove useless loop iterations

* set some variables that were calculated and used multiple times

* fix format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants