Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched inference for OS-Atlas-Base-7B is broken with attn_implementation="flash_attention_2" #17

Open
jasonlee-sf opened this issue Nov 19, 2024 · 0 comments

Comments

@jasonlee-sf
Copy link

Hi, it seems like batched inference is broken when flash_attention is used. When running inference on the 1st example of ScreenSpot test example with flash_attention_2, the output changes depending on the batch size

  • Batch_size = 1: <|object_ref_start|>close button<|object_ref_end|><|box_start|>(954,148),(988,196)<|box_end|><|im_end|>

  • Batch_size = 4: 降序<|im_end|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>

When I disable flash attention_2, the results look fine.

  • Batch_size = 1: <|object_ref_start|>close button<|object_ref_end|><|box_start|>(954,148),(988,196)<|box_end|><|im_end|>

  • Batch_size = 4: <|object_ref_start|>close button<|object_ref_end|><|box_start|>(954,148),(988,196)<|box_end|><|im_end|>

Flash attention with batch_size=1 is fast enough so this bug is not a deal breaker for me, although it'd be nice if this is addressed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant