diff --git a/pl_bolts/models/self_supervised/simclr/transforms.py b/pl_bolts/models/self_supervised/simclr/transforms.py index aecc1388cc..37eccfd6c6 100644 --- a/pl_bolts/models/self_supervised/simclr/transforms.py +++ b/pl_bolts/models/self_supervised/simclr/transforms.py @@ -1,7 +1,4 @@ -import numpy as np - -from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import under_review +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -9,15 +6,9 @@ else: # pragma: no cover warn_missing_pkg("torchvision") -if _OPENCV_AVAILABLE: - import cv2 -else: # pragma: no cover - warn_missing_pkg("cv2", pypi_name="opencv-python") - -@under_review() class SimCLRTrainDataTransform: - """Transforms for SimCLR. + """Transforms for SimCLR during training step of the pre-training stage. Transform:: @@ -25,7 +16,7 @@ class SimCLRTrainDataTransform: RandomHorizontalFlip() RandomApply([color_jitter], p=0.8) RandomGrayscale(p=0.2) - GaussianBlur(kernel_size=int(0.1 * self.input_height)) + RandomApply([GaussianBlur(kernel_size=int(0.1 * self.input_height))], p=0.5) transforms.ToTensor() Example:: @@ -34,7 +25,7 @@ class SimCLRTrainDataTransform: transform = SimCLRTrainDataTransform(input_height=32) x = sample() - (xi, xj) = transform(x) + (xi, xj, xk) = transform(x) # xk is only for the online evaluator if used """ def __init__( @@ -68,16 +59,16 @@ def __init__( if kernel_size % 2 == 0: kernel_size += 1 - data_transforms.append(GaussianBlur(kernel_size=kernel_size, p=0.5)) + data_transforms.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5)) - data_transforms = transforms.Compose(data_transforms) + self.data_transforms = transforms.Compose(data_transforms) if normalize is None: self.final_transform = transforms.ToTensor() else: self.final_transform = transforms.Compose([transforms.ToTensor(), normalize]) - self.train_transform = transforms.Compose([data_transforms, self.final_transform]) + self.train_transform = transforms.Compose([self.data_transforms, self.final_transform]) # add online train transform of the size of global view self.online_transform = transforms.Compose( @@ -93,9 +84,8 @@ def __call__(self, sample): return xi, xj, self.online_transform(sample) -@under_review() class SimCLREvalDataTransform(SimCLRTrainDataTransform): - """Transforms for SimCLR. + """Transforms for SimCLR during the validation step of the pre-training stage. Transform:: @@ -109,7 +99,7 @@ class SimCLREvalDataTransform(SimCLRTrainDataTransform): transform = SimCLREvalDataTransform(input_height=32) x = sample() - (xi, xj) = transform(x) + (xi, xj, xk) = transform(x) # xk is only for the online evaluator if used """ def __init__( @@ -129,70 +119,39 @@ def __init__( ) -@under_review() -class SimCLRFinetuneTransform: +class SimCLRFinetuneTransform(SimCLRTrainDataTransform): + """Transforms for SimCLR during the fine-tuning stage. + + Transform:: + + Resize(input_height + 10, interpolation=3) + transforms.CenterCrop(input_height), + transforms.ToTensor() + + Example:: + + from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform + + transform = SimCLREvalDataTransform(input_height=32) + x = sample() + xk = transform(x) + """ + def __init__( self, input_height: int = 224, jitter_strength: float = 1.0, normalize=None, eval_transform: bool = False ) -> None: - self.jitter_strength = jitter_strength - self.input_height = input_height - self.normalize = normalize - - self.color_jitter = transforms.ColorJitter( - 0.8 * self.jitter_strength, - 0.8 * self.jitter_strength, - 0.8 * self.jitter_strength, - 0.2 * self.jitter_strength, + super().__init__( + normalize=normalize, input_height=input_height, gaussian_blur=None, jitter_strength=jitter_strength ) - if not eval_transform: - data_transforms = [ - transforms.RandomResizedCrop(size=self.input_height), - transforms.RandomHorizontalFlip(p=0.5), - transforms.RandomApply([self.color_jitter], p=0.8), - transforms.RandomGrayscale(p=0.2), - ] - else: - data_transforms = [ + if eval_transform: + self.data_transforms = [ transforms.Resize(int(self.input_height + 0.1 * self.input_height)), transforms.CenterCrop(self.input_height), ] - if normalize is None: - final_transform = transforms.ToTensor() - else: - final_transform = transforms.Compose([transforms.ToTensor(), normalize]) - - data_transforms.append(final_transform) - self.transform = transforms.Compose(data_transforms) + self.transform = transforms.Compose([self.data_transforms, self.final_transform]) def __call__(self, sample): return self.transform(sample) - - -@under_review() -class GaussianBlur: - # Implements Gaussian blur as described in the SimCLR paper - def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0): - if not _TORCHVISION_AVAILABLE: # pragma: no cover - raise ModuleNotFoundError("You want to use `GaussianBlur` from `cv2` which is not installed yet.") - - self.min = min - self.max = max - - # kernel size is set to be 10% of the image height/width - self.kernel_size = kernel_size - self.p = p - - def __call__(self, sample): - sample = np.array(sample) - - # blur the image with a 50% chance - prob = np.random.random_sample() - - if prob < self.p: - sigma = (self.max - self.min) * np.random.random_sample() + self.min - sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) - - return sample diff --git a/tests/conftest.py b/tests/conftest.py index bf233e2185..6b63d39b70 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector from pytorch_lightning.utilities.imports import _IS_WINDOWS +from pl_bolts.utils import _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_13 from pl_bolts.utils.stability import UnderReviewWarning # GitHub Actions use this path to cache datasets. @@ -27,6 +28,8 @@ def catch_warnings(): with warnings.catch_warnings(): warnings.simplefilter("error") warnings.simplefilter("ignore", UnderReviewWarning) + if _TORCHVISION_AVAILABLE and _TORCHVISION_LESS_THAN_0_13: + warnings.filterwarnings("ignore", "FLIP_LEFT_RIGHT is deprecated", DeprecationWarning) yield diff --git a/tests/models/self_supervised/unit/__init__.py b/tests/models/self_supervised/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/self_supervised/unit/test_transforms.py b/tests/models/self_supervised/unit/test_transforms.py new file mode 100644 index 0000000000..737af74bcb --- /dev/null +++ b/tests/models/self_supervised/unit/test_transforms.py @@ -0,0 +1,55 @@ +import numpy as np +import pytest +import torch +from PIL import Image + +from pl_bolts.models.self_supervised.simclr.transforms import ( + SimCLREvalDataTransform, + SimCLRFinetuneTransform, + SimCLRTrainDataTransform, +) + + +@pytest.mark.parametrize( + "transform_cls", + [pytest.param(SimCLRTrainDataTransform, id="train-data"), pytest.param(SimCLREvalDataTransform, id="eval-data")], +) +def test_simclr_train_data_transform(catch_warnings, transform_cls): + # dummy image + img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8) + img = Image.fromarray(img) + + # size of the generated views + input_height = 96 + transform = transform_cls(input_height=input_height) + views = transform(img) + + # the transform must output a list or a tuple of images + assert isinstance(views, (list, tuple)) + + # the transform must output three images + # (1st view, 2nd view, online evaluation view) + assert len(views) == 3 + + # all views are tensors + assert all(torch.is_tensor(v) for v in views) + + # all views have expected sizes + assert all(v.size(1) == v.size(2) == input_height for v in views) + + +def test_simclr_finetune_transform(catch_warnings): + # dummy image + img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8) + img = Image.fromarray(img) + + # size of the generated views + input_height = 96 + transform = SimCLRFinetuneTransform(input_height=input_height) + view = transform(img) + + # the view generator is a tensor + assert torch.is_tensor(view) + + # view has expected size + assert view.size(1) == view.size(2) == input_height