-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix compatibility with jax transformations #7
Comments
can confirm that this error also appears under example here: q = jax.random.normal(keys[0], (l, b, lq, h, d))
k = jax.random.normal(keys[1], (l, b, lkv, h, d))
v = jax.random.normal(keys[2], (l, b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (l, b, lkv))
def scan_fn(carry, qkv):
out = flash_attention(*qkv)[0]
carry += out
return carry, out
@jax.jit
def bench_flash_bwd(q, k, v, mask):
return jax.grad(
lambda q, k, v, mask: jnp.sum(
jax.lax.scan(
scan_fn,
jnp.zeros_like(q[0]),
(q, k, v, mask),
)[0],
)
)(q, k, v, mask)
bench_flash_bwd(q, k, v, mask) |
jax.checkpoint
Thanks for raising this! It looks like a JAX core bug most likely. Could you provide a self-contained runnable repro, in particular including the import or definition for |
from flash_attention_jax import flash_attention |
ran into this and failed to upstream. The trick to fix it is to basically do this: |
@dlwh looks like you also ran an autoformatter so there's a ton of other changes here - can you say a bit more about how you fixed it? |
Yeah sorry, the line linked is the key one. Basically just rename the method called "causal_flash_attention" to "_causal_flash_attention" and make causal_flash_attention return just the first result. Then make
|
won't that make |
This is roughly repeating what @dlwh just said, but I just figured it out and came back to explain: this use of There's a JAX bug in that this was a terrible error message to raise, but the fundamental bug is in that use of |
you'll need to make the analogous change to |
Shall I send a PR fix to this repo (maybe you both could review it), and then separately fix the JAX error message? Or @dlwh do you want to send the fix to this repo? |
I can probably get to it tonight or tomorrow, but I'm about to go dark for several hours. Totally up to you! |
I'll take the first stab, and cc you! |
so the relevant fix would be to replace
return (out, (row_sum, row_max)), (q, k, v, key_mask, out, row_sum, row_max) ? |
interesting that this works with |
@GallagherCommaJack Yes, that'd work! It's probably the simplest fix, though we could also look at the call sites of What's a repro for the behavior you're describing? I tried removing import jax
import jax.numpy as jnp
from flash_attention_jax import flash_attention
b = 3
lq = 16
lkv = 17
h = 5
d = 19
keys = jax.random.split(jax.random.PRNGKey(0), 4)
q = jax.random.normal(keys[0], (b, lq, h, d))
k = jax.random.normal(keys[1], (b, lkv, h, d))
v = jax.random.normal(keys[2], (b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (b, lkv))
@jax.jit
def bench_flash_bwd(q, k, v, mask):
return jax.grad(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0]))(q)
bench_flash_bwd(q, k, v, mask) |
Ah, I think it was just a shape bug; if I sent I think by adding the better JAX error message I described, we'll catch this much earlier and get an error in both cases. I'll be sure to test both with and without checkpoint/scan. |
Actually, I think it would not work just because the callers expect only a single output there. I think the issue here was that the |
with the fix it's working with |
the error with full backtrace:
|
It looks like one of |
When defining a forward rule for `jax.custom_vjp(primal_fn)`, if `primal_fn` has output type `T` then we need the forward rule to have output type `(T, R)` for some `R`. That is, we need the first output of the forward rule to look like the full output of `primal_fn`. (Here the `R` values represent the 'residuals' computed on the forward pass to save for use on the backward pass.) This PR fixes a disagreement between `custom_vjp`-decorated functions and their corresponding forward rules. The disagreement caused some interesting behavior! Discussed on lucidrains#7 Separately, I'm going to try to get JAX to raise a better error message in this case; the error message was some really confusing JAX-internals thing.
debugging a bit, it looks like the issue is that |
@lucidrains looks like there's an implicit assumption somewhere in here that |
When defining a forward rule for `jax.custom_vjp(primal_fn)`, if `primal_fn` has output type `T` then we need the forward rule to have output type `(T, R)` for some `R`. That is, we need the first output of the forward rule to look like the full output of `primal_fn`. (Here the `R` values represent the 'residuals' computed on the forward pass to save for use on the backward pass.) This PR fixes a disagreement between `custom_vjp`-decorated functions and their corresponding forward rules. The disagreement caused some interesting behavior! Discussed on lucidrains#7 Separately, I'm going to try to get JAX to raise a better error message in this case; the error message was some really confusing JAX-internals thing.
@GallagherCommaJack the fix I proposed in #8 is different from the commit you sent, just FYI. |
does that work with |
looks like it does not |
Indeed I think the shape issue is unrelated. |
In particular: * add function names so it's clear what decorated functions and rules are causing the error; * when possible (because the functions were run), check for agreement of pytree structure and leaf shapes/dtypes between the primal function and rules context: lucidrains/flash-attention-jax#7
jax-ml/jax#12611 should improve the error message we got here! With the same repro (i.e. before the fix #7 was merged here), the error will be:
|
In particular: * add function names so it's clear what decorated functions and rules are causing the error; * when possible (because the functions were run), check for agreement of pytree structure and leaf shapes/dtypes between the primal function and rules context: lucidrains/flash-attention-jax#7
In particular: * add function names so it's clear what decorated functions and rules are causing the error; * when possible (because the functions were run), check for agreement of pytree structure and leaf shapes/dtypes between the primal function and rules context: lucidrains/flash-attention-jax#7
In particular: * add function names so it's clear what decorated functions and rules are causing the error; * when possible (because the functions were run), check for agreement of pytree structure and leaf shapes/dtypes between the primal function and rules context: lucidrains/flash-attention-jax#7
currently impossible to use
flash_attention
within a function that will use gradient checkpointingminimal example to reproduce:
fails with error:
The text was updated successfully, but these errors were encountered: