Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update the documentation so that the main example works with s-l 1.0.6 #335

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 80 additions & 66 deletions docs/source/tutorials/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Let's first go through how the library is organized.
Now, let's assume that we want to train Barlow Twins on CIFAR10 for 100 epochs.
For this, we won't use the ``main_pretrain.py`` file directly, but we'll build a minimal version of it in order to give a general overview of the library.

We start by importing everything that we will need (we will be relying on Pytorch Lightning to use our already implemented training/validation steps:
We start by importing everything that we will need (we will be relying on Pytorch Lightning to use our already implemented training/validation steps):

.. code-block:: python

Expand All @@ -30,20 +30,22 @@ We start by importing everything that we will need (we will be relying on Pytorc
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins import DDPPlugin

# solo learn uses omega conf and hydra to manage configs files now
from omegaconf import DictConfig
from solo.methods import BarlowTwins # imports the method class
from solo.utils.checkpointer import Checkpointer

# some data utilities
# we need one dataloader to train an online linear classifier
# (don't worry, the rest of the model has no idea of this classifier, so it doesn't use label info)
from solo.utils.classification_dataloader import prepare_data as prepare_data_classification
from solo.data.classification_dataloader import prepare_data as prepare_classification_dataloader

# and some utilities to perform data loading for the method itself, including augmentation pipelines
from solo.utils.pretrain_dataloader import (
from solo.data.pretrain_dataloader import (
build_transform_pipeline,
prepare_dataloader,
prepare_datasets,
prepare_n_crop_transform,
prepare_transform,
)


Expand All @@ -56,38 +58,39 @@ However, for now, we won't rely on this, so let's just define all the needed par
# common parameters for all methods
# some parameters for extra functionally are missing, but don't mind this for now.
base_kwargs = {
"backbone": "resnet18",
"num_classes": 10,
"name": "barlow_twins-cifar10", # change here for cifar100
"backbone": {
"name": "resnet18",
"kwargs": {}
},
"data": {
"dataset": "cifar10",
"num_classes": 10,
"train_path": "./data", # replace with your own path
"val_path": "./data", # replace with your own path
"num_large_crops": 2, # must equal 2 for barlow twins
"num_small_crops": 0, # must equal 0 for barlow twins
"num_workers": 4,
},
"cifar": True,
"zero_init_residual": True,
"max_epochs": 100,
"optimizer": "sgd",
"lars": True,
"lr": 0.01,
"gpus": "0",
"grad_clip_lars": True,
"weight_decay": 0.00001,
"classifier_lr": 0.5,
"exclude_bias_n_norm_lars": True,
"accumulate_grad_batches": 1,
"extra_optimizer_args": {"momentum": 0.9},
"scheduler": "warmup_cosine",
"min_lr": 0.0,
"warmup_start_lr": 0.0,
"warmup_epochs": 10,
"num_crops_per_aug": [2, 0],
"num_large_crops": 2,
"num_small_crops": 0,
"eta_lars": 0.02,
"lr_decay_steps": None,
"optimizer": {
"name": "lars",
"lr": 0.01,
"batch_size": 256,
"weight_decay": 0.00001,
"classifier_lr": 0.1 # mandatory

},
"scheduler":{
"name": "warmup_cosine",
"min_lr": 0.0,
"warmup_start_lr": 0.0,
"warmup_epochs": 10,
},
"method": "barlow_twins",
"dali_device": "gpu",
"batch_size": 256,
"num_workers": 4,
"data_dir": "/data/datasets",
"train_dir": "cifar10/train",
"val_dir": "cifar10/val",
"dataset": "cifar10",
"name": "barlow-cifar10",
}

# barlow specific parameters
Expand All @@ -99,53 +102,69 @@ However, for now, we won't rely on this, so let's just define all the needed par
"backbone_args": {"cifar": True, "zero_init_residual": True},
}

kwargs = {**base_kwargs, **method_kwargs}

model = BarlowTwins(**kwargs)
cfg = DictConfig({**base_kwargs, "method_kwargs": method_kwargs})
model = BarlowTwins(cfg)


Now, let's create all the necessary data loaders.

.. code-block:: python

# we first prepare our single transformation pipeline
# we first prepare our single transformation pipeline config
transform_kwargs = {
"brightness": 0.4,
"contrast": 0.4,
"saturation": 0.2,
"hue": 0.1,
"gaussian_prob": 0.0,
"solarization_prob": 0.0,
"crop_size": 32,
"num_crops": 1,
"rrc": {
"enabled": True,
"crop_min_scale": 0.08,
"crop_max_scale": 1.0
},
"color_jitter": {
"prob": 0.8,
"brightness": 0.4,
"contrast": 0.4,
"saturation": 0.2,
"hue": 0.1,
},
# all below need to be specified but are unused
"grayscale": {"prob": 0.0},
"gaussian_blur": {"prob": 0.0},
"solarization": {"prob": 0.0},
"equalization": {"prob": 0.0},
"horizontal_flip": {"prob": 0.0},
}
transform = [prepare_transform("cifar10", **transform_kwargs)]
aug_cfg = DictConfig(transform_kwargs)
augs = build_transform_pipeline("cifar10", aug_cfg)


# then, we wrap the pipepline using this utility function
# then, we wrap the pipeline using this utility function
# to make it produce an arbitrary number of crops
transform = prepare_n_crop_transform(transform, num_crops_per_aug=[2])
transform = prepare_n_crop_transform([augs], num_crops_per_aug=[2])

# finally, we produce the Dataset/Dataloader classes
train_dataset = prepare_datasets(
"cifar10",
transform,
data_dir="./",
train_dir=None,
dataset="cifar10",
transform=transform,
train_data_path=base_kwargs["data"]["train_path"],
no_labels=False,
)
train_loader = prepare_dataloader(
train_dataset, batch_size=base_kwargs["batch_size"], num_workers=base_kwargs["num_workers"]
train_dataset=train_dataset,
batch_size=base_kwargs["optimizer"]["batch_size"],
num_workers=base_kwargs["data"]["num_workers"]
)

# we will also create a validation dataloader to automatically
# check how well our models is doing in an online fashion.
_, val_loader = prepare_data_classification(
"cifar10",
data_dir="./",
train_dir=None,
val_dir=None,
batch_size=base_kwargs["batch_size"],
num_workers=base_kwargs["num_workers"],
_, val_loader = prepare_classification_dataloader(
dataset=base_kwargs["data"]["dataset"], # "cifar10"
train_data_path=base_kwargs["data"]["train_path"],
val_data_path=base_kwargs["data"]["val_path"],
batch_size=base_kwargs["optimizer"]["batch_size"],
num_workers=base_kwargs["data"]["num_workers"],
)


Now, we just need to define some extra magic for Pytorch Lightning to automatically log some stuff for us and then we can just create our lightning Trainer.

.. code-block:: python
Expand All @@ -165,26 +184,21 @@ Now, we just need to define some extra magic for Pytorch Lightning to automatica
callbacks.append(lr_monitor)

# checkpointer can automatically log your parameters,
# but we need to wrap it on a Namespace object
from argparse import Namespace

args = Namespace(**kwargs)
# saves the checkout after every epoch
ckpt = Checkpointer(
args,
cfg,
logdir="checkpoints/barlow",
frequency=1,
)
callbacks.append(ckpt)

trainer = Trainer.from_argparse_args(
args,
cfg,
logger=wandb_logger,
callbacks=callbacks,
plugins=DDPPlugin(find_unused_parameters=True),
checkpoint_callback=False,
terminate_on_nan=True,
accelerator="ddp",
accelerator="auto", # use whatever is available
strategy="ddp", # could change depending on your setup
)

trainer.fit(model, train_loader, val_loader)
Expand Down