Skip to content

Commit

Permalink
update examples for v2
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Apr 16, 2024
1 parent 1669a7c commit 3bf3806
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 23 deletions.
19 changes: 12 additions & 7 deletions zoobot/pytorch/examples/finetuning/finetune_counts_full_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from zoobot.pytorch.training import finetune
from zoobot.pytorch.predictions import predict_on_catalog
from zoobot.shared.schemas import gz_candels_ortho_schema
from zoobot.shared.load_predictions import prediction_hdf5_to_summary_parquet

"""
Example for finetuning Zoobot on counts of volunteer responses throughout a complex decision tree (here, GZ CANDELS).
Expand Down Expand Up @@ -67,12 +68,12 @@
resize_after_crop=resize_after_crop
)

checkpoint_loc = os.path.join(
# TODO replace with path to downloaded checkpoints. See Zoobot README for download links.
repo_dir, 'gz-decals-classifiers/results/benchmarks/pytorch/evo/uploaded/effnetb0_greyscale_224px.ckpt') # decals hparams

model = finetune.FinetuneableZoobotTree(checkpoint_loc=checkpoint_loc, schema=schema)

model = finetune.FinetuneableZoobotTree(
name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
schema=schema
)

# TODO set this to wherever you'd like to save your results
save_dir = os.path.join(
repo_dir, f'gz-decals-classifiers/results/finetune_{np.random.randint(1e8)}')

Expand All @@ -86,12 +87,16 @@
# now save predictions on test set to evaluate performance
datamodule_kwargs = {'batch_size': batch_size, 'resize_after_crop': resize_after_crop}
trainer_kwargs = {'devices': 1, 'accelerator': accelerator}

hdf5_loc = os.path.join(save_dir, 'test_predictions.hdf5')
predict_on_catalog.predict(
test_catalog,
model,
n_samples=1,
label_cols=schema.label_cols,
save_loc=os.path.join(save_dir, 'test_predictions.csv'),
save_loc=hdf5_loc,
datamodule_kwargs=datamodule_kwargs,
trainer_kwargs=trainer_kwargs
)

prediction_hdf5_to_summary_parquet(hdf5_loc=hdf5_loc, save_loc=hdf5_loc.replace('.hdf5', 'summary.parquet'), schema=schema)
42 changes: 26 additions & 16 deletions zoobot/pytorch/examples/representations/get_representations.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,45 @@
import logging
import os

import timm

from galaxy_datasets import demo_rings

from zoobot.pytorch.training import finetune, representations
from zoobot.pytorch.estimators import define_model
from zoobot.pytorch.predictions import predict_on_catalog
from zoobot.pytorch.training import finetune
from zoobot.shared import load_predictions, schemas


def main(catalog, checkpoint_loc, save_dir):
def main(catalog, save_dir, name="hf_hub:mwalmsley/zoobot-encoder-convnext_nano"):

assert all([os.path.isfile(x) for x in catalog['file_loc']])

if not os.path.exists(save_dir):
os.mkdir(save_dir)

# can load from either ZoobotTree checkpoint (if trained from scratch)
encoder = define_model.ZoobotTree.load_from_checkpoint(checkpoint_loc).encoder
# or FinetuneableZoobotTree (if finetuned)
# currently, FinetuneableZoobotTree checkpoints should be loaded as ZoobotTree with the args below
# this is a bit awkward and I'm working on a clearer method - but it does work.
# encoder = define_model.ZoobotTree.load_from_checkpoint(checkpoint_loc, output_dim=TODO, question_index_groups=[]).encoder
# load the encoder

# OPTION 1
# Load a pretrained model from HuggingFace, with no finetuning, only as published
model = representations.ZoobotEncoder.load_from_name(name)
# or equivalently (the above is just a wrapper for these two lines below)
# encoder = timm.create_model(model_name=name, pretrained=True)
# model = representations.ZoobotEncoder(encoder=encoder)

# convert to simple pytorch lightning model
model = representations.ZoobotEncoder(encoder=encoder, pyramid=False)
"""
# OPTION 2
label_cols = [f'feat_{n}' for n in range(1280)]
# Load a model that has been finetuned on your own data
# (...do your usual finetuning..., or load a finetuned model with finetune.FinetuneableZoobotClassifier(checkpoint_loc=....ckpt)
encoder = finetuned_model.encoder
# and then convert to simple pytorch lightning model. You can use any pytorch model here.
model = representations.ZoobotEncoder(encoder=encoder)
"""

encoder_dim = define_model.get_encoder_dim(model.encoder)
label_cols = [f'feat_{n}' for n in range(encoder_dim)]
save_loc = os.path.join(save_dir, 'representations.hdf5')

accelerator = 'cpu' # or 'gpu' if available
Expand All @@ -52,20 +65,17 @@ def main(catalog, checkpoint_loc, save_dir):

logging.basicConfig(level=logging.INFO)

# load the gz evo model for representations
checkpoint_loc = '/home/walml/repos/gz-decals-classifiers/results/benchmarks/pytorch/evo/evo_py_gr_11941/checkpoints/epoch=73-step=42698.ckpt'

# use this demo dataset
# TODO change this to wherever you'd like, it will auto-download
data_dir = '/home/walml/repos/galaxy-datasets/roots/demo_rings'
data_dir = '/Users/user/repos/galaxy-datasets/roots/demo_rings'
catalog, _ = demo_rings(root=data_dir, download=True, train=True)
print(catalog.head())
# zoobot expects id_str and file_loc columns, so add these if needed

# save the representations here
# TODO change this to wherever you'd like
save_dir = os.path.join('/home/walml/repos/zoobot/results/pytorch/representations/example')
save_dir = os.path.join('/Users/user/repos/zoobot/results/pytorch/representations/example')

representations_loc = main(catalog, checkpoint_loc, save_dir)
representations_loc = main(catalog, save_dir)
rep_df = load_predictions.single_forward_pass_hdf5s_to_df(representations_loc)
print(rep_df)

0 comments on commit 3bf3806

Please sign in to comment.