Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix compatibility with jax transformations #7

Open
GallagherCommaJack opened this issue Sep 23, 2022 · 28 comments
Open

fix compatibility with jax transformations #7

GallagherCommaJack opened this issue Sep 23, 2022 · 28 comments

Comments

@GallagherCommaJack
Copy link

currently impossible to use flash_attention within a function that will use gradient checkpointing

minimal example to reproduce:

b = 3
lq = 16
lkv = 17
h = 5
d = 19
q = jax.random.normal(keys[0], (b, lq, h, d))
k = jax.random.normal(keys[1], (b, lkv, h, d))
v = jax.random.normal(keys[2], (b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (b, lkv))

@jax.jit
def bench_flash_bwd(q, k, v, mask):
    return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)

fails with error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb Cell 9 in <cell line: 1>()
----> [1](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) get_ipython().run_line_magic('timeit', 'bench_flash_bwd(q, k, v, mask).block_until_ready()')

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2305, in InteractiveShell.run_line_magic(self, magic_name, line, _stack_depth)
   2303     kwargs['local_ns'] = self.get_local_scope(stack_depth)
   2304 with self.builtin_trap:
-> 2305     result = fn(*args, **kwargs)
   2306 return result

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/magics/execution.py:1162, in ExecutionMagics.timeit(self, line, cell, local_ns)
   1160 for index in range(0, 10):
   1161     number = 10 ** index
-> 1162     time_number = timer.timeit(number)
   1163     if time_number >= 0.2:
   1164         break

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/magics/execution.py:156, in Timer.timeit(self, number)
    154 gc.disable()
    155 try:
--> 156     timing = self.inner(it, self.timer)
    157 finally:
    158     if gcold:

File <magic-timeit>:1, in inner(_it, _timer)

    [... skipping hidden 14 frame]

/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb Cell 9 in bench_flash_bwd(q, k, v, mask)
      [1](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) @jax.jit
      [2](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1) def bench_flash_bwd(q, k, v, mask):
----> [3](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2)     return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0]), policy=jax.checkpoint_policies.everything_saveable))(q)

    [... skipping hidden 25 frame]

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/util.py:48, in safe_map(f, *args)
     46 n = len(args[0])
     47 for arg in args[1:]:
---> 48   assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
     49 return list(map(f, *args))

AssertionError: length mismatch: [3, 1]
@GallagherCommaJack
Copy link
Author

can confirm that this error also appears under jax.lax.scan

example here:

