Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BartModel's past_key_values seems to have different explanations in input_doc and output_doc #9380

Closed
forest1988 opened this issue Jan 2, 2021 · 3 comments · Fixed by #9381

Comments

@forest1988
Copy link
Contributor

Environment info

  • transformers version: 4.1.1
  • Platform: Linux-4.15.0-123-generic-x86_64-with-glibc2.10
  • Python version: 3.8.3
  • PyTorch version (GPU?): 1.7.0 (True)
  • Tensorflow version (GPU?): 2.3.1 (True)
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

Bart: @patrickvonplaten

Information

Model I am using (Bert, XLNet ...): Bart

The problem arises in the document of BartModel and BartForConditionalGeneration

To reproduce

Thank you for kindly answering my question #9298.
I'm now trying to use Bart in transformers v4.1.1.

I'd like to make use of past_key_values, which seems to have been the major change of the refactoring #8900,
but I am a bit confused about the type and shape of it.

About the input of the forward function, it is explained as:

past_key_values (Tuple[Tuple[torch.Tensor]] of length config.n_layers with each tuple having 2 tuples each of which has 2 tensors of shape (batch_size, num_heads, sequence_length - 1, embed_size_per_head)) –

Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.

About the output, it is explained as:

past_key_values (List[torch.FloatTensor], optional, returned when use_cache=True is passed or when config.use_cache=True) – List of torch.FloatTensor of length config.n_layers, with each tensor of shape (2, batch_size, num_heads, sequence_length, embed_size_per_head)).

Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be used (see past_key_values input) to speed up sequential decoding.

I think it will be natural if the input past_key_values and the output past_key_values have the same format and the output can be used as the input in the next step.

If my understanding is correct, the document of the input is generated with BART_INPUTS_DOCSTRING, and the output is from Seq2SeqModelOutput.

    @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="facebook/bart-large",
        output_type=Seq2SeqModelOutput,
        config_class=_CONFIG_FOR_DOC,

I'm sorry if I'm wrong, but maybe the Seq2SeqModelOutput documentation hasn't been updated for refactoring?
(When I look at the git log, I cannot find the related commit.)

I apologize if the difference in input/output format is due to some intention.

If you don't mind, I'd like to ask one more question.
In the refactoring of Bart, the BartDecoderLayer (renamed from DecoderLayer) seems to be updated as below:

        # make sure decoder uni-directional self-attn at 1st position and cross-attn at 2nd position.
        present_key_value = (self_attn_present_key_value, cross_attn_present_key_value)

        return (
            hidden_states,
            self_attn_weights,
            present_key_value,
            cross_attn_weights,
        )

And in the BartDecoder, cache is updated as below:

            if use_cache:
                next_decoder_cache += (present_key_value,)
...

        next_cache = next_decoder_cache if use_cache else None

Does it mean the Bart (and other Seq2Seq Language Models) have both selt_atten_present_key_value and cross_attn_present_key_value in past_key_values?

Expected behavior

Maybe the document of Seq2SeqModelOutput needs to be updated.
I apologize if the difference in the input/output explanations is due to some intention.

@patrickvonplaten
Copy link
Contributor

Hey @forest1988,

Thanks for your issue! You're 100% correct. The docs need to be updated here! The output is actually never a list, it should always be a Tuple(Tuple(torch.FloatTensor)) - I'll make a PR afterward.
And in Bart, past_key_values always consists of selt_attn_present_key_value and cross_attn_present_key_value.

@forest1988
Copy link
Contributor Author

Hi @patrickvonplaten,

Thank you for your quick response to this issue!
The update of the docs and your answer to my question -- what past_key_values consists of -- are very helpful for me!

@forest1988
Copy link
Contributor Author

Hi @patrickvonplaten,

Excuse me for my frequent questions.
I created a new issue #9391, in which I ask your help about the past_key_values in Bart (Seq2SeqLM) and GPT-2 (CausalLM).

I think it is not an error, but a feature request.

If you could check it out when you have time, it would be greatly appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants