-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers #9098
[RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers #9098
Conversation
@@ -535,7 +535,6 @@ def config_and_inputs(self): | |||
n_docs=self.n_docs, | |||
retrieval_vector_size=self.retrieval_vector_size, | |||
max_combined_length=self.max_combined_length, | |||
use_cache=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use cache was not tested because there was a discrepancy previously between Bart and T5 -> should work now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now use_cache is correctly tested in RAG
@@ -758,18 +756,18 @@ def test_rag_sequence_generate_beam(self): | |||
generator_tokenizer=rag_decoder_tokenizer, | |||
) | |||
|
|||
rag_token = self.sequence_model | |||
rag_token.set_retriever(rag_retriever) | |||
rag_sequence = self.sequence_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a sequence not a token test => so change results here slightly
@@ -407,7 +407,7 @@ def forward( | |||
hidden_states: torch.Tensor, | |||
encoder_hidden_states: torch.Tensor, | |||
encoder_attn_mask: Optional[torch.Tensor] = None, | |||
past_key_value: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |||
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the past_key_value
should have exactly one level for each layer, no matter whether the model is a decoder-only a.k.a. GPT2 or BART. This was not correctly refactored in BART (it should have been implemented 1-to-1 as in T5). No breaking changes here though.
For GPT2, the tuple for each layer contains 2 tensors: key and value states
For BART/T5, the tuple for each layer contains 4 tensors: key and value states of uni-directional self-attention, saved key and value states for cross-attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @patil-suraj for information
@@ -1284,12 +1285,9 @@ def _force_token_id_to_be_generated(scores, token_id) -> None: | |||
|
|||
@staticmethod | |||
def _reorder_cache(past, beam_idx): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes re-order easier
@@ -1057,23 +1061,17 @@ def question_encoder(self): | |||
def _reorder_cache(past, beam_idx): | |||
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs""" | |||
|
|||
def _reorder_stacked(hidden_states): | |||
n_docs = hidden_states.shape[0] // beam_idx.shape[0] | |||
def _reorder_stacked(hidden_states, new_order): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactor RAG according to Bart
What does this PR do?
In Transformers, the cache should always have the same structure. This becomes especially important for composite models like
RAG
andEncoderDecoder
that expect all models to have the same cache.Bart and T5 had different caches with Bart being most different from the standard cache of the library.
This PR aligns the
past_key_values
cache of Bart/Rag with all other models in the library. In general, the philosophy should be:the past_key_value should have exactly one level for each layer, no matter whether the model is a decoder-only a.k.a. GPT2 or BART. This was not correctly refactored in BART (it should have been implemented 1-to-1 as in T5). No breaking changes here though.
past_key_value
tuple for each layer should always be a tuple of tensors, not a tuple of a tupleThis doesn't break any backward compatibility and should fix some RAG problems (@ratthachat). All RAG, Bart slow tests are passing and changes correspond just to the tuple structure.
PR is blocking me for TFBart refactor -> will merge already.
cc @LysandreJik, @sgugger, @patil-suraj for info.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.