Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama model: causal_mask does not exist #29173

Closed
2 of 4 tasks
xadupre opened this issue Feb 21, 2024 · 11 comments · Fixed by #29198
Closed
2 of 4 tasks

llama model: causal_mask does not exist #29173

xadupre opened this issue Feb 21, 2024 · 11 comments · Fixed by #29198

Comments

@xadupre
Copy link
Contributor

xadupre commented Feb 21, 2024

System Info

Line: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L372C43-L372C54

causal_mask does not exist if attention_mask is None.

  if attention_mask is not None:  # no matter the length, we just slice it
      if cache_position is not None:
          causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
      attn_weights = attn_weights + causal_mask

Who can help?

@onnx

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

config = LlamaConfig(
    num_hidden_layers=num_hidden_layers,
    vocab_size=vocab_size,
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    max_position_embeddings=max_position_embeddings,
    num_attention_heads=num_attention_heads,
)
config._attn_implementation = _attn_implementation
model = LlamaAttention(config)
torch.compile(model)
optimized_model = torch.compile(model)
optimized_model(*inputs)

Expected behavior

No exception.

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker @younesbelkada

@thiagocrepaldi
Copy link

thiagocrepaldi commented Feb 21, 2024

Seems to be a regression from #27931

@BowenBao did we hit that during ONNX Export?

@BowenBao
Copy link
Contributor

@BowenBao did we hit that during ONNX Export?

At that time no, but things might have changed, or the test may had run with a different set of arguments.

@BowenBao
Copy link
Contributor

BowenBao commented Feb 21, 2024

Just tried with 4.38 rel and is running into below error when exporting with dynamo

AssertionError: Mutating module attribute _cos_cached during export.

This is full graph dynamo export regression introduced by #29109

I can get dynamo export to work e2e by commenting out the two lines below. I guess on my end it's using a different set of inputs so not repro-ing OP's error.

        self._cos_cached = cos
        self._sin_cached = sin

@ArthurZucker
Copy link
Collaborator

I'll open a fix right now, these should have been buffers registered.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Feb 22, 2024

Could you try with the updated PR? If this work I think we can do a Patch release

@BowenBao
Copy link
Contributor

Thank you for the quick patch @ArthurZucker ! Validated locally this unblocks full graph dynamo export.

I'm not sure if OP's issue still persists. @xadupre could you share the set of inputs in your repro?

@ArthurZucker
Copy link
Collaborator

Alright, I'll do an actual patch release on Monday as I think this is important to have !

@thiagocrepaldi
Copy link

Thank you.

Out of curiosity, how long does a transformers patch release take in terms of hours?

@xadupre
Copy link
Contributor Author

xadupre commented Feb 23, 2024

Here is a short example failing with 4.38.1. It does not even need torch.compile and return UnboundLocalError: local variable 'causal_mask' referenced before assignment.

import torch
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention

config = LlamaConfig(
    hidden_size=16,
    num_hidden_layers=1,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=1024,
    num_attention_heads=2,
)
config._attn_implementation = "eager"

class LlamaAttentionWrapper(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = LlamaAttention(config, layer_idx=0)

    def forward(self, hidden_states, attention_mask, position_ids):
        attn_output, _, _ = self.attention(
            hidden_states, attention_mask, position_ids
        )
        return attn_output

def generate_example_inputs(batch: int, seq: int, hidden_size: int):
    hidden_state = torch.randn(batch, seq, hidden_size)
    attention_mask = torch.zeros(batch, 1, seq, seq, dtype=torch.float)
    position_ids = torch.arange(0, seq, dtype=torch.int64)
    position_ids = position_ids.unsqueeze(0).view(-1, seq)
    return hidden_state, attention_mask, position_ids

example_args = generate_example_inputs(2, 1024, 16)
model = LlamaAttentionWrapper(config)
model(*example_args)

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Feb 27, 2024

@xadupre this fixes it:

         if attention_mask is not None:  # no matter the length, we just slice it
+            causal_mask = attention_mask
             if cache_position is not None:
                 causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]

which we do have for sdpa. Will be included in the patch as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants