Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Fix adaptation prompt CI on transformers main #1465

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/peft/tuners/adaption_prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)

seq_len = q_len
kv_seq_len = value_states.shape[-2]

if past_key_value is not None:
if isinstance(past_key_value, tuple):
# for transformers <= 4.35
Expand All @@ -85,7 +87,6 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
# `rotary_emb` forward pass.
if "position_ids" in list(inspect.signature(model.rotary_emb.forward).parameters):
if position_ids is None:
kv_seq_len = value_states.shape[-2]
past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, model.layer_idx)
kv_seq_len += past_seen_tokens

Expand All @@ -101,6 +102,17 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
# embeddings
return (query_states * cos) + (llama_rotate_half(query_states) * sin)

# For transformers > 4.37.2 `position_ids` became a required arguments in the
# rotary embedding's forward pass. and cos/sin are indexed through the
# `rotary_emb` forward pass.
if "position_ids" in list(inspect.signature(model.rotary_emb.forward).parameters):
if position_ids is None:
new_cache_positions = torch.arange(kv_seq_len, kv_seq_len + q_len, device=value_states.device)
position_ids = new_cache_positions.unsqueeze(0)

cos, sin = model.rotary_emb(value_states, seq_len=kv_seq_len, position_ids=position_ids)
return (query_states * cos) + (llama_rotate_half(query_states) * sin)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to simplify this function. What was especially confusing for me were the values of q_len, seq_len, and kv_seq_len.

E.g. we have:

    bsz, q_len, _ = hidden_states.size()
    value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
    seq_len = q_len
    kv_seq_len = value_states.shape[-2]

From this, it follows that q_len, seq_len, and kv_seq_len are all the same value, right?

Also, I saw that there was a bit of code duplication when it came to calculating position_ids. My simplified version tries to take this into account:

def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
    hidden_states = kwargs.get("hidden_states")
    position_ids = kwargs.get("position_ids")
    past_key_value = kwargs.get("past_key_value")
    bsz, q_len, _ = hidden_states.size()
    query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
    value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
    seq_len = q_len

    if past_key_value is not None:
        if isinstance(past_key_value, tuple):
            # for transformers <= 4.35
            seq_len += past_key_value[0].shape[-2]
        else:
            # since transformers 4.36, this is a DynamicCache instance
            seq_len += past_key_value.get_seq_length(model.layer_idx)

    # For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
    if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
        # TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
        cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
        return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)

    # XXX
    past_seen_tokens = 0
    if position_ids is None:
        # Compute position_ids, since they are required for transformers > 4.37.2
        if past_key_value is None:
            new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
        else:
            past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
            new_cache_positions = torch.arange(
                past_seen_tokens, past_seen_tokens + q_len, device=value_states.device
            )
        position_ids = new_cache_positions.unsqueeze(0)

    cos, sin = model.rotary_emb(value_states, seq_len=q_len + past_seen_tokens, position_ids=position_ids)
    return (query_states * cos) + (llama_rotate_half(query_states) * sin)

Note that the code until # XXX is basically the old state of this function, from before the changes to transformers. After that comes the code that takes into account the new changes. I tried to keep the logic identical, but please double-check that this is true.

I tested the code with transformers 4.35.0, 4.37.2, and installed from main (4.38.0.dev0) and the tests passed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks very much for the investigation, the proposed solution sounds great!

cos, sin = model.rotary_emb(value_states, seq_len=seq_len)

return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)
Expand Down
Loading