Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AIM Model from Scalable Pre-training of Large Autoregressive Image Models #1479

Merged
merged 71 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
951d43e
Add MAE evaluation
guarin May 30, 2023
503cc44
Add stochastic depth dropout
guarin May 31, 2023
ac43499
Add MAE
guarin May 31, 2023
15bfe3a
Drop assertion
guarin May 31, 2023
49c85c0
Fix smooth cross entropy loss and mixup
guarin May 31, 2023
9d95783
Update comments
guarin May 31, 2023
0bb601d
Add layer lr decay and weight decay
guarin Jun 5, 2023
d7d69af
Update comment
guarin Jun 5, 2023
ec05437
Add test for MAE images_to_tokens
guarin Jun 5, 2023
923a606
Disable BN update
guarin Jun 5, 2023
bdce8a6
Add BN before classification head
guarin Jun 6, 2023
316f918
Format
guarin Jun 6, 2023
a6943fd
Fix BN freezing
guarin Jun 6, 2023
1a2b454
Cleanup
guarin Jun 6, 2023
bc066ae
Use torch.no_grad instead of deactivating gradients manually
guarin Jun 6, 2023
d56e340
Create new stochastic depth instances
guarin Jun 6, 2023
5ed6803
Add mask token to learnable params
guarin Jun 6, 2023
4f0baf1
Add sine-cosine positional embedding
guarin Jun 6, 2023
9c4a8cf
Initialize parameters as in paper
guarin Jun 6, 2023
9904c10
Merge branch 'master' into guarin-lig-3056-add-mae-imagenet-benchmark
guarin Dec 6, 2023
83edd1c
Fix types
guarin Dec 6, 2023
e27946e
Format
guarin Dec 6, 2023
0672b0a
Merge branch 'guarin-lig-3056-add-mae-imagenet-benchmark' of github.c…
ersi-lightly Dec 17, 2023
45433c5
adjusted to existing interface
ersi-lightly Dec 18, 2023
c5cab9e
draft
ersi-lightly Dec 19, 2023
017168e
remove
ersi-lightly Dec 19, 2023
278423b
added modifications
ersi-lightly Jan 4, 2024
fde116c
added mae implementation with timm and example
ersi-lightly Jan 5, 2024
f008645
formatted
ersi-lightly Jan 5, 2024
c97112e
fixed import
ersi-lightly Jan 5, 2024
2e55d6b
removed
ersi-lightly Jan 5, 2024
484add1
fixed typing
ersi-lightly Jan 5, 2024
af9b76a
modified imagenette benchmark
ersi-lightly Jan 8, 2024
a0639a7
formatted
ersi-lightly Jan 8, 2024
57762b8
edited vitb16 benchmark
ersi-lightly Jan 8, 2024
971c19a
addressed comments
ersi-lightly Jan 9, 2024
1ec7470
fixed typing and formatted
ersi-lightly Jan 9, 2024
76ee356
addressed comments
ersi-lightly Jan 9, 2024
edb2d42
added docstring and formatted
ersi-lightly Jan 9, 2024
f00d320
removed images to tokens method
ersi-lightly Jan 10, 2024
b5b0ab5
Merge branch 'ersi-lig-3912-refactor-mae-to-use-timm-vit' into ersi-l…
ersi-lightly Jan 12, 2024
a79738e
added the posibility to handle images of different sizes
ersi-lightly Jan 12, 2024
9564748
formatted
ersi-lightly Jan 12, 2024
e30ca7b
removed comments
ersi-lightly Jan 12, 2024
43e48d4
revert
ersi-lightly Jan 12, 2024
7c1e477
changed import
ersi-lightly Jan 12, 2024
658b1fd
initialize class token
ersi-lightly Jan 12, 2024
304bbb6
specified that class token should be used
ersi-lightly Jan 12, 2024
90aa4f7
chabged architecture
ersi-lightly Jan 15, 2024
d795b68
addressed comments
ersi-lightly Jan 19, 2024
a583747
formatted
ersi-lightly Jan 19, 2024
d246793
Add AIM
guarin Jan 19, 2024
c0374c7
Add aim to benchmarks
guarin Jan 19, 2024
1163cf2
Fix depth
guarin Jan 19, 2024
2ce38d1
Update args
guarin Jan 19, 2024
f82846a
Update head architecture
guarin Jan 22, 2024
bf613e6
Update masking
guarin Jan 22, 2024
78b391d
Add benchmark arguments
guarin Jan 22, 2024
0c7936a
Set default float32 matmul precision
guarin Jan 22, 2024
39b09f4
Remove MAE changes
guarin Jan 23, 2024
ce6060e
Update docstrings
guarin Jan 23, 2024
48e8513
Fix types
guarin Jan 23, 2024
22e3ac5
Merge branch 'master' into aim
guarin Jan 23, 2024
2fca7f2
More type ignore for timm
guarin Jan 23, 2024
64f6442
Merge branch 'aim' of https://github.com/lightly-ai/lightly into aim
guarin Jan 23, 2024
832a24d
More type ignores
guarin Jan 23, 2024
d16da74
More type ignore
guarin Jan 23, 2024
0d00f93
Drop extra norm
guarin Jan 23, 2024
aa3dde8
Disable bias in head
guarin Jan 23, 2024
299db09
Move mask creation to method
guarin Jan 23, 2024
6b93bf5
Add optional mask in forward
guarin Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/imagenet/resnet50/finetune_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from lightly.utils.scheduler import CosineWarmupScheduler


