Skip to content

Commit

Permalink
fix flash_attention gradient api issues
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Sep 7, 2022
1 parent f80c46f commit a2828ce
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 29 deletions.
73 changes: 46 additions & 27 deletions src/levanter/flash_attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# flake8: noqa
# type: ignore
import math

import jax.numpy as jnp


# https://arxiv.org/pdf/2205.14135.pdf

# assume we have 2GB to work with
Expand Down Expand Up @@ -45,11 +50,12 @@


import math
import jax
from functools import partial
from jax import nn
from jax import custom_vjp
from jax import numpy as jnp, lax, jit

import jax
from jax import custom_vjp, jit, lax, nn
from jax import numpy as jnp


# constants

Expand All @@ -61,10 +67,11 @@

# flash attention


def _query_chunk_flash_attention(q_range_chunk, k_range, q, k, v):
q_len, k_len, dim, v_dim = q.shape[-2], *k.shape, v.shape[-1]
scale = 1 / jnp.sqrt(dim)
q_scaled = q * scale
q_scaled = q * scale

def chunk_scanner(carries, _):
key_chunk_idx, out, row_sum, row_max = carries
Expand All @@ -81,13 +88,13 @@ def chunk_scanner(carries, _):

attn_weights = jnp.where(causal_mask, MASK_VALUE, attn_weights)

block_row_max = jnp.max(attn_weights, axis = -1, keepdims = True)
block_row_max = jnp.max(attn_weights, axis=-1, keepdims=True)

exp_weights = jnp.exp(attn_weights - block_row_max)

exp_weights = jnp.where(causal_mask, 0., exp_weights)
exp_weights = jnp.where(causal_mask, 0.0, exp_weights)

block_row_sum = jnp.sum(exp_weights, axis = -1, keepdims = True) + EPSILON
block_row_sum = jnp.sum(exp_weights, axis=-1, keepdims=True) + EPSILON

exp_values = exp_weights @ v_chunk

Expand All @@ -98,25 +105,31 @@ def chunk_scanner(carries, _):

new_row_sum = exp_row_max_diff * row_sum + exp_block_row_max_diff * block_row_sum

out = (row_sum / new_row_sum) * exp_row_max_diff * out + \
(exp_block_row_max_diff / new_row_sum) * exp_values
out = (row_sum / new_row_sum) * exp_row_max_diff * out + (exp_block_row_max_diff / new_row_sum) * exp_values

return (key_chunk_idx + k_chunk_sizes, out, new_row_sum, new_row_max), None

out = jnp.zeros((q_len, dim))
row_sum = jnp.zeros((q_len, 1))
row_max = jnp.ones((q_len, 1)) * -1e6

(_, out, row_sum, row_max), _ = lax.scan(chunk_scanner, init = (0, out, row_sum, row_max), xs = None, length = math.ceil(k_len / K_CHUNK_SIZE))
(_, out, row_sum, row_max), _ = lax.scan(
chunk_scanner, init=(0, out, row_sum, row_max), xs=None, length=math.ceil(k_len / K_CHUNK_SIZE)
)

out = out.reshape(q_len, v_dim)
row_sum = row_sum.reshape(q_len)
row_max = row_max.reshape(q_len)

return out, row_sum, row_max


@custom_vjp
def causal_flash_attention(q, k, v):
return _causal_flash_attention(q, k, v)[0]


def _causal_flash_attention(q, k, v):
q_len, dim, k_len, v_dim = *q.shape, *v.shape

q_range = jnp.arange(q_len).reshape(q_len, 1) + (k_len - q_len)
Expand All @@ -125,24 +138,26 @@ def causal_flash_attention(q, k, v):
def chunk_scanner(chunk_idx, _):
chunk_sizes = min(Q_CHUNK_SIZE, q_len)

q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes = (chunk_sizes, dim))
q_range_chunk = lax.dynamic_slice(q_range, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1))
q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes=(chunk_sizes, dim))
q_range_chunk = lax.dynamic_slice(q_range, (chunk_idx, 0), slice_sizes=(chunk_sizes, 1))

return (chunk_idx + chunk_sizes, _query_chunk_flash_attention(q_range_chunk, k_range, q_chunk, k, v))

_, (out, row_sum, row_max) = lax.scan(chunk_scanner, init = 0, xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))
_, (out, row_sum, row_max) = lax.scan(chunk_scanner, init=0, xs=None, length=math.ceil(q_len / Q_CHUNK_SIZE))

out = out.reshape(q_len, v_dim)
row_sum = row_sum.reshape(q_len)
row_max = row_max.reshape(q_len)

return out, (row_sum, row_max)


@jit
def flash_attention_forward(q, k, v):
out, (row_sum, row_max) = causal_flash_attention(q, k, v)
out, (row_sum, row_max) = _causal_flash_attention(q, k, v)
return out, (q, k, v, out, row_sum, row_max)


def _query_chunk_flash_attention_backward(query_range_chunk, key_range, q, k, v, o, do, l, m):
q_len, dim, k_len, v_dim = *q.shape, *v.shape

