Skip to content

Commit

Permalink
update pretraining scripts, ckpts, and logs
Browse files Browse the repository at this point in the history
  • Loading branch information
fwtan committed Jul 22, 2023
1 parent 0901630 commit 0934b88
Show file tree
Hide file tree
Showing 41 changed files with 167,694 additions and 20 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ coverage.xml
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
Expand Down
43 changes: 41 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# [Effective Self-supervised Pre-training on Low-compute networks without Distillation](https://openreview.net/forum?id=cbpRzMy-UZH)
# [Effective Self-supervised Pre-training on Low-compute networks without Distillation](https://arxiv.org/abs/2210.02808)
Fuwen Tan, Fatemeh Saleh, Brais Martinez, ICLR 2023.

## Abstract
Expand All @@ -22,12 +22,14 @@ This repo supports pre-training [DINO|SwAV|MoCo] with [MobileNet V2|ResNets|ViTs
python3 main.py --cfg config/exp_yamls/dino/dino_cnn_sslight.yaml DATA.PATH_TO_DATA_DIR $IN1K_PATH OUTPUT_DIR $OUTPUT_PATH
```

In order to assess the quality of features during pre-training, an additional linear classifier can be trained on the separated features. This ensures that the gradient from the linear classifier does not interfere with the feature learning process:
To assess the quality of features during pre-training, an additional linear classifier can be trained on the separated features. This ensures that the gradient from the linear classifier does not interfere with the feature learning process:

```
python3 main.py --cfg config/exp_yamls/dino/dino_cnn_sslight.yaml DATA.PATH_TO_DATA_DIR $IN1K_PATH OUTPUT_DIR $OUTPUT_PATH TRAIN.JOINT_LINEAR_PROBE True
```

Note that the accuracy of this extra classifier is typically lower than a standard linear probing evaluation.

The table below includes the scripts for specific pre-training experiments:


Expand All @@ -38,54 +40,91 @@ The table below includes the scripts for specific pre-training experiments:
<th valign="bottom">Backbone</th>
<th valign="bottom">IN1K Linear Accu.</th>
<th valign="bottom">Training</th>
<th valign="bottom">Log (re-trained)</th>
<!-- TABLE BODY -->
<tr>
<td align="left">DINO baseline</td>
<td align="center">MobileNet V2</td>
<td align="center">66.2</td>
<td align="center"><a href=src/experiments/dino/mnv2/baseline.sh>script</a></td>
<td align="center"><a href=src/experiments/dino/mnv2/baseline.log>log</a></td>
</tr>
<tr>
<td align="left">DINO SSLight</td>
<td align="center">MobileNet V2</td>
<td align="center">68.3 (+2.1)</td>
<td align="center"><a href=src/experiments/dino/mnv2/sslight.sh>script</a></td>
<td align="center"><a href=src/experiments/dino/mnv2/sslight.log>log</a></td>
</tr>
<tr>
<td align="left">DINO baseline</td>
<td align="center">ResNet18</td>
<td align="center">62.2</td>
<td align="center"><a href=src/experiments/dino/resnet18/baseline.sh>script</a></td>
<td align="center"><a href=src/experiments/dino/resnet18/baseline.log>log</a></td>
</tr>
<tr>
<td align="left">DINO SSLight</td>
<td align="center">ResNet18</td>
<td align="center">65.7 (+3.5)</td>
<td align="center"><a href=src/experiments/dino/resnet18/sslight.sh>script</a></td>
<td align="center"><a href=src/experiments/dino/resnet18/sslight.log>log</a></td>
</tr>
<tr>
<td align="left">DINO baseline</td>
<td align="center">ResNet34</td>
<td align="center">67.7</td>
<td align="center"><a href=src/experiments/dino/resnet34/baseline.sh>script</a></td>
<td align="center"><a href=src/experiments/dino/resnet34/baseline.log>log</a></td>
</tr>
<tr>
<td align="left">DINO SSLight</td>
<td align="center">ResNet34</td>
<td align="center">69.7 (+2.0)</td>
<td align="center"><a href=src/experiments/dino/resnet34/sslight.sh>script</a></td>
<td align="center"><a href=src/experiments/dino/resnet34/sslight.log>log</a></td>
</tr>
<tr>
<td align="left">DINO baseline</td>
<td align="center">ViT-Tiny/16</td>
<td align="center">66.7</td>
<td align="center"><a href=src/experiments/dino/vit_tiny_16/baseline.sh>script</a></td>
<td align="center"><a href=src/experiments/dino/vit_tiny_16/baseline.log>log</a></td>
</tr>
<tr>
<td align="left">DINO SSLight</td>
<td align="center">ViT-Tiny/16</td>
<td align="center">69.5 (+2.8)</td>
<td align="center"><a href=src/experiments/dino/vit_tiny_16/sslight.sh>script</a></td>
<td align="center"><a href=src/experiments/dino/vit_tiny_16/sslight.log>log</a></td>
</tr>
<tr>
<td align="left">SWAV baseline</td>
<td align="center">MobileNet V2</td>
<td align="center">65.2</td>
<td align="center"><a href=src/experiments/swav/mnv2/baseline.sh>script</a></td>
<td align="center"><a href=src/experiments/swav/mnv2/baseline.log>log</a></td>
</tr>
<tr>
<td align="left">SWAV SSLight</td>
<td align="center">MobileNet V2</td>
<td align="center">67.3 (+2.1)</td>
<td align="center"><a href=src/experiments/swav/mnv2/sslight.sh>script</a></td>
<td align="center"><a href=src/experiments/swav/mnv2/sslight.log>log</a></td>
</tr>
<tr>
<td align="left">MoCo baseline</td>
<td align="center">MobileNet V2</td>
<td align="center">60.6 </td>
<td align="center"><a href=src/experiments/moco/mnv2/baseline.sh>script</a></td>
<td align="center"><a href=src/experiments/moco/mnv2/baseline.log>log</a></td>
</tr>
<tr>
<td align="left">MoCo SSLight</td>
<td align="center">MobileNet V2</td>
<td align="center">61.6 (+1.0)</td>
<td align="center"><a href=src/experiments/moco/mnv2/sslight.sh>script</a></td>
<td align="center"><a href=src/experiments/moco/mnv2/sslight.log>log</a></td>
</tr>
</tbody></table>

Expand Down
2 changes: 1 addition & 1 deletion src/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
# ---------------------------------------------------------------------------- #
_C.MOCO = CfgNode()

_C.MOCO.TEMPERATURE = 0.02
_C.MOCO.TEMPERATURE = 0.2

_C.MOCO.GLOBAL_ONLY = False

Expand Down
52 changes: 52 additions & 0 deletions src/config/exp_yamls/moco/moco_cnn_baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
TRAIN:
BATCH_SIZE: 128
CHECKPOINT_PERIOD: 20
MULTI_VIEWS_TRANSFORMS:
GLOBAL_CROPS_SCALE: [0.14, 1.0]
LOCAL_CROPS_SCALE: [0.05, 0.14]
LOCAL_CROPS_NUMBER: 6
RANDOM_HORIZONTAL_FLIP_PROB: [0.5, 0.5]
GAUSSIAN_BLUR_PROB: [1.0, 0.1, 0.5] # global_1, global_2, local
COLOR_JITTER_PROB: [0.8, 0.8]
COLOR_JITTER_INTENSITY: [0.4, 0.4, 0.2, 0.1]
GREYSCALE_PROB: [0.2, 0.2]
CROP_PROB: [1.0, 1.0]
LOCAL_CROP_SIZE: 96
GLOBAL_CROP_SIZE: 224
SOLARIZATION_PROB: [0.0, 0.2]
CROPS_FOR_ASSIGN: [0, 1]
NMB_CROPS: [2, 6]
LAMBDAS: [0.142857, 0.857143]
MODEL:
BACKBONE_ARCH: mobilenet_v2
MODEL_MOMENTUM: 0.996
MOCO:
GLOBAL_ONLY: False
TEMPERATURE: 0.2
QUEUE_LENGTH: 65536
MOMENTUM: 0.999
OUTPUT_DIM: 128
HIDDEN_SIZE: 2048
NUM_LAYERS: 2
USE_BN_IN_HEAD: False
SOLVER:
OPTIMIZING_METHOD: LARS
WEIGHT_DECAY: 1.0e-6
WEIGHT_DECAY_END: 1.0e-6
TOTAL_EPOCHS: 200
BASE_LR: 0.48
MIN_LR: 0.0048
MOMENTUM: 0.9
CLIP_GRAD: 0.0
LOG_STEP: 20
SSL_METHOD: MOCO
STAGE: TRAIN
DISTRIBUTED: True
SEED: 0
DIST_BACKEND: 'nccl'
N_NODES: 1
WORLD_SIZE: 1
NODE_RANK: 0
DIST_URL: "env://"
WORKERS: 4
USE_FP16: False
52 changes: 52 additions & 0 deletions src/config/exp_yamls/moco/moco_cnn_sslight.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
TRAIN:
BATCH_SIZE: 128
CHECKPOINT_PERIOD: 20
MULTI_VIEWS_TRANSFORMS:
GLOBAL_CROPS_SCALE: [0.3, 1.0]
LOCAL_CROPS_SCALE: [0.05, 0.3]
LOCAL_CROPS_NUMBER: 6
RANDOM_HORIZONTAL_FLIP_PROB: [0.5, 0.5]
GAUSSIAN_BLUR_PROB: [1.0, 0.1, 0.5] # global_1, global_2, local
COLOR_JITTER_PROB: [0.8, 0.8]
COLOR_JITTER_INTENSITY: [0.4, 0.4, 0.2, 0.1]
GREYSCALE_PROB: [0.2, 0.2]
CROP_PROB: [1.0, 1.0]
LOCAL_CROP_SIZE: 128
GLOBAL_CROP_SIZE: 224
SOLARIZATION_PROB: [0.0, 0.2]
CROPS_FOR_ASSIGN: [0, 1]
NMB_CROPS: [2, 6]
LAMBDAS: [0.4, 0.6]
MODEL:
BACKBONE_ARCH: mobilenet_v2
MODEL_MOMENTUM: 0.996
MOCO:
GLOBAL_ONLY: False
TEMPERATURE: 0.2
QUEUE_LENGTH: 65536
MOMENTUM: 0.999
OUTPUT_DIM: 128
HIDDEN_SIZE: 2048
NUM_LAYERS: 2
USE_BN_IN_HEAD: False
SOLVER:
OPTIMIZING_METHOD: LARS
WEIGHT_DECAY: 1.0e-6
WEIGHT_DECAY_END: 1.0e-6
TOTAL_EPOCHS: 200
BASE_LR: 0.48
MIN_LR: 0.0048
MOMENTUM: 0.9
CLIP_GRAD: 0.0
LOG_STEP: 20
SSL_METHOD: MOCO
STAGE: TRAIN
DISTRIBUTED: True
SEED: 0
DIST_BACKEND: 'nccl'
N_NODES: 1
WORLD_SIZE: 1
NODE_RANK: 0
DIST_URL: "env://"
WORKERS: 4
USE_FP16: False
56 changes: 56 additions & 0 deletions src/config/exp_yamls/swav/swav_cnn_baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
TRAIN:
BATCH_SIZE: 128
CHECKPOINT_PERIOD: 20
MULTI_VIEWS_TRANSFORMS:
GLOBAL_CROPS_SCALE: [0.14, 1.0]
LOCAL_CROPS_SCALE: [0.05, 0.14]
LOCAL_CROPS_NUMBER: 6
RANDOM_HORIZONTAL_FLIP_PROB: [0.5, 0.5]
GAUSSIAN_BLUR_PROB: [0.5, 0.5, 0.5] # global_1, global_2, local
COLOR_JITTER_PROB: [0.8, 0.8]
COLOR_JITTER_INTENSITY: [0.8, 0.8, 0.8, 0.2]
GREYSCALE_PROB: [0.2, 0.2]
CROP_PROB: [1.0, 1.0]
LOCAL_CROP_SIZE: 96
GLOBAL_CROP_SIZE: 224
SOLARIZATION_PROB: [0.0, 0.0]
CROPS_FOR_ASSIGN: [0, 1]
NMB_CROPS: [2, 6]
LAMBDAS: [0.142857, 0.857143]
MODEL:
BACKBONE_ARCH: mobilenet_v2
SOLVER:
OPTIMIZING_METHOD: LARS
WEIGHT_DECAY: 1.0e-6
WEIGHT_DECAY_END: 1.0e-6
TOTAL_EPOCHS: 200
WARMUP_EPOCHS: 10
START_WARMUP: 0.075
BASE_LR: 0.3
MOMENTUM: 0.9
MIN_LR: 0.0048
SWAV:
FREEZE_PROTOTYPES_EPOCHS: 1
EPOCH_QUEUE_STARTS: 15
TEMPERATURE: 0.1
EPSILON: 0.05
SINKHORN_ITERATIONS: 3
NMB_PROTOTYPES: 3000
QUEUE_LENGTH: 3840
OUTPUT_DIM: 128
HIDDEN_SIZE: 4096
NUM_LAYERS: 2
USE_BN_IN_HEAD: True
SSL_METHOD: SWAV
LOG_STEP: 20
# stage from [TRAIN, VAL, TEST, FT]
STAGE: TRAIN
DISTRIBUTED: True
SEED: 0
DIST_BACKEND: 'nccl'
N_NODES: 1
WORLD_SIZE: 1
NODE_RANK: 0
DIST_URL: "env://"
WORKERS: 4
USE_FP16: False
56 changes: 56 additions & 0 deletions src/config/exp_yamls/swav/swav_cnn_sslight.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
TRAIN:
BATCH_SIZE: 128
CHECKPOINT_PERIOD: 20
MULTI_VIEWS_TRANSFORMS:
GLOBAL_CROPS_SCALE: [0.3, 1.0]
LOCAL_CROPS_SCALE: [0.05, 0.3]
LOCAL_CROPS_NUMBER: 6
RANDOM_HORIZONTAL_FLIP_PROB: [0.5, 0.5]
GAUSSIAN_BLUR_PROB: [0.5, 0.5, 0.5] # global_1, global_2, local
COLOR_JITTER_PROB: [0.8, 0.8]
COLOR_JITTER_INTENSITY: [0.8, 0.8, 0.8, 0.2]
GREYSCALE_PROB: [0.2, 0.2]
CROP_PROB: [1.0, 1.0]
LOCAL_CROP_SIZE: 128
GLOBAL_CROP_SIZE: 224
SOLARIZATION_PROB: [0.0, 0.0]
CROPS_FOR_ASSIGN: [0, 1]
NMB_CROPS: [2, 6]
LAMBDAS: [0.4, 0.6]
MODEL:
BACKBONE_ARCH: mobilenet_v2
SOLVER:
OPTIMIZING_METHOD: LARS
WEIGHT_DECAY: 1.0e-6
WEIGHT_DECAY_END: 1.0e-6
TOTAL_EPOCHS: 200
WARMUP_EPOCHS: 10
START_WARMUP: 0.075
BASE_LR: 0.3
MOMENTUM: 0.9
MIN_LR: 0.0048
SWAV:
FREEZE_PROTOTYPES_EPOCHS: 1
EPOCH_QUEUE_STARTS: 15
TEMPERATURE: 0.1
EPSILON: 0.05
SINKHORN_ITERATIONS: 3
NMB_PROTOTYPES: 3000
QUEUE_LENGTH: 3840
OUTPUT_DIM: 128
HIDDEN_SIZE: 4096
NUM_LAYERS: 2
USE_BN_IN_HEAD: True
SSL_METHOD: SWAV
LOG_STEP: 20
# stage from [TRAIN, VAL, TEST, FT]
STAGE: TRAIN
DISTRIBUTED: True
SEED: 0
DIST_BACKEND: 'nccl'
N_NODES: 1
WORLD_SIZE: 1
NODE_RANK: 0
DIST_URL: "env://"
WORKERS: 4
USE_FP16: False
2 changes: 1 addition & 1 deletion src/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,4 @@ def __call__(self, image):

DataAugmentationSWAV = MultiCropsDataAugmentation
DataAugmentationDINO = MultiCropsDataAugmentation
DataAugmentationMoCo = MultiCropsDataAugmentation
DataAugmentationMOCO = MultiCropsDataAugmentation
1 change: 1 addition & 0 deletions src/datasets
Loading

0 comments on commit 0934b88

Please sign in to comment.