You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I want to reproduce the original model results of mistral-7b-v0.2 without flash-attn I got the error:
Traceback (most recent call last):
File "/home/yuanye/long_llm/InfLLM/benchmark/pred.py", line 330, in <module>
preds = get_pred(
File "/home/yuanye/long_llm/InfLLM/benchmark/pred.py", line 263, in get_pred
output = searcher.generate(
File "/home/yuanye/long_llm/InfLLM/inf_llm/utils/greedy_search.py", line 32, in generate
result = self._decode(input_ids, **kwargs)
File "/home/yuanye/long_llm/InfLLM/inf_llm/utils/greedy_search.py", line 54, in _decode
out = self.model(
File "/home/ma-user/anaconda3/envs/infllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ma-user/anaconda3/envs/infllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ma-user/anaconda3/envs/infllm/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1065, in forward
outputs = self.model(
File "/home/ma-user/anaconda3/envs/infllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ma-user/anaconda3/envs/infllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yuanye/long_llm/InfLLM/inf_llm/utils/patch.py", line 102, in model_forward
layer_outputs = decoder_layer(
File "/home/ma-user/anaconda3/envs/infllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ma-user/anaconda3/envs/infllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ma-user/anaconda3/envs/infllm/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 528, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/ma-user/anaconda3/envs/infllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ma-user/anaconda3/envs/infllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yuanye/long_llm/InfLLM/inf_llm/utils/patch.py", line 16, in hf_forward
ret = forward(
File "/home/yuanye/long_llm/InfLLM/inf_llm/attention/origin.py", line 49, in forward
score = torch.matmul(h_q, h_k.transpose(-1, -2))
RuntimeError: The size of tensor a (32) must match the size of tensor b (8) at non-singleton dimension 1
It seems that inf_llm/attention/origin.py does not support GQA in mistral. How to fix it?
The text was updated successfully, but these errors were encountered:
When I want to reproduce the original model results of mistral-7b-v0.2 without
flash-attn
I got the error:It seems that
inf_llm/attention/origin.py
does not support GQA in mistral. How to fix it?The text was updated successfully, but these errors were encountered: