Skip to content

Commit

Permalink
Add sklearn sanity check to classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
inigoval committed May 8, 2024
1 parent 9dd6afc commit b907e2a
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from sklearn import linear_model
from sklearn.metrics import accuracy_score

from zoobot.pytorch.estimators import define_model
from zoobot.pytorch.training import losses, schedulers
from zoobot.shared import schemas
Expand Down Expand Up @@ -349,6 +350,7 @@ def run_step_through_model(self, batch):
x, y = batch
y_pred = self.forward(x)
loss = self.loss(y_pred, y) # must be subclasses and specified
loss.float()
return y, y_pred, loss

def step_to_dict(self, y, y_pred, loss):
Expand Down Expand Up @@ -484,6 +486,35 @@ def step_to_dict(self, y, y_pred, loss):
"class_predictions": y_class_preds,
}

# Sanity check embeddings with linear evaluation first
def on_train_start(self) -> None:
with torch.no_grad():
embeddings, labels = {"train": [], "val": []}, {"train": [], "val": []}

# Get validation set embeddings
for x, y in self.trainer.datamodule.val_dataloader():
embeddings["val"] += self.encoder(x.to(self.device)).cpu()
labels["val"] += y

# Get train set embeddings
for x, y in self.trainer.datamodule.train_dataloader():
embeddings["train"] += self.encoder(x.to(self.device)).cpu()
labels["train"] += y

# this is linear *train* acc but that's okay, simply test of features
model = linear_model.LogisticRegression(penalty=None, max_iter=200)
model.fit(embeddings["train"], labels["train"])

self.log(
"finetuning/linear_eval/val",
accuracy_score(labels["val"], model.predict(embeddings["val"])),
)
self.log(
"finetuning/linear_eval/train",
accuracy_score(labels["train"], model.predict(embeddings["train"])),
)
# doesn't need to be torchmetric, only happens in one go? but distributed

def on_train_batch_end(self, step_output, *args):
super().on_train_batch_end(step_output, *args)

Expand Down

0 comments on commit b907e2a

Please sign in to comment.