@@ -348,7 +348,7 @@ def forward(
348348        position_embeddings : tuple [torch .Tensor , torch .Tensor ],
349349        attention_mask : Optional [torch .Tensor ] =  None ,
350350        position_ids : Optional [torch .LongTensor ] =  None ,
351-         past_key_values : Optional [tuple [ torch . Tensor ] ] =  None ,
351+         past_key_values : Optional [Cache ] =  None ,
352352        cache_position : Optional [torch .LongTensor ] =  None ,
353353        ** kwargs : Unpack [FlashAttentionKwargs ],
354354    ) ->  torch .FloatTensor :
@@ -366,7 +366,7 @@ def forward(
366366            use_cache (`bool`, *optional*): 
367367                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 
368368                (see `past_key_values`). 
369-             past_key_values (`Tuple(torch.FloatTensor) `, *optional*): cached past key and value projection states 
369+             past_key_values (`Cache `, *optional*): cached past key and value projection states 
370370            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): 
371371                Indices depicting the position of the input sequence tokens in the sequence. 
372372            position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): 
@@ -1011,8 +1011,7 @@ def _deepstack_process(
10111011class  Qwen3VLMoeModelOutputWithPast (ModelOutput ):
10121012    r""" 
10131013    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 
1014-         Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 
1015-         `(batch_size, num_heads, sequence_length, embed_size_per_head)`) 
1014+         It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). 
10161015
10171016        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see 
10181017        `past_key_values` input) to speed up sequential decoding. 
@@ -1021,7 +1020,7 @@ class Qwen3VLMoeModelOutputWithPast(ModelOutput):
10211020    """ 
10221021
10231022    last_hidden_state : Optional [torch .FloatTensor ] =  None 
1024-     past_key_values : Optional [list [ torch . FloatTensor ] ] =  None 
1023+     past_key_values : Optional [Cache ] =  None 
10251024    hidden_states : Optional [tuple [torch .FloatTensor ]] =  None 
10261025    attentions : Optional [tuple [torch .FloatTensor ]] =  None 
10271026    rope_deltas : Optional [torch .LongTensor ] =  None 
@@ -1398,8 +1397,7 @@ class Qwen3VLMoeCausalLMOutputWithPast(ModelOutput):
13981397    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): 
13991398        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 
14001399    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 
1401-         Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 
1402-         `(batch_size, num_heads, sequence_length, embed_size_per_head)`) 
1400+         It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). 
14031401
14041402        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see 
14051403        `past_key_values` input) to speed up sequential decoding. 
@@ -1409,7 +1407,7 @@ class Qwen3VLMoeCausalLMOutputWithPast(ModelOutput):
14091407
14101408    loss : Optional [torch .FloatTensor ] =  None 
14111409    logits : Optional [torch .FloatTensor ] =  None 
1412-     past_key_values : Optional [list [ torch . FloatTensor ] ] =  None 
1410+     past_key_values : Optional [Cache ] =  None 
14131411    hidden_states : Optional [tuple [torch .FloatTensor ]] =  None 
14141412    attentions : Optional [tuple [torch .FloatTensor ]] =  None 
14151413    rope_deltas : Optional [torch .LongTensor ] =  None 
@@ -1465,7 +1463,7 @@ def forward(
14651463        input_ids : torch .LongTensor  =  None ,
14661464        attention_mask : Optional [torch .Tensor ] =  None ,
14671465        position_ids : Optional [torch .LongTensor ] =  None ,
1468-         past_key_values : Optional [list [ torch . FloatTensor ] ] =  None ,
1466+         past_key_values : Optional [Cache ] =  None ,
14691467        inputs_embeds : Optional [torch .FloatTensor ] =  None ,
14701468        labels : Optional [torch .LongTensor ] =  None ,
14711469        pixel_values : Optional [torch .Tensor ] =  None ,
0 commit comments