Skip to content

Commit d6256ab

Browse files
committed
ulysses enabling in native attention path
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 9f3c0fd commit d6256ab

File tree

1 file changed

+42
-12
lines changed

1 file changed

+42
-12
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,18 +1538,48 @@ def _native_attention(
15381538
) -> torch.Tensor:
15391539
if return_lse:
15401540
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
1541-
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
1542-
out = torch.nn.functional.scaled_dot_product_attention(
1543-
query=query,
1544-
key=key,
1545-
value=value,
1546-
attn_mask=attn_mask,
1547-
dropout_p=dropout_p,
1548-
is_causal=is_causal,
1549-
scale=scale,
1550-
enable_gqa=enable_gqa,
1551-
)
1552-
out = out.permute(0, 2, 1, 3)
1541+
if _parallel_config is None:
1542+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
1543+
out = torch.nn.functional.scaled_dot_product_attention(
1544+
query=query,
1545+
key=key,
1546+
value=value,
1547+
attn_mask=attn_mask,
1548+
dropout_p=dropout_p,
1549+
is_causal=is_causal,
1550+
scale=scale,
1551+
enable_gqa=enable_gqa,
1552+
)
1553+
out = out.permute(0, 2, 1, 3)
1554+
elif _parallel_config.context_parallel_config.ring_degree == 1:
1555+
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
1556+
world_size = _parallel_config.context_parallel_config.ulysses_degree
1557+
group = ulysses_mesh.get_group()
1558+
1559+
B, S_Q_LOCAL, H, D = query.shape
1560+
_, S_KV_LOCAL, _, _ = key.shape
1561+
H_LOCAL = H // world_size
1562+
query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
1563+
key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
1564+
value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
1565+
query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
1566+
query, key, value = (x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (query, key, value))
1567+
out = torch.nn.functional.scaled_dot_product_attention(
1568+
query=query,
1569+
key=key,
1570+
value=value,
1571+
attn_mask=attn_mask,
1572+
dropout_p=dropout_p,
1573+
is_causal=is_causal,
1574+
scale=scale,
1575+
enable_gqa=enable_gqa,
1576+
)
1577+
out = out.reshape(B, H_LOCAL, world_size, S_Q_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
1578+
out = _all_to_all_single(out, group)
1579+
out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
1580+
return out
1581+
else:
1582+
raise ValueError("Native attention backend does not support context parallelism with ring_degree > 1, you could try to use ulysses Attention instead")
15531583
return out
15541584

15551585

0 commit comments

Comments
 (0)