4141from peft .tuners .tuners_utils import BaseTuner , BaseTunerLayer
4242from peft .utils .constants import DUMMY_MODEL_CONFIG
4343from peft .utils .integrations import init_empty_weights
44- from peft .utils .other import set_additional_trainable_modules
44+ from peft .utils .other import create_attention_mask , set_additional_trainable_modules
4545
4646from . import __version__
4747from .config import PeftConfig
@@ -765,7 +765,7 @@ def get_prompt(
765765 dtype = past_key_values [0 ].dtype ,
766766 device = past_key_values [0 ].device ,
767767 )
768- cache_position = torch .arange (peft_config .num_virtual_tokens )
768+ cache_position = torch .arange (peft_config .num_virtual_tokens , device = past_key_values [ 0 ]. device )
769769 for layer_idx in range (peft_config .num_layers ):
770770 key_states , value_states = past_key_values [0 ][layer_idx ], past_key_values [1 ][layer_idx ]
771771 new_cache .update (
@@ -1959,6 +1959,10 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
19591959 uses_transformers_4_36 and self .base_model .config .model_type in transformers_new_cache_archs
19601960 )
19611961
1962+ # heuristic to determine if we're in 'prefill stage' (when the KV cache is filled with the values from the
1963+ # initial input)
1964+ is_prefill = (model_kwargs .get ("cache_position" ) is not None ) and (model_kwargs ["cache_position" ][0 ] == 0 )
1965+
19621966 if peft_config .peft_type == PeftType .POLY :
19631967 model_kwargs ["task_ids" ] = task_ids
19641968 if peft_config .is_prompt_learning :
@@ -1975,6 +1979,16 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
19751979 model_kwargs ["input_ids" ] = model_kwargs ["input_ids" ][:, - 1 :]
19761980
19771981 if (attention_mask := model_kwargs .get ("attention_mask" , None )) is not None :
1982+ if isinstance (attention_mask , dict ):
1983+ # see: https://github.com/huggingface/transformers/pull/37866
1984+ # For now, just deal with the case of a single attention mask
1985+ if len (attention_mask ) != 1 :
1986+ raise ValueError (
1987+ f"Expected a single attention mask, got { len (attention_mask )} instead, please open an "
1988+ "issue (https://github.com/huggingface/peft/issues) and report the error."
1989+ )
1990+ attention_mask = list (attention_mask .values ())[0 ]
1991+
19781992 size = model_kwargs ["input_ids" ].shape [0 ], peft_config .num_virtual_tokens
19791993 prefix_attention_mask = torch .ones (size ).to (model_kwargs ["input_ids" ].device )
19801994 if attention_mask .dim () == 4 :
@@ -1984,7 +1998,26 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
19841998 # to [batch_size, total_sequence_length]
19851999 bs = attention_mask .shape [0 ]
19862000 total_seq_len = prefix_attention_mask .shape [1 ] + attention_mask .shape [2 ]
1987- model_kwargs ["attention_mask" ] = torch .ones ((bs , total_seq_len ), dtype = attention_mask .dtype )
2001+ attention_mask_2d = torch .ones ((bs , total_seq_len ), dtype = attention_mask .dtype )
2002+
2003+ if is_prefill and (peft_config .peft_type != PeftType .PREFIX_TUNING ):
2004+ # if in prefill stage, for prompt learning methods that are not prefix tuning, new tokens
2005+ # (embeddings) are inserted, thus set cache_position to correspond to these tokens
2006+ cache_position_ = torch .arange (total_seq_len , device = model_kwargs ["input_ids" ].device )
2007+ else :
2008+ # prefix tuning acts directly on the cache, no need to upate cache_position
2009+ cache_position_ = model_kwargs ["cache_position" ]
2010+
2011+ attention_mask_new = create_attention_mask (
2012+ self .get_base_model (),
2013+ model_input = None ,
2014+ attention_mask = attention_mask_2d ,
2015+ past_key_values = model_kwargs .get ("past_key_values" ),
2016+ cache_position = cache_position_ ,
2017+ batch_size = bs ,
2018+ sequence_length = total_seq_len ,
2019+ )
2020+ model_kwargs ["attention_mask" ] = attention_mask_new
19882021 else :
19892022 # 2d attention mask
19902023 model_kwargs ["attention_mask" ] = torch .cat ((prefix_attention_mask , attention_mask ), dim = 1 )
@@ -2020,11 +2053,16 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
20202053 model_kwargs ["inputs_embeds" ] = torch .cat ((prompts , inputs_embeds ), dim = 1 )
20212054 model_kwargs ["input_ids" ] = None
20222055
2023- # For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is
2024- # passed in the forward pass to keep track of the position ids of the cache. We have to
2025- # pop that from `model_kwargs` as `cache_position` is properly created by the model, using the passed
2026- # `inputs_embeds`: https://github.com/huggingface/transformers/blob/593230f0a1150ea9c0477b9d859f25daf73c8c33/src/transformers/models/llama/modeling_llama.py#L956
2027- _ = model_kwargs .pop ("cache_position" , None )
2056+ # if we're in the prefill stage
2057+ if is_prefill and (peft_config .peft_type == PeftType .PREFIX_TUNING ):
2058+ # for prefix tuning, the past_key_values have been prefilled
2059+ model_kwargs ["cache_position" ] += peft_config .num_virtual_tokens
2060+ elif peft_config .peft_type != PeftType .PREFIX_TUNING : # prefix tuning needs cache_position
2061+ # For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is passed in the forward
2062+ # pass to keep track of the position ids of the cache. We have to pop that from `model_kwargs` as
2063+ # `cache_position` is properly created by the model, using the passed `inputs_embeds`:
2064+ # https://github.com/huggingface/transformers/blob/593230f0a1150ea9c0477b9d859f25daf73c8c33/src/transformers/models/llama/modeling_llama.py#L956
2065+ _ = model_kwargs .pop ("cache_position" , None )
20282066
20292067 return model_kwargs
20302068
0 commit comments