Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatibility Between xformers FA3 Torch Custom op Wrapper and recent flashattn_hopper_cuda #1159

Open
ohwi opened this issue Nov 21, 2024 · 5 comments

Comments

@ohwi
Copy link

ohwi commented Nov 21, 2024

There is an incompatibility between xformers FA3 Torch custom op wrapper logic and recent flashattn_hopper_cuda changes, resulting in a TypeError due to changes in required arguments for the fwd() function:

TypeError: fwd(): incompatible function arguments. The following argument types are supported:
  1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: Optional[torch.Tensor], arg4: float, arg5: Optional[torch.Tensor], arg6: Optional[torch.Tensor], arg7: Optional[torch.Tensor], arg8: bool, arg9: int, arg10: int) -> list[torch.Tensor]

The recent FA3 requires the positional arguments window_size_left and window_size_right (see source).

Me and my colleague @antferdom encountered this issue while attempting to use our own installed FA3 due to failure in compiling FA3 within xformers. As a result, we made modifications to mha_fwd to use FA3 as a fallback.

def mha_fwd(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    seqused_k: torch.Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    p: float,
    softmax_scale: float,
    is_causal: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
    if cu_seqlens_q is None:
        assert cu_seqlens_k is None
        assert seqused_k is None
        window_size_left = -1
        window_size_right = 0 if is_causal else -1
        (
            out,
            q_padded,
            k_padded,
            v_padded,
            out_padded,
            softmax_lse,
            p,
        ) = _C_flashattention3.fwd(
            query, key, value, None, softmax_scale, None, None, None, is_causal,
            window_size_left, window_size_right, False,
        )

So, we suggest updating mha_fwd to properly account for window_size_left and window_size_right as required positional arguments, or adding a check in _C_flashattention3.fwd to determine if these arguments are needed to support all versions of FA3.

If this solution is acceptable, we would be happy to submit a Pull Request to address this issue.

@lw
Copy link
Contributor

lw commented Nov 22, 2024

CC @bottler

@bottler
Copy link
Contributor

bottler commented Nov 22, 2024

The library should work with the actual submodules' commits. It is not a bug in xformers if updating the flash submodule to a different commit produces an incompatibility. In general that is impossible to avoid.

At some point I expect xformers will be updated to a more recent flash version. If you want to submit a PR which updates the submodule to a more recent commit, and updates flash.py and flash3.py consistent with it, and tests are passing, then that would be great!

@ohwi
Copy link
Author

ohwi commented Nov 28, 2024

Hi, sorry for the late response. I’ll do my best to get this done over the weekend, and make a PR. Thank you!

@danthe3rd
Copy link
Contributor

due to failure in compiling FA3 within xformers

Could you elaborate on what was the issue there?

@ohwi
Copy link
Author

ohwi commented Dec 18, 2024

Sorry, I’m not exactly remember what the issue was, but I suspect it might be related to my environment rather than the library itself. I’ll try re-installing the library and will share the specific error if it happens again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants