Skip to content

Commit 171da8e

Browse files
FIX Attention mask dict issue, generate w/ gemma (#2579)
Resolves CI errors such as this one: https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182 After resolving that error, other errors can occur, but they're unrelated and investigated independently. After the transformers change in huggingface/transformers#37866, it can happen that: > Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by prepare_inputd_for_generation (one dict entry per attention type) As PEFT operates on the attention mask for prompt learning methods, we need to adjust the code for the possibility of attention_mask being a dict. Right now, I simply extract the single value if the dict is just one element. For other sizes, I just raise an error, as I don't know how to deal with that. For our tests, this is enough but we might need to find a better solution in the future.
1 parent bbc9f5d commit 171da8e

File tree

2 files changed

+95
-8
lines changed

2 files changed

+95
-8
lines changed

src/peft/peft_model.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
4242
from peft.utils.constants import DUMMY_MODEL_CONFIG
4343
from peft.utils.integrations import init_empty_weights
44-
from peft.utils.other import set_additional_trainable_modules
44+
from peft.utils.other import create_attention_mask, set_additional_trainable_modules
4545

4646
from . import __version__
4747
from .config import PeftConfig
@@ -765,7 +765,7 @@ def get_prompt(
765765
dtype=past_key_values[0].dtype,
766766
device=past_key_values[0].device,
767767
)
768-
cache_position = torch.arange(peft_config.num_virtual_tokens)
768+
cache_position = torch.arange(peft_config.num_virtual_tokens, device=past_key_values[0].device)
769769
for layer_idx in range(peft_config.num_layers):
770770
key_states, value_states = past_key_values[0][layer_idx], past_key_values[1][layer_idx]
771771
new_cache.update(
@@ -1959,6 +1959,10 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
19591959
uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs
19601960
)
19611961

1962+
# heuristic to determine if we're in 'prefill stage' (when the KV cache is filled with the values from the
1963+
# initial input)
1964+
is_prefill = (model_kwargs.get("cache_position") is not None) and (model_kwargs["cache_position"][0] == 0)
1965+
19621966
if peft_config.peft_type == PeftType.POLY:
19631967
model_kwargs["task_ids"] = task_ids
19641968
if peft_config.is_prompt_learning:
@@ -1975,6 +1979,16 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
19751979
model_kwargs["input_ids"] = model_kwargs["input_ids"][:, -1:]
19761980

19771981
if (attention_mask := model_kwargs.get("attention_mask", None)) is not None:
1982+
if isinstance(attention_mask, dict):
1983+
# see: https://github.com/huggingface/transformers/pull/37866
1984+
# For now, just deal with the case of a single attention mask
1985+
if len(attention_mask) != 1:
1986+
raise ValueError(
1987+
f"Expected a single attention mask, got {len(attention_mask)} instead, please open an "
1988+
"issue (https://github.com/huggingface/peft/issues) and report the error."
1989+
)
1990+
attention_mask = list(attention_mask.values())[0]
1991+
19781992
size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
19791993
prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device)
19801994
if attention_mask.dim() == 4:
@@ -1984,7 +1998,26 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
19841998
# to [batch_size, total_sequence_length]
19851999
bs = attention_mask.shape[0]
19862000
total_seq_len = prefix_attention_mask.shape[1] + attention_mask.shape[2]
1987-
model_kwargs["attention_mask"] = torch.ones((bs, total_seq_len), dtype=attention_mask.dtype)
2001+
attention_mask_2d = torch.ones((bs, total_seq_len), dtype=attention_mask.dtype)
2002+
2003+
if is_prefill and (peft_config.peft_type != PeftType.PREFIX_TUNING):
2004+
# if in prefill stage, for prompt learning methods that are not prefix tuning, new tokens
2005+
# (embeddings) are inserted, thus set cache_position to correspond to these tokens
2006+
cache_position_ = torch.arange(total_seq_len, device=model_kwargs["input_ids"].device)
2007+
else:
2008+
# prefix tuning acts directly on the cache, no need to upate cache_position
2009+
cache_position_ = model_kwargs["cache_position"]
2010+
2011+
attention_mask_new = create_attention_mask(
2012+
self.get_base_model(),
2013+
model_input=None,
2014+
attention_mask=attention_mask_2d,
2015+
past_key_values=model_kwargs.get("past_key_values"),
2016+
cache_position=cache_position_,
2017+
batch_size=bs,
2018+
sequence_length=total_seq_len,
2019+
)
2020+
model_kwargs["attention_mask"] = attention_mask_new
19882021
else:
19892022
# 2d attention mask
19902023
model_kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)
@@ -2020,11 +2053,16 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
20202053
model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
20212054
model_kwargs["input_ids"] = None
20222055

