From 9dd6afc2e789dcc2dcc71470df1224d1d67d7263 Mon Sep 17 00:00:00 2001 From: Inigo Val Slijepcevic Date: Wed, 8 May 2024 15:56:49 +0100 Subject: [PATCH] Add vision transformer to `finetune.py` --- zoobot/pytorch/training/finetune.py | 40 +++++++++++++++++++---------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index ae27eac8..3c356e80 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -1,22 +1,22 @@ import logging import os -from typing import Any, Union, Optional import warnings from functools import partial +from typing import Any, Optional, Union import numpy as np import pytorch_lightning as pl -from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.callbacks import LearningRateMonitor - +import timm import torch import torch.nn.functional as F import torchmetrics as tm -import timm - -from zoobot.pytorch.training import losses, schedulers +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from sklearn import linear_model +from sklearn.metrics import accuracy_score from zoobot.pytorch.estimators import define_model +from zoobot.pytorch.training import losses, schedulers from zoobot.shared import schemas # https://discuss.pytorch.org/t/how-to-freeze-bn-layers-while-training-the-rest-of-network-mean-and-var-wont-freeze/89736/7 @@ -131,6 +131,7 @@ def __init__( zoobot_checkpoint_loc ) # extracts the timm encoder self.encoder_dim = self.encoder.num_features + else: assert ( zoobot_checkpoint_loc is None @@ -140,7 +141,10 @@ def __init__( ), "Must pass either checkpoint to load or encoder to use" self.encoder = encoder # work out encoder dim 'manually' - self.encoder_dim = define_model.get_encoder_dim(self.encoder) + if isinstance(self.encoder, timm.models.VisionTransformer): + self.encoder_dim = self.encoder.embed_dim + else: + self.encoder_dim = define_model.get_encoder_dim(self.encoder) self.n_blocks = n_blocks @@ -171,6 +175,12 @@ def __init__( self.prog_bar = prog_bar self.visualize_images = visualize_images + # Remove head if it exists + if hasattr(self.encoder, "head"): + # If the encoder has a 'head' attribute, replace it with Identity() + self.encoder.head = torch.nn.Identity() + print("Replaced encoder.head with Identity()") + def configure_optimizers(self): """ This controls which parameters get optimized @@ -243,19 +253,23 @@ def configure_optimizers(self): # take n blocks, ordered highest layer to lowest layer tuneable_blocks.reverse() - logging.info("possible blocks to tune: {}".format(len(tuneable_blocks))) + logging.info(f"possible blocks to tune: {len(tuneable_blocks)}") + # will finetune all params in first N - logging.info("blocks that will be tuned: {}".format(self.n_blocks)) + logging.info(f"blocks that will be tuned: {self.n_blocks}") blocks_to_tune = tuneable_blocks[: self.n_blocks] + # optionally, can finetune batchnorm params in remaining layers remaining_blocks = tuneable_blocks[self.n_blocks :] - logging.info("Remaining blocks: {}".format(len(remaining_blocks))) + logging.info(f"Remaining blocks: {len(remaining_blocks)}") + assert not any( [block in remaining_blocks for block in blocks_to_tune] ), "Some blocks are in both tuneable and remaining" # Append parameters of layers for finetuning along with decayed learning rate for i, block in enumerate(blocks_to_tune): # _ is the block name e.g. '3' + logging.info(f"Adding block {block} with lr {lr * (self.lr_decay**i)}") params.append({"params": block.parameters(), "lr": lr * (self.lr_decay**i)}) # optionally, for the remaining layers (not otherwise finetuned) you can choose to still FT the batchnorm layers @@ -268,7 +282,7 @@ def configure_optimizers(self): # "lr": lr * (self.lr_decay**i) # }) - logging.info("param groups: {}".format(len(params))) + logging.info(f"param groups: {len(params)}") # because it iterates through the generators, THIS BREAKS TRAINING so only uncomment to debug params # for param_group_n, param_group in enumerate(params):