Skip to content

Commit de4c6f1

Browse files
committed
Support latest FA3 API changes
1 parent 8e391b7 commit de4c6f1

File tree

3 files changed

+53
-31
lines changed

3 files changed

+53
-31
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -621,24 +621,32 @@ def _wrapped_flash_attn_3(
621621
) -> Tuple[torch.Tensor, torch.Tensor]:
622622
# Hardcoded for now because pytorch does not support tuple/int type hints
623623
window_size = (-1, -1)
624-
out, lse, *_ = flash_attn_3_func(
625-
q=q,
626-
k=k,
627-
v=v,
628-
softmax_scale=softmax_scale,
629-
causal=causal,
630-
qv=qv,
631-
q_descale=q_descale,
632-
k_descale=k_descale,
633-
v_descale=v_descale,
634-
window_size=window_size,
635-
attention_chunk=attention_chunk,
636-
softcap=softcap,
637-
num_splits=num_splits,
638-
pack_gqa=pack_gqa,
639-
deterministic=deterministic,
640-
sm_margin=sm_margin,
641-
)
624+
625+
kwargs = {
626+
"q": q,
627+
"k": k,
628+
"v": v,
629+
"softmax_scale": softmax_scale,
630+
"causal": causal,
631+
"qv": qv,
632+
"q_descale": q_descale,
633+
"k_descale": k_descale,
634+
"v_descale": v_descale,
635+
"window_size": window_size,
636+
"attention_chunk": attention_chunk,
637+
"softcap": softcap,
638+
"num_splits": num_splits,
639+
"pack_gqa": pack_gqa,
640+
"deterministic": deterministic,
641+
"sm_margin": sm_margin,
642+
}
643+
644+
# For backward compatibility with early flash-attn-3 APIs.
645+
if "return_attn_probs" in inspect.signature(flash_attn_3_func).parameters:
646+
kwargs["return_attn_probs"] = True
647+
648+
out, lse, *_ = flash_attn_3_func(**kwargs)
649+
642650
lse = lse.permute(0, 2, 1)
643651
return out, lse
644652

@@ -1504,17 +1512,29 @@ def _flash_varlen_attention_3(
15041512
key_packed = torch.cat(key_valid, dim=0)
15051513
value_packed = torch.cat(value_valid, dim=0)
15061514

1507-
out, lse, *_ = flash_attn_3_varlen_func(
1508-
q=query_packed,
1509-
k=key_packed,
1510-
v=value_packed,
1511-
cu_seqlens_q=cu_seqlens_q,
1512-
cu_seqlens_k=cu_seqlens_k,
1513-
max_seqlen_q=max_seqlen_q,
1514-
max_seqlen_k=max_seqlen_k,
1515-
softmax_scale=scale,
1516-
causal=is_causal,
1517-
)
1515+
kwargs = {
1516+
"q": query_packed,
1517+
"k": key_packed,
1518+
"v": value_packed,
1519+
"cu_seqlens_q": cu_seqlens_q,
1520+
"cu_seqlens_k": cu_seqlens_k,
1521+
"max_seqlen_q": max_seqlen_q,
1522+
"max_seqlen_k": max_seqlen_k,
1523+
"softmax_scale": scale,
1524+
"causal": is_causal,
1525+
}
1526+
1527+
if "return_attn_probs" in inspect.signature(flash_attn_3_varlen_func).parameters:
1528+
kwargs["return_attn_probs"] = return_lse
1529+
out = flash_attn_3_varlen_func(**kwargs)
1530+
if return_lse:
1531+
out, lse = out[0], out[1]
1532+
else:
1533+
lse = None
1534+
else:
1535+
# For backward compatibility with early flash-attn-3 APIs.
1536+
out, lse, *_ = flash_attn_3_varlen_func(**kwargs)
1537+
15181538
out = out.unflatten(0, (batch_size, -1))
15191539

15201540
return (out, lse) if return_lse else out

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,9 @@ def forward(
638638

639639
if torch.is_grad_enabled() and self.gradient_checkpointing:
640640
for layer in self.layers:
641-
unified = self._gradient_checkpointing_func(layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input)
641+
unified = self._gradient_checkpointing_func(
642+
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
643+
)
642644
else:
643645
for layer in self.layers:
644646
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
>>> # pipe.transformer.set_attention_backend("flash")
4646
>>> # (2) Use flash attention 3
4747
>>> # pipe.transformer.set_attention_backend("_flash_3")
48-
48+
4949
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
5050
>>> image = pipe(
5151
... prompt,

0 commit comments

Comments
 (0)