Skip to content

Commit

Permalink
MoCo CLI uses CIFAR10 so that the tests will pass (ImageNet cannot be…
Browse files Browse the repository at this point in the history
… downloaded automatically)
  • Loading branch information
senarvi committed Jul 2, 2023
1 parent f1c7be3 commit 7c0b0b6
Showing 1 changed file with 23 additions and 66 deletions.
89 changes: 23 additions & 66 deletions src/pl_bolts/models/self_supervised/moco/moco_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning import LightningModule
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch import Tensor, nn, optim
from torch.nn import functional as F # noqa: N812
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

# It seems to be impossible to avoid mypy errors if using import instead of getattr().
# See https://github.com/python/mypy/issues/8823
Expand All @@ -27,12 +26,12 @@
except AttributeError:
LRScheduler = getattr(optim.lr_scheduler, "_LRScheduler")

from pl_bolts.datasets import UnlabeledImagenet
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.metrics import precision_at_k
from pl_bolts.models.self_supervised.moco.utils import concatenate_all, shuffle_batch, sort_batch, validate_batch
from pl_bolts.transforms.self_supervised.moco_transforms import (
MoCo2EvalImagenetTransforms,
MoCo2TrainImagenetTransforms,
MoCo2EvalCIFAR10Transforms,
MoCo2TrainCIFAR10Transforms,
)
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down Expand Up @@ -294,80 +293,38 @@ def _calculate_loss(self, images: Tensor, queue: RepresentationQueue) -> Tuple[T
return loss, acc1, acc5


class Collate:
def __call__(
self, samples: List[Tuple[Tuple[Tensor, Tensor], int]]
) -> Tuple[List[Tuple[Tensor, Tensor]], List[int]]:
return tuple(zip(*samples)) # type: ignore
def collate(samples: List[Tuple[Tuple[Tensor, Tensor], int]]) -> Tuple[List[Tuple[Tensor, Tensor]], List[int]]:
return tuple(zip(*samples)) # type: ignore


class ImageNetDataModule(LightningDataModule):
def __init__(
self,
data_dir: str,
meta_dir: Optional[str] = None,
num_workers: int = 0,
batch_size: int = 32,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use ImageNet dataset loaded from `torchvision`, which is not installed yet."
)

self.data_dir = data_dir
self.num_workers = num_workers
self.meta_dir = meta_dir
self.batch_size = batch_size
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last

def train_dataloader(self, num_images_per_class: int = -1) -> DataLoader:
dataset = UnlabeledImagenet(
self.data_dir,
num_imgs_per_class=num_images_per_class,
meta_dir=self.meta_dir,
split="train",
transform=MoCo2TrainImagenetTransforms(),
)
return DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
collate_fn=Collate(),
class CIFAR10ContrastiveDataModule(CIFAR10DataModule):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(
*args,
train_transforms=MoCo2TrainCIFAR10Transforms(),
val_transforms=MoCo2EvalCIFAR10Transforms(),
**kwargs,
)

def val_dataloader(self, num_images_per_class: int = 50) -> DataLoader:
dataset = UnlabeledImagenet(
self.data_dir,
num_imgs_per_class_val_split=num_images_per_class,
meta_dir=self.meta_dir,
split="val",
transform=MoCo2EvalImagenetTransforms(),
)
def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:
return DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
shuffle=shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
collate_fn=Collate(),
collate_fn=collate,
)


def cli_main() -> None:
LightningCLI(MoCo, ImageNetDataModule, seed_everything_default=42)
try: # Backward compatibility for Lightning CLI
from pytorch_lightning.cli import LightningCLI # PL v1.9+
except ImportError:
from pytorch_lightning.utilities.cli import LightningCLI # PL v1.8

LightningCLI(MoCo, CIFAR10ContrastiveDataModule, seed_everything_default=42)


if __name__ == "__main__":
Expand Down

0 comments on commit 7c0b0b6

Please sign in to comment.