Skip to content

Commit

Permalink
Add 5D support for flash_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Feb 9, 2025
1 parent c0afda3 commit 9a76e5e
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def fa_custom_forward(
ab = xs.enable_manual_sharding(
ab, partition_spec, mesh=mesh).global_tensor



# It computes the shape and type of o, l, m.
shapes = [q.shape]
dtypes = [q.dtype]
Expand All @@ -279,6 +281,14 @@ def fa_custom_forward(
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)

# support 5D inputs
if len(q_full_shape) == 5:
q = q.reshape(-1, *q_full_shape[2:])
k = k.reshape(-1, *q_full_shape[2:])
v = v.reshape(-1, *q_full_shape[2:])
q_segment_ids = q_segment_ids.reshape(-1, *q_segment_ids.shape[2:])
kv_segment_ids = kv_segment_ids.reshape(-1, *kv_segment_ids.shape[2:])

# We can't directly use flash_attention as we need to override the save_residuals flag which returns
# l and m that is needed for the backward. Then we lose all the shape checks.
# TODO: replicate the shape checks on flash_attention.
Expand Down Expand Up @@ -322,6 +332,8 @@ def fa_custom_forward(
o, *aux = o
l, m = (v[..., 0] for v in aux[-2:])

if len(q_full_shape) == 5:
o = o.reshape(q_full_shape)
# SPMD integration
if partition_spec is not None:
o = xs.disable_manual_sharding(
Expand Down Expand Up @@ -397,6 +409,17 @@ def fa_custom_backward(
if ab is not None:
ab = xs.enable_manual_sharding(
ab, partition_spec, mesh=mesh).global_tensor

# support 5D input
if len(q.shape) == 5:
q = q.reshape(-1, *q.shape[2:])
k = k.reshape(-1, *k.shape[2:])
v = v.reshape(-1, *v.shape[2:])
expanded_l = expanded_l.reshape(-1, *expanded_l.shape[2:])
expanded_m = expanded_m.reshape(-1, *expanded_m.shape[2:])
grad_output = grad_output.reshape(-1, *grad_output.shape[2:])
expanded_grad_i = expanded_grad_i.reshape(-1, *expanded_grad_i.shape[2:])

if q_segment_ids is not None and kv_segment_ids is not None:
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)
Expand Down Expand Up @@ -490,6 +513,16 @@ def fa_custom_backward(
if require_grad_v:
grad_v = grads[1]


# support 5D input
if len(q.shape) == 5:
grad_q = grad_q.reshape(q_full_shape)
grad_k = grad_k.reshape(kv_full_shape)
grad_v = grad_v.reshape(kv_full_shape)
grad_v = grad_v.reshape(kv_full_shape)
if ab is not None:
grad_ab = grad_ab.reshape(ab_full_shape)

# SPMD integration
if partition_spec is not None:
grad_q = xs.disable_manual_sharding(
Expand Down

0 comments on commit 9a76e5e

Please sign in to comment.