-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama model: causal_mask does not exist #29173
Comments
Seems to be a regression from #27931 @BowenBao did we hit that during ONNX Export? |
At that time no, but things might have changed, or the test may had run with a different set of arguments. |
Just tried with 4.38 rel and is running into below error when exporting with dynamo
This is full graph dynamo export regression introduced by #29109 I can get dynamo export to work e2e by commenting out the two lines below. I guess on my end it's using a different set of inputs so not repro-ing OP's error.
|
I'll open a fix right now, these should have been buffers registered. |
Could you try with the updated PR? If this work I think we can do a Patch release |
Thank you for the quick patch @ArthurZucker ! Validated locally this unblocks full graph dynamo export. I'm not sure if OP's issue still persists. @xadupre could you share the set of |
Alright, I'll do an actual patch release on Monday as I think this is important to have ! |
Thank you. Out of curiosity, how long does a transformers patch release take in terms of hours? |
Here is a short example failing with 4.38.1. It does not even need torch.compile and return import torch
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
config = LlamaConfig(
hidden_size=16,
num_hidden_layers=1,
vocab_size=1024,
intermediate_size=16,
max_position_embeddings=1024,
num_attention_heads=2,
)
config._attn_implementation = "eager"
class LlamaAttentionWrapper(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.attention = LlamaAttention(config, layer_idx=0)
def forward(self, hidden_states, attention_mask, position_ids):
attn_output, _, _ = self.attention(
hidden_states, attention_mask, position_ids
)
return attn_output
def generate_example_inputs(batch: int, seq: int, hidden_size: int):
hidden_state = torch.randn(batch, seq, hidden_size)
attention_mask = torch.zeros(batch, 1, seq, seq, dtype=torch.float)
position_ids = torch.arange(0, seq, dtype=torch.int64)
position_ids = position_ids.unsqueeze(0).view(-1, seq)
return hidden_state, attention_mask, position_ids
example_args = generate_example_inputs(2, 1024, 16)
model = LlamaAttentionWrapper(config)
model(*example_args) |
@xadupre this fixes it: if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] which we do have for |
System Info
Line: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L372C43-L372C54
causal_mask
does not exist if attention_mask is None.Who can help?
@onnx
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
No exception.
The text was updated successfully, but these errors were encountered: