diff --git a/llava/model/language_model/llava_llama.py b/llava/model/language_model/llava_llama.py index 4a2050a8e..4a4780dff 100644 --- a/llava/model/language_model/llava_llama.py +++ b/llava/model/language_model/llava_llama.py @@ -63,6 +63,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, @@ -94,6 +95,7 @@ def forward( inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, + cache_position=cache_position, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict @@ -134,7 +136,7 @@ def generate( **kwargs ) - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, cache_position=None, **kwargs): images = kwargs.pop("images", None) _inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs