Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils.constants import DUMMY_MODEL_CONFIG
from peft.utils.integrations import init_empty_weights
from peft.utils.other import set_additional_trainable_modules
from peft.utils.other import create_attention_mask, set_additional_trainable_modules

from . import __version__
from .config import PeftConfig
Expand Down Expand Up @@ -739,7 +739,7 @@ def get_prompt(
dtype=past_key_values[0].dtype,
device=past_key_values[0].device,
)
cache_position = torch.arange(peft_config.num_virtual_tokens)
cache_position = torch.arange(peft_config.num_virtual_tokens, device=past_key_values[0].device)
for layer_idx in range(peft_config.num_layers):
key_states, value_states = past_key_values[0][layer_idx], past_key_values[1][layer_idx]
new_cache.update(
Expand Down Expand Up @@ -1926,6 +1926,10 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs
)

# heuristic to determine if we're in 'prefill stage' (when the KV cache is filled with the values from the
# initial input)
is_prefill = (model_kwargs.get("cache_position") is not None) and (model_kwargs["cache_position"][0] == 0)

if peft_config.peft_type == PeftType.POLY:
model_kwargs["task_ids"] = task_ids
if peft_config.is_prompt_learning:
Expand All @@ -1942,6 +1946,16 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
model_kwargs["input_ids"] = model_kwargs["input_ids"][:, -1:]

if (attention_mask := model_kwargs.get("attention_mask", None)) is not None:
if isinstance(attention_mask, dict):
# see: https://github.com/huggingface/transformers/pull/37866
# For now, just deal with the case of a single attention mask
if len(attention_mask) != 1:
raise ValueError(
f"Expected a single attention mask, got {len(attention_mask)} instead, please open an "
"issue (https://github.com/huggingface/peft/issues) and report the error."
)
attention_mask = list(attention_mask.values())[0]

size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device)
if attention_mask.dim() == 4:
Expand All @@ -1951,7 +1965,26 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
# to [batch_size, total_sequence_length]
bs = attention_mask.shape[0]
total_seq_len = prefix_attention_mask.shape[1] + attention_mask.shape[2]
model_kwargs["attention_mask"] = torch.ones((bs, total_seq_len), dtype=attention_mask.dtype)
attention_mask_2d = torch.ones((bs, total_seq_len), dtype=attention_mask.dtype)

if is_prefill and (peft_config.peft_type != PeftType.PREFIX_TUNING):
# if in prefill stage, for prompt learning methods that are not prefix tuning, new tokens
# (embeddings) are inserted, thus set cache_position to correspond to these tokens
cache_position_ = torch.arange(total_seq_len, device=model_kwargs["input_ids"].device)
else:
# prefix tuning acts directly on the cache, no need to upate cache_position
cache_position_ = model_kwargs["cache_position"]

attention_mask_new = create_attention_mask(
self.get_base_model(),
model_input=None,
attention_mask=attention_mask_2d,
past_key_values=model_kwargs.get("past_key_values"),
cache_position=cache_position_,
batch_size=bs,
sequence_length=total_seq_len,
)
model_kwargs["attention_mask"] = attention_mask_new
else:
# 2d attention mask
model_kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)
Expand Down Expand Up @@ -1987,11 +2020,16 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
model_kwargs["input_ids"] = None

# For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is
# passed in the forward pass to keep track of the position ids of the cache. We have to
# pop that from `model_kwargs` as `cache_position` is properly created by the model, using the passed
# `inputs_embeds`: https://github.com/huggingface/transformers/blob/593230f0a1150ea9c0477b9d859f25daf73c8c33/src/transformers/models/llama/modeling_llama.py#L956
_ = model_kwargs.pop("cache_position", None)
# if we're in the prefill stage
if is_prefill and (peft_config.peft_type == PeftType.PREFIX_TUNING):
# for prefix tuning, the past_key_values have been prefilled
model_kwargs["cache_position"] += peft_config.num_virtual_tokens
elif peft_config.peft_type != PeftType.PREFIX_TUNING: # prefix tuning needs cache_position
# For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is passed in the forward
# pass to keep track of the position ids of the cache. We have to pop that from `model_kwargs` as
# `cache_position` is properly created by the model, using the passed `inputs_embeds`:
# https://github.com/huggingface/transformers/blob/593230f0a1150ea9c0477b9d859f25daf73c8c33/src/transformers/models/llama/modeling_llama.py#L956
_ = model_kwargs.pop("cache_position", None)

return model_kwargs

Expand Down
49 changes: 49 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,3 +1272,52 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
token_indices=token_adapter.token_indices[adapter_name],
tied_adapter=model.get_input_embeddings().token_adapter,
)


def create_attention_mask(
model, *, model_input, attention_mask, past_key_values, cache_position, batch_size, sequence_length
):
# adapted from:
# https://github.com/huggingface/transformers/blob/cb4c56ce0dfa1350267ed28e57760986a58a9ba4/src/transformers/generation/utils.py#L644-L680
# In PEFT, we sometimes need to re-create the attention mask. This is because some prompt learning methods insert
# new items into the sequence, which results in the attention mask needing an update. We re-use transformers code
# for this as much as possible.
try:
from transformers.masking_utils import create_masks_for_generate
except ImportError as exc:
raise ImportError("Your transformers version is too old, please upgrade it to > 4.52") from exc

# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
# the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder.
base_model = getattr(model, model.base_model_prefix, model)
decoder = base_model.get_decoder() if hasattr(base_model, "get_decoder") else None
causal_mask_creation_function = getattr(base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None)
if causal_mask_creation_function is None and decoder is not None: # it may be in the decoder
causal_mask_creation_function = getattr(decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None)

# If it's not defined, it means the model uses the new general mask API
if causal_mask_creation_function is None: # can't be found
token_type_ids = getattr(model_input, "token_type_ids", None)
# Some models may overwrite the general one
causal_mask_creation_function = getattr(model, "create_masks_for_generate", create_masks_for_generate)
attention_mask = causal_mask_creation_function(
config=model.config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
input_embeds=torch.empty((batch_size, sequence_length), dtype=model.dtype),
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
token_type_ids=token_type_ids,
)
else:
attention_mask = causal_mask_creation_function(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=model.dtype,
cache_position=cache_position,
batch_size=batch_size,
config=model.config,
past_key_values=past_key_values,
)
return attention_mask
Loading