Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Fix adaptation prompt CI on transformers main #1465

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 21 additions & 28 deletions src/peft/tuners/adaption_prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,18 @@ def llama_apply_rotary_pos_emb(q, cos, sin, position_ids):

def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
"""
Compute query states for Llama models specifically.

They need to be recomputed as the forward() method of the original LlamaModel in the transformers library does not
return them. See the related discussion in the PR: https://github.com/huggingface/peft/pull/268
Compute query states for Llama models specifically. They need to be recomputed as the forward() method of the
original LlamaModel in the transformers library does not return them. See the related discussion in the PR:
https://github.com/huggingface/peft/pull/268
"""
hidden_states = kwargs.get("hidden_states")
position_ids = kwargs.get("position_ids")
past_key_value = kwargs.get("past_key_value")
bsz, q_len, _ = hidden_states.size()
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)

seq_len = q_len

if past_key_value is not None:
if isinstance(past_key_value, tuple):
# for transformers <= 4.35
Expand All @@ -80,30 +79,24 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
# since transformers 4.36, this is a DynamicCache instance
seq_len += past_key_value.get_seq_length(model.layer_idx)

# For transformers > 4.37.2 `position_ids` became a required arguments in the
# rotary embedding's forward pass. and cos/sin are indexed through the
# `rotary_emb` forward pass.
if "position_ids" in list(inspect.signature(model.rotary_emb.forward).parameters):
if position_ids is None:
kv_seq_len = value_states.shape[-2]
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, model.layer_idx)
kv_seq_len += past_seen_tokens

new_cache_positions = torch.arange(
past_seen_tokens, past_seen_tokens + q_len, device=value_states.device
)
position_ids = new_cache_positions.unsqueeze(0)

cos, sin = model.rotary_emb(value_states, seq_len=kv_seq_len, position_ids=position_ids)

# Here cos and sin are are already indexed correctly, therefore to avoid adding
# boilerplate in `llama_apply_rotary_pos_emb` we just return here the correct query states
# embeddings
return (query_states * cos) + (llama_rotate_half(query_states) * sin)

cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
# For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
# TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)

past_seen_tokens = 0
if position_ids is None:
# Compute position_ids, since they are required for transformers > 4.37.2
if past_key_value is None:
new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
else:
past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device)
position_ids = new_cache_positions.unsqueeze(0)

return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)
cos, sin = model.rotary_emb(value_states, seq_len=q_len + past_seen_tokens, position_ids=position_ids)
return (query_states * cos) + (llama_rotate_half(query_states) * sin)


def is_adaption_prompt_trainable(params: str) -> bool:
Expand Down
Loading