Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ptxla training #9864

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Conversation

entrpn
Copy link
Contributor

@entrpn entrpn commented Nov 4, 2024

  • Updates TPU benchmark numbers.
  • Updates the ptxla training example code.
  • Adds flash attention to ptxla code running on TPUs.

@sayakpaul can you please review. This new PR supersedes the other one I had opened a while back, which I just closed. Thank you.

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul
Copy link
Member

Cc: @yiyixuxu could you review the changes made to attention_processor.py?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 5, 2024

@entrpn can you use a custom attention instead? (without updating our default attention processor)

@zpcore
Copy link

zpcore commented Nov 5, 2024

@entrpn can you use a custom attention instead? (without updating our default attention processor)

Hi @yiyixuxu , we wrapped the flash attention kernel call under condition if XLA_AVAILABLE. This shouldn't touch the default attention processor behavior. Can you give more details about use a custom attention? Thanks

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 5, 2024

I'm just wondering if it makes sense for Flash Attention to have its attention processor since this one is meant for SDPA

cc @DN6 here too

@entrpn
Copy link
Contributor Author

entrpn commented Nov 5, 2024

@yiyixuxu this makes sense.

@zpcore do you think you can implement it?

@zpcore
Copy link

zpcore commented Nov 5, 2024

@yiyixuxu this makes sense.

@zpcore do you think you can implement it?

Yes, I can follow up with the code change.

@zpcore
Copy link

zpcore commented Nov 5, 2024

Hi @yiyixuxu , what about we create another AttnProcess with flash attention in parallel with AttnProcessor2_0? My concern is that majority of the code will be the same as AttnProcessor2_0.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 6, 2024

@zpcore
that should not be a problem. a lot of our attention processors share majority of same code, e.g. https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L732 and https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L2443

this way user can explicitly set to use flash attention if they want to

@miladm
Copy link

miladm commented Nov 6, 2024

@yiyixuxu - to better understand, can you please help me understand why wrapping the flash attention kernel call under condition if XLA_AVAILABLE causes a trouble? Do you want this functionality to be more generalized?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 6, 2024

is it not possible that XLA_AVAILABLE but the user does not want to use flash attention?
our attention processors are designed to be very easy to switch & each one corresponding to a very specific method -> could be xformer, SDPA, or even like special method like fused has its own processor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants