Skip to content

Commit

Permalink
Add vision transformer to finetune.py
Browse files Browse the repository at this point in the history
  • Loading branch information
inigoval committed May 8, 2024
1 parent e512126 commit 9dd6afc
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import logging
import os
from typing import Any, Union, Optional
import warnings
from functools import partial
from typing import Any, Optional, Union

import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor

import timm
import torch
import torch.nn.functional as F
import torchmetrics as tm
import timm

from zoobot.pytorch.training import losses, schedulers
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from sklearn import linear_model
from sklearn.metrics import accuracy_score
from zoobot.pytorch.estimators import define_model
from zoobot.pytorch.training import losses, schedulers
from zoobot.shared import schemas

# https://discuss.pytorch.org/t/how-to-freeze-bn-layers-while-training-the-rest-of-network-mean-and-var-wont-freeze/89736/7
Expand Down Expand Up @@ -131,6 +131,7 @@ def __init__(
zoobot_checkpoint_loc
) # extracts the timm encoder
self.encoder_dim = self.encoder.num_features

else:
assert (
zoobot_checkpoint_loc is None
Expand All @@ -140,7 +141,10 @@ def __init__(
), "Must pass either checkpoint to load or encoder to use"
self.encoder = encoder
# work out encoder dim 'manually'
self.encoder_dim = define_model.get_encoder_dim(self.encoder)
if isinstance(self.encoder, timm.models.VisionTransformer):
self.encoder_dim = self.encoder.embed_dim
else:
self.encoder_dim = define_model.get_encoder_dim(self.encoder)

self.n_blocks = n_blocks

Expand Down Expand Up @@ -171,6 +175,12 @@ def __init__(
self.prog_bar = prog_bar
self.visualize_images = visualize_images

# Remove head if it exists
if hasattr(self.encoder, "head"):
# If the encoder has a 'head' attribute, replace it with Identity()
self.encoder.head = torch.nn.Identity()
print("Replaced encoder.head with Identity()")

def configure_optimizers(self):
"""
This controls which parameters get optimized
Expand Down Expand Up @@ -243,19 +253,23 @@ def configure_optimizers(self):

# take n blocks, ordered highest layer to lowest layer
tuneable_blocks.reverse()
logging.info("possible blocks to tune: {}".format(len(tuneable_blocks)))
logging.info(f"possible blocks to tune: {len(tuneable_blocks)}")

# will finetune all params in first N
logging.info("blocks that will be tuned: {}".format(self.n_blocks))
logging.info(f"blocks that will be tuned: {self.n_blocks}")
blocks_to_tune = tuneable_blocks[: self.n_blocks]

# optionally, can finetune batchnorm params in remaining layers
remaining_blocks = tuneable_blocks[self.n_blocks :]
logging.info("Remaining blocks: {}".format(len(remaining_blocks)))
logging.info(f"Remaining blocks: {len(remaining_blocks)}")

assert not any(
[block in remaining_blocks for block in blocks_to_tune]
), "Some blocks are in both tuneable and remaining"

# Append parameters of layers for finetuning along with decayed learning rate
for i, block in enumerate(blocks_to_tune): # _ is the block name e.g. '3'
logging.info(f"Adding block {block} with lr {lr * (self.lr_decay**i)}")
params.append({"params": block.parameters(), "lr": lr * (self.lr_decay**i)})

# optionally, for the remaining layers (not otherwise finetuned) you can choose to still FT the batchnorm layers
Expand All @@ -268,7 +282,7 @@ def configure_optimizers(self):
# "lr": lr * (self.lr_decay**i)
# })

logging.info("param groups: {}".format(len(params)))
logging.info(f"param groups: {len(params)}")

# because it iterates through the generators, THIS BREAKS TRAINING so only uncomment to debug params
# for param_group_n, param_group in enumerate(params):
Expand Down

0 comments on commit 9dd6afc

Please sign in to comment.