From 8878b5b5e816bbd1f11b5c0a866eaee45a9a6004 Mon Sep 17 00:00:00 2001 From: Benj Fassbind Date: Tue, 21 May 2024 09:08:08 +0200 Subject: [PATCH] Add unpatchify model utils operation (#1544) * Add utils.unpatchify --- lightly/models/utils.py | 27 +++++++++++++++++++++++++++ tests/models/test_ModelUtils.py | 10 ++++++++++ 2 files changed, 37 insertions(+) diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 1fcba9acf..d5180c377 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -434,6 +434,33 @@ def patchify(images: torch.Tensor, patch_size: int) -> torch.Tensor: return patches +def unpatchify( + patches: torch.Tensor, patch_size: int, channels: int = 3 +) -> torch.Tensor: + """ + Reconstructs images from their patches. + + Args: + patches: + Patches tensor with shape (batch_size, num_patches, channels * patch_size ** 2). + patch_size: + The patch size in pixels used to create the patches. + channels: + The number of channels the image must have + + Returns: + Reconstructed images tensor with shape (batch_size, channels, height, width). + """ + N, C = patches.shape[0], channels + patch_h = patch_w = int(patches.shape[1] ** 0.5) + assert patch_h * patch_w == patches.shape[1] + + images = patches.reshape(shape=(N, patch_h, patch_w, patch_size, patch_size, C)) + images = torch.einsum("nhwpqc->nchpwq", images) + images = images.reshape(shape=(N, C, patch_h * patch_size, patch_h * patch_size)) + return images + + def random_token_mask( size: Tuple[int, int], mask_ratio: float = 0.6, diff --git a/tests/models/test_ModelUtils.py b/tests/models/test_ModelUtils.py index 2b7174993..ec39230e2 100644 --- a/tests/models/test_ModelUtils.py +++ b/tests/models/test_ModelUtils.py @@ -199,6 +199,16 @@ def test_patchify(self, seed=0): img_patch = img_patches[i * width_patches + j] self._assert_tensor_equal(img_patch, expected_patch) + def test_unpatchify(self, seed=0): + torch.manual_seed(seed) + batch_size, channels, height, width = (2, 3, 8, 8) + patch_size = 4 + images = torch.rand(batch_size, channels, height, width) + batch_patches = utils.patchify(images, patch_size) + unpatched_images = utils.unpatchify(batch_patches, patch_size) + + self._assert_tensor_equal(images, unpatched_images) + def _test_random_token_mask( self, seed=0, mask_ratio=0.6, mask_class_token=False, device="cpu" ):