Skip to content

Commit 6c0c059

Browse files
committed
Fix Docstring and Make Style.
1 parent fbf26b7 commit 6c0c059

File tree

3 files changed

+106
-30
lines changed

3 files changed

+106
-30
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 102 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,17 @@
1818
import math
1919
from dataclasses import dataclass
2020
from enum import Enum
21-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
21+
from typing import (
22+
TYPE_CHECKING,
23+
Any,
24+
Callable,
25+
Dict,
26+
List,
27+
Literal,
28+
Optional,
29+
Tuple,
30+
Union,
31+
)
2232

2333
import torch
2434

@@ -68,7 +78,10 @@
6878

6979
if _CAN_USE_FLASH_ATTN:
7080
from flash_attn import flash_attn_func, flash_attn_varlen_func
71-
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
81+
from flash_attn.flash_attn_interface import (
82+
_wrapped_flash_attn_backward,
83+
_wrapped_flash_attn_forward,
84+
)
7285
else:
7386
flash_attn_func = None
7487
flash_attn_varlen_func = None
@@ -77,9 +90,9 @@
7790

7891

7992
if _CAN_USE_FLASH_ATTN_3:
93+
from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward
8094
from flash_attn_interface import flash_attn_func as flash_attn_3_func
8195
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
82-
from flash_attn_interface import _flash_attn_forward as flash_attn_3_forward
8396
else:
8497
flash_attn_3_func = None
8598
flash_attn_3_varlen_func = None
@@ -122,7 +135,9 @@
122135

123136

124137
if _CAN_USE_XLA_ATTN:
125-
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
138+
from torch_xla.experimental.custom_kernel import (
139+
flash_attention as xla_flash_attention,
140+
)
126141
else:
127142
xla_flash_attention = None
128143

@@ -265,13 +280,17 @@ class _HubKernelConfig:
265280
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
266281
# TODO: temporary revision for now. Remove when merged upstream into `main`.
267282
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
268-
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
283+
repo_id="kernels-community/flash-attn3",
284+
function_attr="flash_attn_func",
285+
revision="fake-ops-return-probs",
269286
)
270287
}
271288

272289

273290
@contextlib.contextmanager
274-
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
291+
def attention_backend(
292+
backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE,
293+
):
275294
"""
276295
Context manager to set the active attention backend.
277296
"""
@@ -416,7 +435,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
416435
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
417436
)
418437

