Skip to content

Commit

Permalink
Add unpatchify model utils operation (#1544)
Browse files Browse the repository at this point in the history
* Add utils.unpatchify
  • Loading branch information
randombenj authored May 21, 2024
1 parent a3de571 commit 8878b5b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
27 changes: 27 additions & 0 deletions lightly/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,33 @@ def patchify(images: torch.Tensor, patch_size: int) -> torch.Tensor:
return patches


def unpatchify(
patches: torch.Tensor, patch_size: int, channels: int = 3
) -> torch.Tensor:
"""
Reconstructs images from their patches.
Args:
patches:
Patches tensor with shape (batch_size, num_patches, channels * patch_size ** 2).
patch_size:
The patch size in pixels used to create the patches.
channels:
The number of channels the image must have
Returns:
Reconstructed images tensor with shape (batch_size, channels, height, width).
"""
N, C = patches.shape[0], channels
patch_h = patch_w = int(patches.shape[1] ** 0.5)
assert patch_h * patch_w == patches.shape[1]

images = patches.reshape(shape=(N, patch_h, patch_w, patch_size, patch_size, C))
images = torch.einsum("nhwpqc->nchpwq", images)
images = images.reshape(shape=(N, C, patch_h * patch_size, patch_h * patch_size))
return images


def random_token_mask(
size: Tuple[int, int],
mask_ratio: float = 0.6,
Expand Down
10 changes: 10 additions & 0 deletions tests/models/test_ModelUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,16 @@ def test_patchify(self, seed=0):
img_patch = img_patches[i * width_patches + j]
self._assert_tensor_equal(img_patch, expected_patch)

def test_unpatchify(self, seed=0):
torch.manual_seed(seed)
batch_size, channels, height, width = (2, 3, 8, 8)
patch_size = 4
images = torch.rand(batch_size, channels, height, width)
batch_patches = utils.patchify(images, patch_size)
unpatched_images = utils.unpatchify(batch_patches, patch_size)

self._assert_tensor_equal(images, unpatched_images)

def _test_random_token_mask(
self, seed=0, mask_ratio=0.6, mask_class_token=False, device="cpu"
):
Expand Down

0 comments on commit 8878b5b

Please sign in to comment.