Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revision of SimCLR transforms #857

Merged
merged 18 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 32 additions & 73 deletions pl_bolts/models/self_supervised/simclr/transforms.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
import numpy as np

from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision import transforms
else: # pragma: no cover
warn_missing_pkg("torchvision")

if _OPENCV_AVAILABLE:
import cv2
else: # pragma: no cover
warn_missing_pkg("cv2", pypi_name="opencv-python")


@under_review()
class SimCLRTrainDataTransform:
"""Transforms for SimCLR.
"""Transforms for SimCLR during training step of the pre-training stage.

Transform::

RandomResizedCrop(size=self.input_height)
RandomHorizontalFlip()
RandomApply([color_jitter], p=0.8)
RandomGrayscale(p=0.2)
GaussianBlur(kernel_size=int(0.1 * self.input_height))
RandomApply([GaussianBlur(kernel_size=int(0.1 * self.input_height))], p=0.5)
transforms.ToTensor()

Example::
Expand All @@ -34,7 +25,7 @@ class SimCLRTrainDataTransform:

transform = SimCLRTrainDataTransform(input_height=32)
x = sample()
(xi, xj) = transform(x)
(xi, xj, xk) = transform(x) # xk is only for the online evaluator if used
"""

def __init__(
Expand Down Expand Up @@ -68,16 +59,16 @@ def __init__(
if kernel_size % 2 == 0:
kernel_size += 1

data_transforms.append(GaussianBlur(kernel_size=kernel_size, p=0.5))
data_transforms.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5))

data_transforms = transforms.Compose(data_transforms)
self.data_transforms = transforms.Compose(data_transforms)

if normalize is None:
self.final_transform = transforms.ToTensor()
else:
self.final_transform = transforms.Compose([transforms.ToTensor(), normalize])

self.train_transform = transforms.Compose([data_transforms, self.final_transform])
self.train_transform = transforms.Compose([self.data_transforms, self.final_transform])

# add online train transform of the size of global view
self.online_transform = transforms.Compose(
Expand All @@ -93,9 +84,8 @@ def __call__(self, sample):
return xi, xj, self.online_transform(sample)


@under_review()
class SimCLREvalDataTransform(SimCLRTrainDataTransform):
"""Transforms for SimCLR.
"""Transforms for SimCLR during the validation step of the pre-training stage.

Transform::

Expand All @@ -109,7 +99,7 @@ class SimCLREvalDataTransform(SimCLRTrainDataTransform):

transform = SimCLREvalDataTransform(input_height=32)
x = sample()
(xi, xj) = transform(x)
(xi, xj, xk) = transform(x) # xk is only for the online evaluator if used
"""

def __init__(
Expand All @@ -129,70 +119,39 @@ def __init__(
)


@under_review()
class SimCLRFinetuneTransform:
class SimCLRFinetuneTransform(SimCLRTrainDataTransform):
"""Transforms for SimCLR during the fine-tuning stage.

Transform::

Resize(input_height + 10, interpolation=3)
transforms.CenterCrop(input_height),
transforms.ToTensor()

Example::

from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform

transform = SimCLREvalDataTransform(input_height=32)
x = sample()
xk = transform(x)
"""

def __init__(
self, input_height: int = 224, jitter_strength: float = 1.0, normalize=None, eval_transform: bool = False
) -> None:

self.jitter_strength = jitter_strength
self.input_height = input_height
self.normalize = normalize

self.color_jitter = transforms.ColorJitter(
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.8 * self.jitter_strength,
0.2 * self.jitter_strength,
super().__init__(
normalize=normalize, input_height=input_height, gaussian_blur=None, jitter_strength=jitter_strength
)

if not eval_transform:
data_transforms = [
transforms.RandomResizedCrop(size=self.input_height),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([self.color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
]
else:
data_transforms = [
if eval_transform:
self.data_transforms = [
transforms.Resize(int(self.input_height + 0.1 * self.input_height)),
transforms.CenterCrop(self.input_height),
]

if normalize is None:
final_transform = transforms.ToTensor()
else:
final_transform = transforms.Compose([transforms.ToTensor(), normalize])

data_transforms.append(final_transform)
self.transform = transforms.Compose(data_transforms)
self.transform = transforms.Compose([self.data_transforms, self.final_transform])

def __call__(self, sample):
return self.transform(sample)


@under_review()
class GaussianBlur:
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0):
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `GaussianBlur` from `cv2` which is not installed yet.")

self.min = min
self.max = max

# kernel size is set to be 10% of the image height/width
self.kernel_size = kernel_size
self.p = p

def __call__(self, sample):
sample = np.array(sample)

# blur the image with a 50% chance
prob = np.random.random_sample()

if prob < self.p:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)

return sample
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.utilities.imports import _IS_WINDOWS

from pl_bolts.utils import _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_13
from pl_bolts.utils.stability import UnderReviewWarning

# GitHub Actions use this path to cache datasets.
Expand All @@ -27,6 +28,8 @@ def catch_warnings():
with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.simplefilter("ignore", UnderReviewWarning)
if _TORCHVISION_AVAILABLE and _TORCHVISION_LESS_THAN_0_13:
warnings.filterwarnings("ignore", "FLIP_LEFT_RIGHT is deprecated", DeprecationWarning)
yield


Expand Down
Empty file.
55 changes: 55 additions & 0 deletions tests/models/self_supervised/unit/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
import pytest
import torch
from PIL import Image

from pl_bolts.models.self_supervised.simclr.transforms import (
SimCLREvalDataTransform,
SimCLRFinetuneTransform,
SimCLRTrainDataTransform,
)


@pytest.mark.parametrize(
"transform_cls",
[pytest.param(SimCLRTrainDataTransform, id="train-data"), pytest.param(SimCLREvalDataTransform, id="eval-data")],
)
def test_simclr_train_data_transform(catch_warnings, transform_cls):
# dummy image
img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8)
img = Image.fromarray(img)

# size of the generated views
input_height = 96
transform = transform_cls(input_height=input_height)
views = transform(img)

# the transform must output a list or a tuple of images
assert isinstance(views, (list, tuple))

# the transform must output three images
# (1st view, 2nd view, online evaluation view)
assert len(views) == 3

# all views are tensors
assert all(torch.is_tensor(v) for v in views)

# all views have expected sizes
assert all(v.size(1) == v.size(2) == input_height for v in views)


def test_simclr_finetune_transform(catch_warnings):
# dummy image
img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8)
img = Image.fromarray(img)

# size of the generated views
input_height = 96
transform = SimCLRFinetuneTransform(input_height=input_height)
view = transform(img)

# the view generator is a tensor
assert torch.is_tensor(view)

# view has expected size
assert view.size(1) == view.size(2) == input_height