q = jax.random.normal(keys[0], (l, b, lq, h, d))
k = jax.random.normal(keys[1], (l, b, lkv, h, d))
v = jax.random.normal(keys[2], (l, b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (l, b, lkv))


def scan_fn(carry, qkv):
    out = flash_attention(*qkv)[0]
    carry += out
    return carry, out


@jax.jit
def bench_flash_bwd(q, k, v, mask):
    return jax.grad(
        lambda q, k, v, mask: jnp.sum(
            jax.lax.scan(
                scan_fn,
                jnp.zeros_like(q[0]),
                (q, k, v, mask),
            )[0],
        )
    )(q, k, v, mask)


bench_flash_bwd(q, k, v, mask)

@GallagherCommaJack GallagherCommaJack changed the title fix compatibility with jax.checkpoint fix compatibility with jax transformations Sep 29, 2022
@mattjj
Copy link
Contributor

mattjj commented Sep 29, 2022

Thanks for raising this! It looks like a JAX core bug most likely.

Could you provide a self-contained runnable repro, in particular including the import or definition for flash_attention? (Sorry, I'm not the developer of this repo, so I'm not familiar with that function.)

@GallagherCommaJack
Copy link
Author

from flash_attention_jax import flash_attention

@dlwh
Copy link

dlwh commented Sep 29, 2022

ran into this and failed to upstream. The trick to fix it is to basically do this:

stanford-crfm/levanter@a2828ce#diff-658abe908dd5cd256efe9370e7ec2ae9fa2dcdca586a5f886940331e7b56dd09R129-R132

@GallagherCommaJack
Copy link
Author

@dlwh looks like you also ran an autoformatter so there's a ton of other changes here - can you say a bit more about how you fixed it?

@dlwh
Copy link

dlwh commented Sep 29, 2022

Yeah sorry, the line linked is the key one. Basically just rename the method called "causal_flash_attention" to "_causal_flash_attention" and make causal_flash_attention return just the first result. Then make flash_attention_forward call _causal_flash_attention instead, and you're done.

@custom_vjp
def causal_flash_attention(q, k, v):
+    return _causal_flash_attention(q, k, v)[0]
+
+
+def _causal_flash_attention(q, k, v):

@GallagherCommaJack
Copy link
Author

won't that make flash_attention always do causal masking? I'm using this in a context where that's not appropriate

@mattjj
Copy link
Contributor

mattjj commented Sep 29, 2022

This is roughly repeating what @dlwh just said, but I just figured it out and came back to explain: this use of custom_vjp is buggy in that the flash_attention_forward output needs to be a pair where the first element has the same type as the output of flash_attention. Yet we can see that where flash_attention includes three arrays, the first element of the return value of flash_attention_forward only has one array.

There's a JAX bug in that this was a terrible error message to raise, but the fundamental bug is in that use of custom_vjp.

@dlwh
Copy link

dlwh commented Sep 29, 2022

you'll need to make the analogous change to flash_attention then. as @mattjj said it's really just a buggy use of custom_vjp. (Though despite it not running the code was otherwise correct according to my gradient testing!)

@mattjj
Copy link
Contributor

mattjj commented Sep 29, 2022

Shall I send a PR fix to this repo (maybe you both could review it), and then separately fix the JAX error message? Or @dlwh do you want to send the fix to this repo?

@dlwh
Copy link

dlwh commented Sep 29, 2022

I can probably get to it tonight or tomorrow, but I'm about to go dark for several hours. Totally up to you!

@mattjj
Copy link
Contributor

mattjj commented Sep 29, 2022

I'll take the first stab, and cc you!

@GallagherCommaJack
Copy link
Author

so the relevant fix would be to replace

return out, (q, k, v, key_mask, out, row_sum, row_max)
with

    return (out, (row_sum, row_max)), (q, k, v, key_mask, out, row_sum, row_max)

?

@GallagherCommaJack
Copy link
Author

interesting that this works with grad outside of scan and remat - probably it should fail under grad alone without either of those?

@mattjj
Copy link
Contributor

mattjj commented Sep 29, 2022

@GallagherCommaJack Yes, that'd work! It's probably the simplest fix, though we could also look at the call sites of flash_attention to see if some other organization would be more natural.

What's a repro for the behavior you're describing? I tried removing jax.checkpoint from the repro in the OP and I still got an error. That is, this still errors for me:

import jax
import jax.numpy as jnp

from flash_attention_jax import flash_attention


b = 3
lq = 16
lkv = 17
h = 5
d = 19
keys = jax.random.split(jax.random.PRNGKey(0), 4)
q = jax.random.normal(keys[0], (b, lq, h, d))
k = jax.random.normal(keys[1], (b, lkv, h, d))
v = jax.random.normal(keys[2], (b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (b, lkv))

@jax.jit
def bench_flash_bwd(q, k, v, mask):
    return jax.grad(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0]))(q)


bench_flash_bwd(q, k, v, mask)

@mattjj
Copy link
Contributor

mattjj commented Sep 29, 2022

Ah, I think it was just a shape bug; if I sent lq = lvk = 16 then I see what you mean.

I think by adding the better JAX error message I described, we'll catch this much earlier and get an error in both cases. I'll be sure to test both with and without checkpoint/scan.

@mattjj
Copy link
Contributor

mattjj commented Sep 29, 2022

Yes, that'd work!

Actually, I think it would not work just because the callers expect only a single output there.

