Skip to content

Commit 1dd8f3c

Browse files
committed
Remove einop dependency.
1 parent 549ad57 commit 1dd8f3c

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
from einops import rearrange
2221
from torch.nn.utils.rnn import pad_sequence
2322

2423
from ...configuration_utils import ConfigMixin, register_to_config
@@ -429,9 +428,12 @@ def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_pat
429428
for i in range(bsz):
430429
F, H, W = size[i]
431430
ori_len = (F // pF) * (H // pH) * (W // pW)
432-
x[i] = rearrange(
433-
x[i][:ori_len].view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels),
434-
"f h w pf ph pw c -> c (f pf) (h ph) (w pw)",
431+
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
432+
x[i] = (
433+
x[i][:ori_len]
434+
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
435+
.permute(6, 0, 3, 1, 4, 2, 5)
436+
.reshape(self.out_channels, F, H, W)
435437
)
436438
return x
437439

@@ -497,7 +499,8 @@ def patchify_and_embed(
497499
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
498500

499501
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
500-
image = rearrange(image, "c f pf h ph w pw -> (f h w) (pf ph pw c)")
502+
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
503+
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
501504

502505
image_ori_len = len(image)
503506
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF

0 commit comments

Comments
 (0)