419-
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
438+
elif backend in [
439+
AttentionBackendName._FLASH_3,
440+
AttentionBackendName._FLASH_VARLEN_3,
441+
]:
420442
if not _CAN_USE_FLASH_ATTN_3:
421443
raise RuntimeError(
422444
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
@@ -488,7 +510,11 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
488510
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
489511
max_seqlen_q = seqlens_q.max().item()
490512
max_seqlen_k = seqlens_k.max().item()
491-
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
513+
return (
514+
(seqlens_q, seqlens_k),
515+
(cu_seqlens_q, cu_seqlens_k),
516+
(max_seqlen_q, max_seqlen_k),
517+
)
492518

493519

494520
def _prepare_for_flash_attn_or_sage_varlen_with_mask(
@@ -505,7 +531,11 @@ def _prepare_for_flash_attn_or_sage_varlen_with_mask(
505531
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
506532
max_seqlen_q = seqlens_q.max().item()
507533
max_seqlen_k = seqlens_k.max().item()
508-
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
534+
return (
535+
(seqlens_q, seqlens_k),
536+
(cu_seqlens_q, cu_seqlens_k),
537+
(max_seqlen_q, max_seqlen_k),
538+
)
509539

510540

511541
def _prepare_for_flash_attn_or_sage_varlen(
@@ -625,7 +655,7 @@ def _wrapped_flash_attn_3(
625655
window_size = (-1, -1)
626656
max_seqlen_q = q.shape[2]
627657
max_seqlen_k = k.shape[2]
628-
658+
629659
out, lse, *_ = flash_attn_3_forward(
630660
q=q,
631661
k=k,
@@ -764,7 +794,10 @@ def _native_attention_backward_op(
764794

765795
grad_out_t = grad_out.permute(0, 2, 1, 3)
766796
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
767-
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
797+
outputs=out,
798+
inputs=[query_t, key_t, value_t],
799+
grad_outputs=grad_out_t,
800+
retain_graph=False,
768801
)
769802

770803
grad_query = grad_query_t.permute(0, 2, 1, 3)
@@ -803,18 +836,26 @@ def _cudnn_attention_forward_op(
803836
value = value.transpose(1, 2).contiguous()
804837
tensors_to_save += (query, key, value)
805838

806-
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
807-
torch.ops.aten._scaled_dot_product_cudnn_attention(
808-
query=query,
809-
key=key,
810-
value=value,
811-
attn_bias=attn_mask,
812-
compute_log_sumexp=return_lse,
813-
dropout_p=dropout_p,
814-
is_causal=is_causal,
815-
return_debug_mask=False,
816-
scale=scale,
817-
)
839+
(
840+
out,
841+
lse,
842+
cum_seq_q,
843+
cum_seq_k,
844+
max_q,
845+
max_k,
846+
philox_seed,
847+
philox_offset,
848+
debug_attn_mask,
849+
) = torch.ops.aten._scaled_dot_product_cudnn_attention(
850+
query=query,
851+
key=key,
852+
value=value,
853+
attn_bias=attn_mask,
854+
compute_log_sumexp=return_lse,
855+
dropout_p=dropout_p,
856+
is_causal=is_causal,
857+
return_debug_mask=False,
858+
scale=scale,
818859
)
819860

820861
tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
@@ -941,7 +982,11 @@ def _flash_attention_backward_op(
941982
**kwargs,
942983
):
943984
query, key, value, out, lse, rng_state = ctx.saved_tensors
944-
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
985+
grad_query, grad_key, grad_value = (
986+
torch.empty_like(query),
987+
torch.empty_like(key),
988+
torch.empty_like(value),
989+
)
945990

946991
lse_d = _wrapped_flash_attn_backward( # noqa: F841
947992
grad_out,
@@ -1165,7 +1210,19 @@ def backward(
11651210

11661211
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
11671212

1168-
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
1213+
return (
1214+
grad_query,
1215+
grad_key,
1216+
grad_value,
1217+
None,
1218+
None,
1219+
None,
1220+
None,
1221+
None,
1222+
None,
1223+
None,
1224+
None,
1225+
)
11691226

11701227

11711228
class TemplatedUlyssesAttention(torch.autograd.Function):
@@ -1260,7 +1317,19 @@ def backward(
12601317
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
12611318
)
12621319

1263-
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
1320+
return (
1321+
grad_query,
1322+
grad_key,
1323+
grad_value,
1324+
None,
1325+
None,
1326+
None,
1327+
None,
1328+
None,
1329+
None,
1330+
None,
1331+
None,
1332+
)
12641333

12651334

12661335
def _templated_context_parallel_attention(
@@ -1608,7 +1677,12 @@ def _native_flex_attention(
16081677
block_mask = attn_mask
16091678
elif is_causal:
16101679
block_mask = flex_attention.create_block_mask(
1611-
_flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device
1680+
_flex_attention_causal_mask_mod,
1681+
batch_size,
1682+
num_heads,
1683+
seq_len_q,
1684+
seq_len_kv,
1685+
query.device,
16121686
)
16131687
elif torch.is_tensor(attn_mask):
16141688
if attn_mask.ndim == 2:
@@ -1628,6 +1702,7 @@ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
16281702

16291703
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
16301704
return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
1705+
16311706
else:
16321707
raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
16331708

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,9 @@ def forward(
638638

639639
if torch.is_grad_enabled() and self.gradient_checkpointing:
640640
for layer in self.layers:
641-
unified = self._gradient_checkpointing_func(layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input)
641+
unified = self._gradient_checkpointing_func(
642+
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
643+
)
642644
else:
643645
for layer in self.layers:
644646
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@
4545
>>> # pipe.transformer.set_attention_backend("flash")
4646
>>> # (2) Use flash attention 3
4747
>>> # pipe.transformer.set_attention_backend("_flash_3")
48-
49-
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
48+
>>> prompt = '一幅为名为"造相「Z-IMAGE-TURBO」"的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。'
5049
>>> image = pipe(
5150
... prompt,
5251
... height=1024,

0 commit comments

Comments
 (0)