class FinetuneLinearClassifier(LinearClassifier):
class FinetuneEvalClassifier(LinearClassifier):
def configure_optimizers(self):
parameters = list(self.classification_head.parameters())
parameters += self.model.parameters()
Expand Down Expand Up @@ -119,7 +119,7 @@ def finetune_eval(
strategy="ddp_find_unused_parameters_true",
num_sanity_val_steps=0,
)
classifier = FinetuneLinearClassifier(
classifier = FinetuneEvalClassifier(
model=model,
batch_size_per_device=batch_size_per_device,
feature_dim=2048,
Expand Down
161 changes: 161 additions & 0 deletions benchmarks/imagenet/vitb16/aim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import List, Tuple, Union

import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import MSELoss
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer

from lightly.models import utils
from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer
from lightly.models.utils import get_2d_sincos_pos_embed, random_prefix_mask
from lightly.transforms import AIMTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.scheduler import CosineWarmupScheduler


class AIM(LightningModule):
def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
super().__init__()
self.save_hyperparameters()
self.batch_size_per_device = batch_size_per_device

img_size = 224
self.patch_size = 14
self.num_patches = (img_size // self.patch_size) ** 2

vit = MaskedCausalVisionTransformer(
img_size=img_size,
patch_size=self.patch_size,
num_classes=num_classes,
embed_dim=1536,
depth=24,
num_heads=12,
qk_norm=False,
class_token=False,
no_embed_class=True,
)
# Use absolute positional embedding.
pos_embed = get_2d_sincos_pos_embed(
embed_dim=vit.embed_dim,
grid_size=int(self.num_patches**0.5),
cls_token=False,
)
vit.pos_embed.requires_grad = False
vit.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

self.backbone = vit
self.projection_head = AIMPredictionHead(
ersi-lightly marked this conversation as resolved.
Show resolved Hide resolved
input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2
)

self.criterion = MSELoss()

self.online_classifier = OnlineLinearClassifier(
feature_dim=vit.embed_dim, num_classes=num_classes
)

def forward(self, x: Tensor) -> Tensor:
features = self.backbone.forward_features(x)
ersi-lightly marked this conversation as resolved.
Show resolved Hide resolved
# TODO: We use mean aggregation for simplicity. The paper uses
# AttentionPoolingClassifier to get the class features. But this is not great
# as it requires training an additional head.
# https://github.com/apple/ml-aim/blob/1eaedecc4d584f2eb7c6921212d86a3a694442e1/aim/torch/layers.py#L337
return features.mean(dim=1).flatten(start_dim=1)

def training_step(
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
images = images[0] # images is a list containing only one view
batch_size = images.shape[0]

mask = random_prefix_mask(
size=(batch_size, self.num_patches),
max_prefix_length=self.num_patches - 1,
device=images.device,
)
features = self.backbone.forward_features(images, mask=mask)
# Add positional embedding before head.
features = self.backbone._pos_embed(features)
ersi-lightly marked this conversation as resolved.
Show resolved Hide resolved
predictions = self.projection_head(features)

# Convert images to patches and normalize them.
patches = utils.patchify(images, self.patch_size)
mean = patches.mean(dim=-1, keepdim=True)
var = patches.var(dim=-1, keepdim=True)
patches = (patches - mean) / (var + 1.0e-6) ** 0.5

loss = self.criterion(predictions, patches)

self.log(
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
)

# TODO: We could use AttentionPoolingClassifier instead of mean aggregation:
# https://github.com/apple/ml-aim/blob/1eaedecc4d584f2eb7c6921212d86a3a694442e1/aim/torch/layers.py#L337
cls_features = features.mean(dim=1).flatten(start_dim=1)
cls_loss, cls_log = self.online_classifier.training_step(
(cls_features.detach(), targets), batch_idx
)
self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
return loss + cls_loss

def validation_step(
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
cls_features = self.forward(images).flatten(start_dim=1)
cls_loss, cls_log = self.online_classifier.validation_step(
(cls_features.detach(), targets), batch_idx
)
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
return cls_loss

def configure_optimizers(self):
# Don't use weight decay for batch norm, bias parameters, and classification
# head to improve performance.
params, params_no_weight_decay = utils.get_weight_decay_parameters(
[self.backbone, self.projection_head]
)
optimizer = AdamW(
[
{"name": "aim", "params": params},
{
"name": "aim_no_weight_decay",
"params": params_no_weight_decay,
"weight_decay": 0.0,
},
{
"name": "online_classifier",
"params": self.online_classifier.parameters(),
"weight_decay": 0.0,
},
],
lr=0.001 * self.batch_size_per_device * self.trainer.world_size / 4096,
weight_decay=0.05,
betas=(0.9, 0.95),
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=31250 / 125000 * self.trainer.estimated_stepping_batches,
max_epochs=self.trainer.estimated_stepping_batches,
),
"interval": "step",
}
return [optimizer], [scheduler]

def configure_gradient_clipping(
self,
optimizer: Optimizer,
gradient_clip_val: Union[int, float, None] = None,
gradient_clip_algorithm: Union[str, None] = None,
) -> None:
self.clip_gradients(
optimizer=optimizer, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
)


transform = AIMTransform()
214 changes: 214 additions & 0 deletions benchmarks/imagenet/vitb16/finetune_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from pathlib import Path
from typing import Dict, Tuple

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from timm.data import create_transform
from timm.data.mixup import Mixup
from torch import Tensor
from torch.nn import Module
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchvision import transforms as T

from lightly.data import LightlyDataset
from lightly.models import utils
from lightly.models.utils import add_stochastic_depth_to_blocks
from lightly.transforms.utils import IMAGENET_NORMALIZE
from lightly.utils.benchmarking import LinearClassifier, MetricCallback
from lightly.utils.benchmarking.topk import mean_topk_accuracy
from lightly.utils.scheduler import CosineWarmupScheduler


class FinetuneEvalClassifier(LinearClassifier):
# Parameters follow MAE settings.
# Adapt initialization to include mixup.
def __init__(
self,
model: Module,
batch_size_per_device: int,
feature_dim: int,
num_classes: int,
topk: Tuple[int, ...] = (1, 5),
freeze_model: bool = False,
) -> None:
super().__init__(
model, batch_size_per_device, feature_dim, num_classes, topk, freeze_model
)
# Add path dropout.
add_stochastic_depth_to_blocks(self.model, prob=0.1)
# Add mixup and cutmix.
self.mixup = Mixup(
mixup_alpha=0.8,
cutmix_alpha=1.0,
label_smoothing=0.1,
num_classes=num_classes,
)

# Adapt step to include mixup.
def shared_step(self, batch, batch_idx) -> Tuple[Tensor, Dict[int, Tensor]]:
images, targets = batch[0], batch[1]
if self.trainer.state.stage == "train":
images, targets = self.mixup(images, targets)
predictions = self.forward(images)
loss = self.criterion(predictions, targets)
_, predicted_labels = predictions.topk(max(self.topk))
# Pass targets without mixup for topk accuracy calculation.
topk = mean_topk_accuracy(predicted_labels, batch[1], k=self.topk)
return loss, topk

# Adapt optimizer to match MAE settings. Parameters follow the original code from
# the authors: https://github.com/facebookresearch/mae/blob/main/FINETUNE.md#fine-tuning
# Note that lr and layerwise_lr_decay for ViT-B/16 are 1e-3 and 0.75 in the paper
# but 5e-4 and 0.65 in the code.
def configure_optimizers(self):
lr = 5e-4 * self.batch_size_per_device * self.trainer.world_size / 256
layerwise_lr_decay = 0.65

# Group parameters by weight decay and learning rate.
param_groups = {}
for name, module in utils.get_named_leaf_modules(self.model).items():
if "encoder_layer_" in name:
layer_index = int(name.split("encoder_layer_")[1].split(".")[0])
group_name = f"vit_layer_{layer_index}"
# ViT-B has 12 layers. LR increases from first layer with index 0 to
# last layer with index 11.
group_lr = lr * (layerwise_lr_decay ** (11 - layer_index))
else:
group_name = "vit"
group_lr = lr
params, params_no_weight_decay = utils.get_weight_decay_parameters([module])
group = param_groups.setdefault(
group_name,
{
"name": group_name,
"params": [],
"lr": group_lr,
"weight_decay": 0.05,
},
)
group["params"].extend(params)
group_no_weight_decay = param_groups.setdefault(
f"{group_name}_no_weight_decay",
{
"name": f"{group_name}_no_weight_decay",
"params": [],
"lr": group_lr,
"weight_decay": 0.0,
},
)
group_no_weight_decay["params"].extend(params_no_weight_decay)
param_groups["classification_head"] = {
"name": "classification_head",
"params": self.classification_head.parameters(),
"weight_decay": 0.0,
}
optimizer = AdamW(
list(param_groups.values()),
betas=(0.9, 0.999),
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=(
self.trainer.estimated_stepping_batches
/ self.trainer.max_epochs
* 5
),
max_epochs=self.trainer.estimated_stepping_batches,
),
"interval": "step",
}
return [optimizer], [scheduler]


def finetune_eval(
model: Module,
train_dir: Path,
val_dir: Path,
log_dir: Path,
batch_size_per_device: int,
num_workers: int,
accelerator: str,
devices: int,
precision: str,
num_classes: int,
) -> None:
"""Runs fine-tune evaluation on the given model.

Parameters follow MAE settings.
"""
print("Running fine-tune evaluation...")

# Setup training data.
# NOTE: We use transforms from the timm library here as they are the default in MAE
# and torchvision does not provide all required parameters.
train_transform = create_transform(
input_size=224,
is_training=True,
auto_augment="rand-m9-mstd0.5-inc1",
interpolation="bicubic",
re_prob=0.25,
re_mode="pixel",
re_count=1,
mean=IMAGENET_NORMALIZE["mean"],
std=IMAGENET_NORMALIZE["std"],
)
train_dataset = LightlyDataset(input_dir=str(train_dir), transform=train_transform)
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size_per_device,
shuffle=True,
num_workers=num_workers,
drop_last=True,
persistent_workers=True,
)

# Setup validation data.
val_transform = T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=IMAGENET_NORMALIZE["mean"], std=IMAGENET_NORMALIZE["std"]),
]
)
val_dataset = LightlyDataset(input_dir=str(val_dir), transform=val_transform)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size_per_device,
shuffle=False,
num_workers=num_workers,
persistent_workers=True,
)

# Train linear classifier.
metric_callback = MetricCallback()
trainer = Trainer(
max_epochs=100,
accelerator=accelerator,
devices=devices,
callbacks=[
LearningRateMonitor(),
DeviceStatsMonitor(),
metric_callback,
],
logger=TensorBoardLogger(save_dir=str(log_dir), name="finetune_eval"),
precision=precision,
strategy="ddp_find_unused_parameters_true",
)
classifier = FinetuneEvalClassifier(
model=model,
batch_size_per_device=batch_size_per_device,
feature_dim=768,
num_classes=num_classes,
freeze_model=False,
)
trainer.fit(
model=classifier,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
for metric in ["val_top1", "val_top5"]:
print(f"max finetune {metric}: {max(metric_callback.val_metrics[metric])}")
Loading
Loading