2023-
# For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is
2024-
# passed in the forward pass to keep track of the position ids of the cache. We have to
2025-
# pop that from `model_kwargs` as `cache_position` is properly created by the model, using the passed
2026-
# `inputs_embeds`: https://github.com/huggingface/transformers/blob/593230f0a1150ea9c0477b9d859f25daf73c8c33/src/transformers/models/llama/modeling_llama.py#L956
2027-
_ = model_kwargs.pop("cache_position", None)
2056+
# if we're in the prefill stage
2057+
if is_prefill and (peft_config.peft_type == PeftType.PREFIX_TUNING):
2058+
# for prefix tuning, the past_key_values have been prefilled
2059+
model_kwargs["cache_position"] += peft_config.num_virtual_tokens
2060+
elif peft_config.peft_type != PeftType.PREFIX_TUNING: # prefix tuning needs cache_position
2061+
# For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is passed in the forward
2062+
# pass to keep track of the position ids of the cache. We have to pop that from `model_kwargs` as
2063+
# `cache_position` is properly created by the model, using the passed `inputs_embeds`:
2064+
# https://github.com/huggingface/transformers/blob/593230f0a1150ea9c0477b9d859f25daf73c8c33/src/transformers/models/llama/modeling_llama.py#L956
2065+
_ = model_kwargs.pop("cache_position", None)
20282066

20292067
return model_kwargs
20302068

src/peft/utils/other.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,3 +1294,52 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
12941294
token_indices=token_adapter.token_indices[adapter_name],
12951295
tied_adapter=model.get_input_embeddings().token_adapter,
12961296
)
1297+
1298+
1299+
def create_attention_mask(
1300+
model, *, model_input, attention_mask, past_key_values, cache_position, batch_size, sequence_length
1301+
):
1302+
# adapted from:
1303+
# https://github.com/huggingface/transformers/blob/cb4c56ce0dfa1350267ed28e57760986a58a9ba4/src/transformers/generation/utils.py#L644-L680
1304+
# In PEFT, we sometimes need to re-create the attention mask. This is because some prompt learning methods insert
1305+
# new items into the sequence, which results in the attention mask needing an update. We re-use transformers code
1306+
# for this as much as possible.
1307+
try:
1308+
from transformers.masking_utils import create_masks_for_generate
1309+
except ImportError as exc:
1310+
raise ImportError("Your transformers version is too old, please upgrade it to > 4.52") from exc
1311+
1312+
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
1313+
# the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder.
1314+
base_model = getattr(model, model.base_model_prefix, model)
1315+
decoder = base_model.get_decoder() if hasattr(base_model, "get_decoder") else None
1316+
causal_mask_creation_function = getattr(base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None)
1317+
if causal_mask_creation_function is None and decoder is not None: # it may be in the decoder
1318+
causal_mask_creation_function = getattr(decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None)
1319+
1320+
# If it's not defined, it means the model uses the new general mask API
1321+
if causal_mask_creation_function is None: # can't be found
1322+
token_type_ids = getattr(model_input, "token_type_ids", None)
1323+
# Some models may overwrite the general one
1324+
causal_mask_creation_function = getattr(model, "create_masks_for_generate", create_masks_for_generate)
1325+
attention_mask = causal_mask_creation_function(
1326+
config=model.config,
1327+
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
1328+
input_embeds=torch.empty((batch_size, sequence_length), dtype=model.dtype),
1329+
attention_mask=attention_mask,
1330+
cache_position=cache_position,
1331+
past_key_values=past_key_values,
1332+
token_type_ids=token_type_ids,
1333+
)
1334+
else:
1335+
attention_mask = causal_mask_creation_function(
1336+
attention_mask,
1337+
sequence_length=sequence_length,
1338+
target_length=past_key_values.get_max_cache_shape(),
1339+
dtype=model.dtype,
1340+
cache_position=cache_position,
1341+
batch_size=batch_size,
1342+
config=model.config,
1343+
past_key_values=past_key_values,
1344+
)
1345+
return attention_mask

0 commit comments

Comments
 (0)