|
18 | 18 | import torch |
19 | 19 | import torch.nn as nn |
20 | 20 | import torch.nn.functional as F |
21 | | -from einops import rearrange |
22 | 21 | from torch.nn.utils.rnn import pad_sequence |
23 | 22 |
|
24 | 23 | from ...configuration_utils import ConfigMixin, register_to_config |
@@ -429,9 +428,12 @@ def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_pat |
429 | 428 | for i in range(bsz): |
430 | 429 | F, H, W = size[i] |
431 | 430 | ori_len = (F // pF) * (H // pH) * (W // pW) |
432 | | - x[i] = rearrange( |
433 | | - x[i][:ori_len].view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels), |
434 | | - "f h w pf ph pw c -> c (f pf) (h ph) (w pw)", |
| 431 | + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" |
| 432 | + x[i] = ( |
| 433 | + x[i][:ori_len] |
| 434 | + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) |
| 435 | + .permute(6, 0, 3, 1, 4, 2, 5) |
| 436 | + .reshape(self.out_channels, F, H, W) |
435 | 437 | ) |
436 | 438 | return x |
437 | 439 |
|
@@ -497,7 +499,8 @@ def patchify_and_embed( |
497 | 499 | F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW |
498 | 500 |
|
499 | 501 | image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) |
500 | | - image = rearrange(image, "c f pf h ph w pw -> (f h w) (pf ph pw c)") |
| 502 | + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" |
| 503 | + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) |
501 | 504 |
|
502 | 505 | image_ori_len = len(image) |
503 | 506 | image_padding_len = (-image_ori_len) % SEQ_MULTI_OF |
|
0 commit comments