Expand All @@ -166,14 +181,14 @@ def chunk_scanner(carries, _):

exp_attn_weights = jnp.exp(attn_weights - m)

exp_attn_weights = jnp.where(causal_mask, 0., exp_attn_weights)
exp_attn_weights = jnp.where(causal_mask, 0.0, exp_attn_weights)

p = exp_attn_weights / l

dv_chunk = p.transpose() @ do
dp = do @ v_chunk.transpose()

D = jnp.sum(do * o, axis = -1, keepdims = True)
D = jnp.sum(do * o, axis=-1, keepdims=True)
ds = p * scale * (dp - D)

dq_chunk = ds @ k_chunk
Expand All @@ -183,14 +198,15 @@ def chunk_scanner(carries, _):

dq = jnp.zeros_like(q)

(_, dq), (dk, dv) = lax.scan(chunk_scanner, init = (0, dq), xs = None, length = math.ceil(k_len / K_CHUNK_SIZE))
(_, dq), (dk, dv) = lax.scan(chunk_scanner, init=(0, dq), xs=None, length=math.ceil(k_len / K_CHUNK_SIZE))

dq = dq.reshape(q_len, dim)
dk = dk.reshape(k_len, v_dim)
dv = dv.reshape(k_len, v_dim)

return dq, dk, dv


@jit
def flash_attention_backward(res, do):
q, k, v, o, l, m = res
Expand All @@ -211,24 +227,27 @@ def chunk_scanner(carries, _):

chunk_sizes = min(Q_CHUNK_SIZE, q_len)

q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes = (chunk_sizes, q.shape[-1]))
q_range_chunk = lax.dynamic_slice(q_range, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1))
q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes=(chunk_sizes, q.shape[-1]))
q_range_chunk = lax.dynamic_slice(q_range, (chunk_idx, 0), slice_sizes=(chunk_sizes, 1))

m_chunk = lax.dynamic_slice(m, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1))
l_chunk = lax.dynamic_slice(l, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1))
o_chunk = lax.dynamic_slice(o, (chunk_idx, 0), slice_sizes = (chunk_sizes, o.shape[-1]))
do_chunk = lax.dynamic_slice(do, (chunk_idx, 0), slice_sizes = (chunk_sizes, do.shape[-1]))
m_chunk = lax.dynamic_slice(m, (chunk_idx, 0), slice_sizes=(chunk_sizes, 1))
l_chunk = lax.dynamic_slice(l, (chunk_idx, 0), slice_sizes=(chunk_sizes, 1))
o_chunk = lax.dynamic_slice(o, (chunk_idx, 0), slice_sizes=(chunk_sizes, o.shape[-1]))
do_chunk = lax.dynamic_slice(do, (chunk_idx, 0), slice_sizes=(chunk_sizes, do.shape[-1]))

dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_range_chunk, k_range, q_chunk, k, v, o_chunk, do_chunk, l_chunk, m_chunk)
dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(
q_range_chunk, k_range, q_chunk, k, v, o_chunk, do_chunk, l_chunk, m_chunk
)
return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk

(_, dk, dv), dq = lax.scan(chunk_scanner, init = (0, dk, dv), xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))
(_, dk, dv), dq = lax.scan(chunk_scanner, init=(0, dk, dv), xs=None, length=math.ceil(q_len / Q_CHUNK_SIZE))

dq = dq.reshape(q_len, dim)

return dq, dk, dv


causal_flash_attention.defvjp(flash_attention_forward, flash_attention_backward)


multiheaded_causal_flash_attention = jax.vmap(causal_flash_attention, in_axes = (0, 0, 0), out_axes = 0)
multiheaded_causal_flash_attention = jax.vmap(causal_flash_attention, in_axes=(0, 0, 0), out_axes=0)
4 changes: 2 additions & 2 deletions src/levanter/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from haliax.partitioning import logically_sharded
from levanter import jax_utils
from levanter.compat.torch_serialization import StateDict, TorchSerializationMixin, apply_prefix, reshape_linear_layer
from levanter.flash_attention import multiheaded_causal_flash_attention
from levanter.flash_attention import multiheaded_causal_flash_attention # type: ignore
from levanter.jax_utils import named_call
from levanter.modeling_utils import ACT2FN

Expand Down Expand Up @@ -135,7 +135,7 @@ def __call__(self, hidden_states: NamedArray, layer_idx, inference: bool = True,
query, key, value = map(
lambda x: x.rearrange((self.Heads, self.SeqLen, self.HeadDim)).array, (query, key, value)
)
attn_output = multiheaded_causal_flash_attention(query, key, value)[0]
attn_output = multiheaded_causal_flash_attention(query, key, value)
attn_output = NamedArray(attn_output, (self.Heads, self.SeqLen, self.HeadDim))

# KeySeqLen = self.SeqLen.alias("KeySeqLen") # haliax doesn't support unnamed axes or duplicate axes
Expand Down

0 comments on commit a2828ce

Please sign in to comment.