Skip to content

Commit 5b4c907

Browse files
committed
styling
1 parent 7adaae8 commit 5b4c907

File tree

1 file changed

+5
-21
lines changed

1 file changed

+5
-21
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import itertools
1615
import math
1716
from typing import List, Optional, Tuple
1817

@@ -23,11 +22,11 @@
2322

2423
from ...configuration_utils import ConfigMixin, register_to_config
2524
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
26-
from ..attention_dispatch import dispatch_attention_fn
2725
from ...models.attention_processor import Attention
2826
from ...models.modeling_utils import ModelMixin
2927
from ...utils.import_utils import is_flash_attn_available
3028
from ...utils.torch_utils import maybe_allow_in_graph
29+
from ..attention_dispatch import dispatch_attention_fn
3130

3231

3332
if is_flash_attn_available():
@@ -99,7 +98,6 @@ def __call__(
9998
attention_mask: Optional[torch.Tensor] = None,
10099
freqs_cis: Optional[torch.Tensor] = None,
101100
) -> torch.Tensor:
102-
103101
query = attn.to_q(hidden_states)
104102
key = attn.to_k(hidden_states)
105103
value = attn.to_v(hidden_states)
@@ -586,11 +584,7 @@ def forward(
586584
dtype=x_freqs_cis[0].dtype,
587585
device=device,
588586
)
589-
x_attn_mask = torch.ones(
590-
(bsz, x_max_item_seqlen),
591-
dtype=torch.bool,
592-
device=device
593-
)
587+
x_attn_mask = torch.ones((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
594588
for i, (item, freqs_item) in enumerate(zip(x, x_freqs_cis)):
595589
seq_len = x_item_seqlens[i]
596590
pad_len = x_max_item_seqlen - seq_len
@@ -629,11 +623,7 @@ def forward(
629623
dtype=cap_freqs_cis[0].dtype,
630624
device=device,
631625
)
632-
cap_attn_mask = torch.ones(
633-
(bsz, cap_max_item_seqlen),
634-
dtype=torch.bool,
635-
device=device
636-
)
626+
cap_attn_mask = torch.ones((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
637627
for i, (item, freqs_item) in enumerate(zip(cap_feats, cap_freqs_cis)):
638628
seq_len = cap_item_seqlens[i]
639629
pad_len = cap_max_item_seqlen - seq_len
@@ -672,11 +662,7 @@ def forward(
672662
dtype=unified_freqs_cis[0].dtype,
673663
device=device,
674664
)
675-
unified_attn_mask = torch.ones(
676-
(bsz, unified_max_item_seqlen),
677-
dtype=torch.bool,
678-
device=device
679-
)
665+
unified_attn_mask = torch.ones((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
680666
for i, (item, freqs_item) in enumerate(zip(unified, unified_freqs_cis)):
681667
seq_len = unified_item_seqlens[i]
682668
pad_len = unified_max_item_seqlen - seq_len
@@ -694,9 +680,7 @@ def forward(
694680
adaln_input,
695681
)
696682

697-
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](
698-
unified, adaln_input
699-
)
683+
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
700684
unified = list(unified.unbind(dim=0))
701685
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
702686

0 commit comments

Comments
 (0)