You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The problem arises in the document of BartModel and BartForConditionalGeneration
To reproduce
Thank you for kindly answering my question #9298.
I'm now trying to use Bart in transformers v4.1.1.
I'd like to make use of past_key_values, which seems to have been the major change of the refactoring #8900,
but I am a bit confused about the type and shape of it.
About the input of the forward function, it is explained as:
past_key_values (Tuple[Tuple[torch.Tensor]] of length config.n_layers with each tuple having 2 tuples each of which has 2 tensors of shape (batch_size, num_heads, sequence_length - 1, embed_size_per_head)) –
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
About the output, it is explained as:
past_key_values (List[torch.FloatTensor], optional, returned when use_cache=True is passed or when config.use_cache=True) – List of torch.FloatTensor of length config.n_layers, with each tensor of shape (2, batch_size, num_heads, sequence_length, embed_size_per_head)).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be used (see past_key_values input) to speed up sequential decoding.
I think it will be natural if the input past_key_values and the output past_key_values have the same format and the output can be used as the input in the next step.
If my understanding is correct, the document of the input is generated with BART_INPUTS_DOCSTRING, and the output is from Seq2SeqModelOutput.
I'm sorry if I'm wrong, but maybe the Seq2SeqModelOutput documentation hasn't been updated for refactoring?
(When I look at the git log, I cannot find the related commit.)
I apologize if the difference in input/output format is due to some intention.
If you don't mind, I'd like to ask one more question.
In the refactoring of Bart, the BartDecoderLayer (renamed from DecoderLayer) seems to be updated as below:
# make sure decoder uni-directional self-attn at 1st position and cross-attn at 2nd position.present_key_value= (self_attn_present_key_value, cross_attn_present_key_value)
return (
hidden_states,
self_attn_weights,
present_key_value,
cross_attn_weights,
)
And in the BartDecoder, cache is updated as below:
Thanks for your issue! You're 100% correct. The docs need to be updated here! The output is actually never a list, it should always be a Tuple(Tuple(torch.FloatTensor)) - I'll make a PR afterward.
And in Bart, past_key_values always consists of selt_attn_present_key_value and cross_attn_present_key_value.
Thank you for your quick response to this issue!
The update of the docs and your answer to my question -- what past_key_values consists of -- are very helpful for me!
Excuse me for my frequent questions.
I created a new issue #9391, in which I ask your help about the past_key_values in Bart (Seq2SeqLM) and GPT-2 (CausalLM).
I think it is not an error, but a feature request.
If you could check it out when you have time, it would be greatly appreciated.
Environment info
transformers
version: 4.1.1Who can help
Bart: @patrickvonplaten
Information
Model I am using (Bert, XLNet ...): Bart
The problem arises in the document of BartModel and BartForConditionalGeneration
To reproduce
Thank you for kindly answering my question #9298.
I'm now trying to use Bart in transformers v4.1.1.
I'd like to make use of
past_key_values
, which seems to have been the major change of the refactoring #8900,but I am a bit confused about the type and shape of it.
About the input of the
forward
function, it is explained as:About the output, it is explained as:
I think it will be natural if the input
past_key_values
and the outputpast_key_values
have the same format and the output can be used as the input in the next step.If my understanding is correct, the document of the input is generated with
BART_INPUTS_DOCSTRING
, and the output is fromSeq2SeqModelOutput
.I'm sorry if I'm wrong, but maybe the
Seq2SeqModelOutput
documentation hasn't been updated for refactoring?(When I look at the git log, I cannot find the related commit.)
I apologize if the difference in input/output format is due to some intention.
If you don't mind, I'd like to ask one more question.
In the refactoring of Bart, the
BartDecoderLayer
(renamed fromDecoderLayer
) seems to be updated as below:And in the
BartDecoder
, cache is updated as below:Does it mean the Bart (and other Seq2Seq Language Models) have both
selt_atten_present_key_value
andcross_attn_present_key_value
inpast_key_values
?Expected behavior
Maybe the document of
Seq2SeqModelOutput
needs to be updated.I apologize if the difference in the input/output explanations is due to some intention.
The text was updated successfully, but these errors were encountered: