Skip to content

Commit

Permalink
Support SegmentID when doing data prallel SPMD (#8425)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Nov 28, 2024
1 parent 20f5166 commit 1c91219
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 7 deletions.
130 changes: 129 additions & 1 deletion test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest

import torch
import numpy as np
from torch import nn as nn

import torch_xla
Expand All @@ -22,8 +23,24 @@

class PallasTest(unittest.TestCase):

def _attention(self, q, k, v):
# This is to create a diagonal mask where only elements within the same segment
# can attend to each other. Since the mask is to mask out the unrelevant parts,
# therefore we use != instead of ==.
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
kv_segment_ids):
return q_segment_ids.view(q_segment_ids.shape[0], 1,
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
kv_segment_ids.shape[0], 1, 1,
kv_segment_ids.shape[1])

def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
if attn_mask is not None:
# Masked out the unrelevant parts.
attn_weight = attn_weight.masked_fill(attn_mask,
torch.finfo(attn_weight.dtype).min)
if ab is not None:
attn_weight = attn_weight + ab
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
Expand Down Expand Up @@ -98,6 +115,117 @@ def test_flash_attention_backward_spmd_data_parallel(self):
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_wrapper_segment_ids_spmd(self):
from torch_xla.experimental.custom_kernel import flash_attention
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention, SegmentIds
xs.set_global_mesh(xs.get_1d_mesh("data"))

q = torch.randn(3, 2, 128, 4)
k = torch.randn(3, 2, 128, 4)
v = torch.randn(3, 2, 128, 4)
zeros = torch.zeros(3, 32)
segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
segment_ids_xla = segment_ids.to("xla")
# only shard data dimension
o = flash_attention(
q.to("xla"),
k.to("xla"),
v.to("xla"),
False,
segment_ids_xla,
segment_ids.to("xla"),
partition_spec=("data", None, None, None))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}")

jax_q = jnp.array(q.numpy(), dtype=jnp.float32)
jax_k = jnp.array(k.numpy(), dtype=jnp.float32)
jax_v = jnp.array(v.numpy(), dtype=jnp.float32)
jax_segment_ids = jnp.array(segment_ids.numpy(), dtype=jnp.float32)
expected_o = torch.from_numpy(
np.array(
jax_flash_attention(
jax_q,
jax_k,
jax_v,
segment_ids=SegmentIds(jax_segment_ids, jax_segment_ids),
)))

self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_backward_segment_ids_spmd(self):
jax.config.update("jax_default_matmul_precision", "highest")
from torch_xla.experimental.custom_kernel import flash_attention
n_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.get_1d_mesh("data"))

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
zeros = torch.zeros(4, 32).to("xla")
segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = flash_attention(
q,
k,
v,
False,
segment_ids,
segment_ids,
partition_spec=("data", None, None, None))
loss = o.sum()
loss.backward()
q_grad = q.grad
k_grad = k.grad
v_grad = v.grad
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(q_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(k_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v_grad),
f"{{devices=[{n_devices},1,1,1]0,1,2,3}}")
torch_xla.sync()

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
zeros = torch.zeros(4, 32).to("xla")
segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1)
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = self._attention(
q,
k,
v,
attn_mask=self._make_attention_mask_from_segment_ids(
segment_ids, segment_ids))
loss = o.sum()
loss.backward()
xm.mark_step()

for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update("jax_default_matmul_precision", "default")


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
23 changes: 17 additions & 6 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
dtypes.append(torch.float32)

with torch.no_grad():
segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids(
if partition_spec is not None and q_segment_ids is not None and kv_segment_ids is not None:
# partition_spec is for q,k,v with shape [batch, num_head, seq_len, head_dim], segment id
# is of shape [batch, seq_len], hence we need to tweak it a bit
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
q_segment_ids = xs.enable_manual_sharding(
q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
kv_segment_ids = xs.enable_manual_sharding(
kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)
ctx.segment_ids = segment_ids

Expand Down Expand Up @@ -297,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
if ab is not None:
args += [ab]
if segment_ids is not None:
args += [q_segment_ids, kv_segment_ids]
args += [q_segment_ids_fa, kv_segment_ids_fa]
o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes)

if not save_residuals:
Expand All @@ -319,20 +327,23 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
m = xs.disable_manual_sharding(
m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor

ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids,
kv_segment_ids, full_ab)
# q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
# but it should be OK as the backward will use the same partition_spec
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa,
kv_segment_ids_fa, full_ab)
return o

@staticmethod
def backward(ctx, grad_output):
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv

q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors
q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors
causal = ctx.causal
sm_scale = ctx.sm_scale
partition_spec = ctx.partition_spec
mesh = ctx.mesh
full_shape = ctx.full_shape
# this segment_ids only reflects the local shape of segment_ids
segment_ids = ctx.segment_ids
grad_q = grad_k = grad_v = grad_ab = None

Expand Down Expand Up @@ -398,7 +409,7 @@ def backward(ctx, grad_output):
if ab is not None:
args += [ab]
if segment_ids is not None:
args += [q_segment_ids, kv_segment_ids]
args += [q_segment_ids_fa, kv_segment_ids_fa]
args += [expanded_l, expanded_m, grad_output, expanded_grad_i]

outputs = [q]
Expand Down

0 comments on commit 1c91219

Please sign in to comment.