I think the issue here was that the custom_vjp-decorated function (ie the "primal function") didn't agree with the custom_vjp rule (i.e. their output types didn't agree in the way that they should), but when we only use grad (possibly together with jit) we never actually run the primal function; we only run its forward rule. When grad is applied, we only actually run the primal function when under a jax.checkpoint or jax.scan (or jax.cond etc); that's just because of a JAX implementation detail (these are "initial-style higher-order primitives") which is usually invisible, except apparently when there's a type error in a custom_vjp rule!

@GallagherCommaJack
Copy link
Author

with the fix it's working with lq = lkv under jax.checkpoint!
still fails with lq != lkv which I'm trying to debug now

@GallagherCommaJack
Copy link
Author

midjourney@f690412

@GallagherCommaJack
Copy link
Author

GallagherCommaJack commented Sep 29, 2022

the error with lq = 16; lkv = 17 is TypeError: add got incompatible shapes for broadcasting: (5, 3, 17, 19), (5, 3, 16, 19).

full backtrace:

TypeError                                 Traceback (most recent call last)
Cell In [5], line 22
     18 @jax.jit
     19 def bench_flash_bwd(q, k, v, mask):
     20     return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)
---> 22 bench_flash_bwd(q, k, v, mask)

    [... skipping hidden 14 frame]

Cell In [5], line 20, in bench_flash_bwd(q, k, v, mask)
     18 @jax.jit
     19 def bench_flash_bwd(q, k, v, mask):
---> 20     return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)

    [... skipping hidden 30 frame]

File ~/code/flash-attention-jax/flash_attention_jax/flash_attention.py:172, in flash_attention_backward(res, do)
    169     dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
    170     return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk
--> 172 (_, dk, dv), dq = lax.scan(chunk_scanner, init = (0, dk, dv), xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))
    174 dq = rearrange(dq, 'c n b h d -> b h (c n) d')
    175 dk, dv = map(lambda t: rearrange(t, 'n b h d -> b h n d'), (dk, dv))

    [... skipping hidden 11 frame]

File ~/code/flash-attention-jax/flash_attention_jax/flash_attention.py:170, in flash_attention_backward.<locals>.chunk_scanner(carries, _)
    167 do_chunk = lax.dynamic_slice(do, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, do.shape[-1]))
    169 dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
--> 170 return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk

    [... skipping hidden 1 frame]

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4658, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
   4656 args = (other, self) if swap else (self, other)
   4657 if isinstance(other, _accepted_binop_types):
-> 4658   return binary_op(*args)
   4659 if isinstance(other, _rejected_binop_types):
   4660   raise TypeError(f"unsupported operand type(s) for {opchar}: "
   4661                   f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")

    [... skipping hidden 7 frame]

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py:84, in _maybe_bool_binop.<locals>.fn(x1, x2)
     82 def fn(x1, x2):
     83   x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
---> 84   return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)

    [... skipping hidden 7 frame]

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/lax/lax.py:1537, in broadcasting_shape_rule(name, *avals)
   1535       result_shape.append(non_1s[0])
   1536     else:
-> 1537       raise TypeError(f'{name} got incompatible shapes for broadcasting: '
   1538                       f'{", ".join(map(str, map(tuple, shapes)))}.')
   1540 return tuple(result_shape)

TypeError: add got incompatible shapes for broadcasting: (5, 3, 17, 19), (5, 3, 16, 19).

@mattjj
Copy link
Contributor

mattjj commented Sep 29, 2022

