Skip to content

Conversation

@oscarkey
Copy link
Contributor

Port https://github.com/PriorLabs/TabPFN-private/pull/46 to public:

  • remove flash attention package use
  • chunck scaled_dot_product_attention

This might mean people who installed their own flash attention get slower inference. The difference should be non existent/small for more recent PyTorch, but would be big for anyone using PyTorch 2.1 where FA2 isn't built in. But probably no one is doing this?

Port PriorLabs/TabPFN-private#46 to public.
* remove flash attention package use
* chunck scaled_dot_product_attention

This might mean people who installed their own flash attention get
slower inference, particularly anyone using PyTorch 2.1, where FA2 isn't
built in. But I doubt anyone is doing this?

Co-authored-by: Oscar Key <oscar@priorlabs.ai>
@oscarkey oscarkey requested a review from LeoGrin October 28, 2025 13:33
@oscarkey oscarkey requested a review from a team as a code owner October 28, 2025 13:33
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes the dependency on the flash-attn package and introduces a chunking mechanism for scaled_dot_product_attention to work around a CUDA kernel limitation with large batch sizes. This is a good simplification that improves robustness. The changes are well-implemented and include new tests. I've identified one edge case where an empty batch would cause a runtime error and have suggested a fix along with an additional test case to cover it. Otherwise, the changes look solid.

Copy link
Collaborator

@LeoGrin LeoGrin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

This might mean people who installed their own flash attention get slower inference. The difference should be non existent/small for more recent PyTorch, but would be big for anyone using PyTorch 2.1 where FA2 isn't built in. But probably no one is doing this?

Yes that seems fine to me. I'm wondering if we want to force torch 2.2. On the one hand it's quite recent (jan 2024), on the other hand people using torch 2.1 will have a very bad experience 🤔

@oscarkey
Copy link
Contributor Author

I would be in favour of forcing 2.2, but I created https://linear.app/priorlabs/issue/RES-813/drop-torch-21warn-if-used-on-big-datasets

@oscarkey oscarkey merged commit 93f8673 into main Oct 28, 2025
10 checks passed
@oscarkey oscarkey deleted the ok-sync-attention branch October 28, 2025 14:18
oscarkey pushed a commit that referenced this pull request Nov 12, 2025
Co-authored-by: mirror-bot <mirror-bot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants