Generalized Contrastive Alignment (GCA) provides a robust framework for self-supervised learning tasks, supporting various datasets and augmentation methods.
To set up the required environment, follow these steps:
# Create and activate the environment
conda create -n GCA python=3.11.9 -y
conda activate GCA
# Install dependencies
pip install hydra-core numpy==1.26.4 matplotlib seaborn scikit-image scikit-learn \
pytorch-lightning==1.9.5 torch==2.2.1 torchaudio==2.2.1 \
torchmetrics==1.4.2 torchvision==0.17.1
- SimCLR CIFAR10 Implementation by Damrich et al.: GitHub Link
- SimCLR by Ting Chen et al.: GitHub Link
- IOT in Liangliang Shi, et al. "Understanding and generalizing contrastive learning from the inverse optimal transport perspective." ICML, 2023.
The framework supports the following tasks:
simclr
hs_ince
gca_ince
rince
gca_rince
gca_uot
The following datasets are supported:
SVHN
imagenet100
cifar100
cifar10
You can configure data augmentation using the strong_DA
option:
None
(standard augmentation)large_erase
brightness
strong_crop
To pretrain a model using self-supervised learning, run the following script:
python ssl_pretrain.py \
--config-name "simclr_cifar10.yaml" \
--config-path "./config/" \
task=gca_uot \
dataset=cifar10 \
dataset_dir="./datasets" \
batch_size=512 \
seed=64 \
backbone=resnet18 \
projection_dim=128 \
strong_DA=None \
gpus=1 \
workers=16 \
optimizer='Adam' \
learning_rate=0.03 \
momentum=0.9 \
weight_decay=1e-6 \
lam=0.01 \
q=0.6 \
max_epochs=500
To evaluate the pretrained model with a linear classifier, use the following script:
python linear_evaluation.py \
--config-name="simclr_cifar10.yaml" \
--config-path="./config/" \
task=gca_uot \
dataset=cifar10 \
batch_size=512 \
seed=64 \
backbone=resnet18 \
projection_dim=128 \
strong_DA=None \
lam=0.01 \
q=0.6 \
load_epoch=500
- Task: Specify the self-supervised learning task (e.g.,
gca_uot
). - Dataset: Choose from supported datasets (e.g.,
cifar10
). - Data Augmentation: Use
strong_DA
to set augmentation type. - Training Parameters:
batch_size
: Batch size for training.backbone
: Backbone architecture (e.g.,resnet18
).projection_dim
: Dimension of the projection head.lam
andq
: Regularization and scaling parameters.max_epochs
: Maximum number of epochs for training.
- Ensure that the
dataset_dir
contains the datasets in the correct structure. - Customize parameters in the scripts to fit your experimental needs. As an example for SVHN,
python ssl_pretrain.py \
--config-name "simclr_svhn.yaml" \
--config-path "./config/" \
task=gca_uot \
dataset=SVHN \
dataset_dir="./datasets" \
batch_size=512 \
seed=48 \
backbone=resnet18 \
projection_dim=128 \
strong_DA=None \
gpus=1 \
workers=16 \
optimizer='Adam' \
learning_rate=0.03 \
momentum=0.9 \
weight_decay=1e-6 \
lam=0.01 \
q=0.6 \
max_epochs=500 \
relax_item1=1 \
relax_item2=0.01