It looks like one of chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk has a shape error, in flash_attention_backward. (EDIT: I don't feel comfortable debugging that without learning what this code is actually doing, so hopefully someone who knows the code/algorithm can help!)

mattjj added a commit to mattjj/flash-attention-jax that referenced this issue Sep 29, 2022
When defining a forward rule for `jax.custom_vjp(primal_fn)`, if `primal_fn`
has output type `T` then we need the forward rule to have output type `(T, R)`
for some `R`. That is, we need the first output of the forward rule to look
like the full output of `primal_fn`. (Here the `R` values represent the
'residuals' computed on the forward pass to save for use on the backward pass.)

This PR fixes a disagreement between `custom_vjp`-decorated functions and their
corresponding forward rules.

The disagreement caused some interesting behavior! Discussed on
lucidrains#7

Separately, I'm going to try to get JAX to raise a better error message in this
case; the error message was some really confusing JAX-internals thing.
@GallagherCommaJack
Copy link
Author

debugging a bit, it looks like the issue is that dk has shape h, b, lkv, d and dk_chunk has shape h, b, lq, d

@GallagherCommaJack
Copy link
Author

GallagherCommaJack commented Sep 29, 2022

@lucidrains looks like there's an implicit assumption somewhere in here that lq == lkv in the backwards pass, in _query_chunk_flash_attention_backward

mattjj added a commit to mattjj/flash-attention-jax that referenced this issue Sep 29, 2022
When defining a forward rule for `jax.custom_vjp(primal_fn)`, if `primal_fn`
has output type `T` then we need the forward rule to have output type `(T, R)`
for some `R`. That is, we need the first output of the forward rule to look
like the full output of `primal_fn`. (Here the `R` values represent the
'residuals' computed on the forward pass to save for use on the backward pass.)

This PR fixes a disagreement between `custom_vjp`-decorated functions and their
corresponding forward rules.

The disagreement caused some interesting behavior! Discussed on
lucidrains#7

Separately, I'm going to try to get JAX to raise a better error message in this
case; the error message was some really confusing JAX-internals thing.
@mattjj
Copy link
Contributor

mattjj commented Sep 29, 2022

@GallagherCommaJack the fix I proposed in #8 is different from the commit you sent, just FYI.

@GallagherCommaJack
Copy link
Author

does that work with lq != lkv?

@GallagherCommaJack
Copy link
Author

looks like it does not

@mattjj
Copy link
Contributor

mattjj commented Sep 29, 2022

Indeed I think the shape issue is unrelated.

mattjj added a commit to mattjj/jax that referenced this issue Oct 1, 2022
In particular:
* add function names so it's clear what decorated functions and rules
  are causing the error;
* when possible (because the functions were run), check for agreement of pytree
  structure and leaf shapes/dtypes between the primal function and rules

context: lucidrains/flash-attention-jax#7
@mattjj
Copy link
Contributor

mattjj commented Oct 1, 2022

jax-ml/jax#12611 should improve the error message we got here! With the same repro (i.e. before the fix #7 was merged here), the error will be:

TypeError: Custom VJP fwd rule flash_attention_forward for function
flash_attention must produce a pair (list or tuple of length two) where the
first element represents the primal output (equal to the output of the
custom_vjp-decorated function flash_attention) and the second element
represents residuals (i.e. values stored from the forward pass for use on the
backward pass), but instead the fwd rule output's first element had
container/pytree structure:
    float32[3,16,5,19]
while the custom_vjp-decorated function flash_attention had output
container/pytree structure:
    (float32[3,16,5,19], (float32[3,16,5], float32[3,16,5])).

mattjj added a commit to mattjj/jax that referenced this issue Oct 1, 2022
In particular:
* add function names so it's clear what decorated functions and rules
  are causing the error;
* when possible (because the functions were run), check for agreement of pytree
  structure and leaf shapes/dtypes between the primal function and rules

context: lucidrains/flash-attention-jax#7
mattjj added a commit to mattjj/jax that referenced this issue Oct 1, 2022
In particular:
* add function names so it's clear what decorated functions and rules
  are causing the error;
* when possible (because the functions were run), check for agreement of pytree
  structure and leaf shapes/dtypes between the primal function and rules

context: lucidrains/flash-attention-jax#7
mattjj added a commit to mattjj/jax that referenced this issue Oct 1, 2022
In particular:
* add function names so it's clear what decorated functions and rules
  are causing the error;
* when possible (because the functions were run), check for agreement of pytree
  structure and leaf shapes/dtypes between the primal function and rules

context: lucidrains/flash-attention-jax#7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants