diff --git a/README.md b/README.md
index f74e7e1b..f919f978 100755
--- a/README.md
+++ b/README.md
@@ -24,7 +24,7 @@ Zoobot is trained using millions of answers by Galaxy Zoo volunteers. This code
## Installation
-You can retrain Zoobot in the cloud with a free GPU using this [Google Colab notebook](https://colab.research.google.com/drive/17bb_KbA2J6yrIm4p4Ue_lEBHMNC1I9Jd?usp=sharing). To install locally, keep reading.
+You can retrain Zoobot in the cloud with a free GPU using this [Google Colab notebook](https://colab.research.google.com/drive/1A_-M3Sz5maQmyfW2A7rEu-g_Zi0RMGz5?usp=sharing). To install locally, keep reading.
Download the code using git:
@@ -49,7 +49,7 @@ I share my install steps [here](#install_cuda). GPUs are optional - Zoobot will
## Quickstart
-The [Colab notebook](https://colab.research.google.com/drive/17bb_KbA2J6yrIm4p4Ue_lEBHMNC1I9Jd?usp=sharing) is the quickest way to get started. Alternatively, the minimal example below illustrates how Zoobot works.
+The [Colab notebook](https://colab.research.google.com/drive/1A_-M3Sz5maQmyfW2A7rEu-g_Zi0RMGz5?usp=sharing) is the quickest way to get started. Alternatively, the minimal example below illustrates how Zoobot works.
Let's say you want to find ringed galaxies and you have a small labelled dataset of 500 ringed or not-ringed galaxies. You can retrain Zoobot to find rings like so:
@@ -97,7 +97,7 @@ Zoobot includes many guides and working examples - see the [Getting Started](#ge
## Getting Started
-I suggest starting with the [Colab notebook](https://colab.research.google.com/drive/17bb_KbA2J6yrIm4p4Ue_lEBHMNC1I9Jd?usp=sharing) or the worked examples below, which you can copy and adapt.
+I suggest starting with the [Colab notebook](https://colab.research.google.com/drive/1A_-M3Sz5maQmyfW2A7rEu-g_Zi0RMGz5?usp=sharing) or the worked examples below, which you can copy and adapt.
For context and explanation, see the [documentation](https://zoobot.readthedocs.io/).
diff --git a/docs/autodoc/api.rst b/docs/autodoc/api.rst
deleted file mode 100755
index e60ef207..00000000
--- a/docs/autodoc/api.rst
+++ /dev/null
@@ -1,15 +0,0 @@
-
-API
-====
-
-We encourage you to explore the code directly.
-There are many comments (and commented-out examples) which might be helpful.
-However, for convenience, you can check the docstrings directly here.
-
-
-.. toctree::
- :maxdepth: 2
-
- pytorch
- tensorflow
- shared
diff --git a/docs/autodoc/pytorch/training/finetune.rst b/docs/autodoc/pytorch/training/finetune.rst
index a23e767c..bd8a261c 100644
--- a/docs/autodoc/pytorch/training/finetune.rst
+++ b/docs/autodoc/pytorch/training/finetune.rst
@@ -7,6 +7,7 @@ See the `README `_ for a minimal example.
See zoobot/pytorch/examples for more worked examples.
.. autoclass:: zoobot.pytorch.training.finetune.FinetuneableZoobotAbstract
+ :members: configure_optimizers
|
@@ -14,12 +15,27 @@ See zoobot/pytorch/examples for more worked examples.
|
+.. autoclass:: zoobot.pytorch.training.finetune.FinetuneableZoobotRegressor
+
+|
+
.. autoclass:: zoobot.pytorch.training.finetune.FinetuneableZoobotTree
|
+.. autoclass:: zoobot.pytorch.training.finetune.LinearHead
+ :members: forward
+
+|
+
+.. autofunction:: zoobot.pytorch.training.finetune.load_pretrained_zoobot
+
+|
+
.. autofunction:: zoobot.pytorch.training.finetune.get_trainer
|
-.. autofunction:: zoobot.pytorch.training.finetune.load_pretrained_encoder
+.. autofunction:: zoobot.pytorch.training.finetune.download_from_name
+
+|
\ No newline at end of file
diff --git a/docs/autodoc/shared/schemas.rst b/docs/autodoc/shared/schemas.rst
index 7df8e0a9..afafe8a1 100755
--- a/docs/autodoc/shared/schemas.rst
+++ b/docs/autodoc/shared/schemas.rst
@@ -26,6 +26,5 @@ See :ref:`training_on_vote_counts` for full explanation.
|
.. autoclass:: zoobot.shared.schemas.Schema
- :members:
|
\ No newline at end of file
diff --git a/docs/autodoc/tensorflow.rst b/docs/autodoc/tensorflow.rst
deleted file mode 100644
index c36b0943..00000000
--- a/docs/autodoc/tensorflow.rst
+++ /dev/null
@@ -1,27 +0,0 @@
-tensorflow
-=============
-
-estimators
--------------
-
-.. toctree::
-
- tensorflow/estimators/define_model
- tensorflow/estimators/efficientnet_custom
-
-training
--------------
-
-.. toctree::
-
- tensorflow/training/finetune
- tensorflow/training/train_with_keras
- tensorflow/training/training_config
- tensorflow/training/losses
-
-predictions
--------------
-
-.. toctree::
-
- tensorflow/predictions/predict_on_dataset
diff --git a/docs/autodoc/tensorflow/estimators/define_model.rst b/docs/autodoc/tensorflow/estimators/define_model.rst
deleted file mode 100755
index 3bbe02ed..00000000
--- a/docs/autodoc/tensorflow/estimators/define_model.rst
+++ /dev/null
@@ -1,17 +0,0 @@
-define_model
-===================
-
-This module contains functions for defining an EfficientNet model (:meth:`zoobot.estimators.define_model.get_model`),
-with or without the GZ DECaLS head, and optionally to load the weights of a pretrained model.
-
-Models are defined using functions in ``efficientnet_standard`` and ``efficientnet_custom``.
-
-.. autofunction:: zoobot.tensorflow.estimators.define_model.get_model
-
-|
-
-.. autofunction:: zoobot.tensorflow.estimators.define_model.load_weights
-
-|
-
-.. autofunction:: zoobot.tensorflow.estimators.define_model.load_model
diff --git a/docs/autodoc/tensorflow/estimators/efficientnet_custom.rst b/docs/autodoc/tensorflow/estimators/efficientnet_custom.rst
deleted file mode 100755
index ba656134..00000000
--- a/docs/autodoc/tensorflow/estimators/efficientnet_custom.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-efficientnet_custom
-===================
-
-.. autofunction:: zoobot.tensorflow.estimators.efficientnet_custom.define_headless_efficientnet
-
-|
-
-.. autofunction:: zoobot.tensorflow.estimators.efficientnet_custom.custom_top_dirichlet
-
diff --git a/docs/autodoc/tensorflow/predictions/predict_on_dataset.rst b/docs/autodoc/tensorflow/predictions/predict_on_dataset.rst
deleted file mode 100755
index eda2c76a..00000000
--- a/docs/autodoc/tensorflow/predictions/predict_on_dataset.rst
+++ /dev/null
@@ -1,10 +0,0 @@
-predict_on_dataset
-===================
-
-This module includes utilities to make predictions with a trained model on a list of images.
-
-.. autofunction:: zoobot.tensorflow.predictions.predict_on_dataset.predict
-
-|
-
-.. autofunction:: zoobot.tensorflow.predictions.predict_on_dataset.paths_in_folder
diff --git a/docs/autodoc/tensorflow/training/finetune.rst b/docs/autodoc/tensorflow/training/finetune.rst
deleted file mode 100644
index 6d0ceee3..00000000
--- a/docs/autodoc/tensorflow/training/finetune.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-.. _tensorflow_finetune:
-
-finetune
-===================
-
-Functions to load and adapt a trained (TensorFlow) Zoobot model to a new problem.
-
-:.. warning:: PyTorch is recommended for new users. See :ref:`pytorch_or_tensorflow` for more.
-
-
-.. autofunction:: zoobot.tensorflow.training.finetune.run_finetuning
diff --git a/docs/autodoc/tensorflow/training/losses.rst b/docs/autodoc/tensorflow/training/losses.rst
deleted file mode 100755
index e744c44c..00000000
--- a/docs/autodoc/tensorflow/training/losses.rst
+++ /dev/null
@@ -1,19 +0,0 @@
-losses
-===================
-
-This module contains functions for calculating the custom Dirichlet-Multinomial loss used for Galaxy Zoo decision trees.
-
-
-.. autofunction:: zoobot.tensorflow.training.losses.get_multiquestion_loss
-
-|
-
-.. autofunction:: zoobot.tensorflow.training.losses.calculate_multiquestion_loss
-
-|
-
-.. autofunction:: zoobot.tensorflow.training.losses.dirichlet_loss
-
-|
-
-.. autofunction:: zoobot.tensorflow.training.losses.get_dirichlet_neg_log_prob
diff --git a/docs/autodoc/tensorflow/training/train_with_keras.rst b/docs/autodoc/tensorflow/training/train_with_keras.rst
deleted file mode 100644
index 2c7026b2..00000000
--- a/docs/autodoc/tensorflow/training/train_with_keras.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-train_with_keras
-===================
-
-This is the interface to train new tensorflow models from scratch.
-
-.. autofunction:: zoobot.tensorflow.training.train_with_keras.train
diff --git a/docs/autodoc/tensorflow/training/training_config.rst b/docs/autodoc/tensorflow/training/training_config.rst
deleted file mode 100755
index e12d4b69..00000000
--- a/docs/autodoc/tensorflow/training/training_config.rst
+++ /dev/null
@@ -1,12 +0,0 @@
-.. _training_config:
-
-training_config
-===================
-
-This module creates the :class:`Trainer` class for training a Zoobot model (itself a tf.keras.Model).
-Implements common features training like early stopping and tensorboard logging.
-
-Follows the same idea as the PyTorch Lightning Trainer object.
-
-.. autoclass:: zoobot.tensorflow.training.training_config.Trainer
- :members:
diff --git a/docs/guides/advanced_finetuning.rst b/docs/guides/advanced_finetuning.rst
index 59a59aff..6554f69c 100644
--- a/docs/guides/advanced_finetuning.rst
+++ b/docs/guides/advanced_finetuning.rst
@@ -4,48 +4,48 @@ Advanced Finetuning
=====================
-Zoobot includes the :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotClassifier` and :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotTree`
-classes to help you finetune Zoobot on classification or decision tree problems, respectively.
-But what about other problems, like regression or object detection?
+Zoobot includes the :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotClassifier`, :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotRegressor`, and :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotTree`
+classes to help you finetune Zoobot on classification, regression, or decision tree problems, respectively.
+But what about other problems, like object detection?
Here's how to integrate pretrained Zoobot models into your own code.
Using Zoobot's Encoder Directly
------------------------------------
-To get Zoobot's encoder, load the model and access the .encoder attribute:
+To get Zoobot's encoder, load any Finetuneable class and grab the encoder attribute:
.. code-block:: python
- model = ZoobotTree.load_from_checkpoint(pretrained_checkpoint_loc)
+ model = FinetuneableZoobotClassifier(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano')
encoder = model.encoder
- model = FinetuneableZoobotClassifier.load_from_checkpoint(finetuned_checkpoint_loc)
- encoder = model.encoder
+or, because Zoobot encoders are `timm` models, you can just directly use `timm`:
+
+.. code-block:: python
+
+ import timm
+
+ encoder = timm.create_model('hf_hub:mwalmsley/zoobot-encoder-convnext_nano', pretrained=True, num_classes=0)
- # for ZoobotTree, there's also a utility function to do this in one line
- encoder = finetune.load_pretrained_encoder(pretrained_checkpoint_loc)
-:class:`zoobot.pytorch.estimators.define_model.ZoobotTree`, :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotClassifier` and :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotTree`
-all have ``.encoder`` and ``.head`` attributes. These are the plain PyTorch (Sequential) models used for encoding or task predictions.
-The Zoobot classes simply wrap these with instructions for training, logging, checkpointing, and so on.
+You can use it like any other `timm` model. For example, we did this to `add contrastive learning `_. Good luck!
-You can use the encoder separately like any PyTorch Sequential for any machine learning task. We did this to `add contrastive learning `_. Go nuts.
Subclassing FinetuneableZoobotAbstract
---------------------------------------
-If you'd like to finetune Zoobot on a new task that isn't classification or vote counts,
+If you'd like to finetune Zoobot on a new task that isn't classification, regression, or vote counts,
you could instead subclass :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotAbstract`.
-This is less general but avoids having to write out your own finetuning training code in e.g. PyTorch Lightning.
+This lets you use our finetuning code with your own head and loss functions.
-For example, to make a regression version:
+Imagine there wasn't a regression version and you wanted to finetune Zoobot on a regression task. You could do:
.. code-block:: python
- class FinetuneableZoobotRegression(FinetuneableZoobotAbstract):
+ class FinetuneableZoobotCustomRegression(FinetuneableZoobotAbstract):
def __init__(
self,
@@ -56,12 +56,12 @@ For example, to make a regression version:
super().__init__(**super_kwargs)
self.foo = foo
- self.loss = torch.nn.MSELoss()
- self.head = torch.nn.Sequential(...)
+ self.loss = torch.nn.SomeCrazyLoss()
+ self.head = torch.nn.Sequential(my_crazy_head)
# see zoobot/pytorch/training/finetune.py for more examples and all methods required
-You can then finetune this new class just as with e.g. FinetuneableZoobotClassifier.
+You can then finetune this new class just as with e.g. :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotRegressor`.
Extracting Frozen Representations
@@ -71,27 +71,21 @@ Once you've finetuned to your survey, or if you're using a pretrained survey, (S
the representations can be stored as frozen vectors and used as features.
We use this at Galaxy Zoo to power our upcoming similary search and anomaly-finding tools.
-As above, we can get Zoobot's encoder from the .encoder attribute:
-
-.. code-block:: python
-
- # can load from either ZoobotTree (if trained from scratch)
- # or FinetuneableZoobotTree (if finetuned)
- encoder = finetune.FinetuneableZoobotTree.load_from_checkpoint(checkpoint_loc).encoder
-
-``encoder`` is a PyTorch Sequential object, so we could use ``encoder.predict()`` to calculate our representations.
+As above, we can get Zoobot's encoder from the .encoder attribute. We could use ``encoder()`` to calculate our representations.
But then we'd have to deal with batching, looping, etc.
To avoid this boilerplate, Zoobot includes a PyTorch Lightning class that lets you pass ``encoder`` to the same :func:`zoobot.pytorch.predictions.predict_on_catalog.predict`
utility function used for making predictions with a full Zoobot model.
.. code-block:: python
+ from zoobot.pytorch.training import representations
+
# convert to simple pytorch lightning model
- model = representations.ZoobotEncoder(encoder=encoder, pyramid=False)
+ lightning_encoder = ZoobotEncoder.load_from_name('hf_hub:mwalmsley/zoobot-encoder-convnext_nano')
predict_on_catalog.predict(
catalog,
- model,
+ lightning_encoder,
n_samples=1,
label_cols=label_cols,
save_loc=save_loc,
@@ -101,9 +95,9 @@ utility function used for making predictions with a full Zoobot model.
See `zoobot/pytorch/examples/representations `_ for a full working example.
-We plan on adding precalculated representations for all our DESI galaxies - but we haven't done it yet. Sorry.
-Please raise an issue if you really need these.
+We are sharing precalculated representations for all our DESI galaxies, and soon for HSC as well.
+Check the data notes at :doc:/data_notes
-The representations are typically quite high-dimensional (1280 for EfficientNetB0) and therefore highly redundant.
+The representations are typically quite high-dimensional (e.g. 1280 for EfficientNetB0) and therefore highly redundant.
We suggest using PCA to compress them down to a more reasonable dimension (e.g. 40) while preserving most of the information.
This was our approach in the `Practical Morphology Tools paper `_.
diff --git a/docs/guides/finetuning.rst b/docs/guides/finetuning.rst
index 1ab59003..d46eee1b 100755
--- a/docs/guides/finetuning.rst
+++ b/docs/guides/finetuning.rst
@@ -30,7 +30,7 @@ Examples
Zoobot includes many working examples of finetuning:
-- `Google Colab notebook `__ (for binary classification in the cloud)
+- `Google Colab notebook `__ (for binary classification in the cloud)
- `finetune_binary_classification.py `__ (script version of the Colab notebook)
- `finetune_counts_full_tree.py `__ (for finetuning on a complicated GZ-style decision tree)
diff --git a/docs/guides/guides.rst b/docs/guides/guides.rst
index e5ab3399..1de9e932 100755
--- a/docs/guides/guides.rst
+++ b/docs/guides/guides.rst
@@ -9,10 +9,8 @@ Below are some practical guides for tasks that we hope Zoobot will be helpful fo
/guides/finetuning
/guides/advanced_finetuning
- /guides/training_on_vote_counts
/guides/how_the_code_fits_together
- /guides/pytorch_or_tensorflow
-
-If you'd prefer worked examples, you can find those under `zoobot/pytorch/examples `_ and `zoobot/tensorflow/examples `_.
+ /guides/loading_data
+ /guides/training_on_vote_counts
-There's also this `Colab notebook `_ demonstrating finetuning which you can run in the cloud (with free access to a powerful GPU, courtesy of Google Research)
+If you'd prefer worked examples, you can find those under `zoobot/pytorch/examples `_.
diff --git a/docs/guides/how_the_code_fits_together.rst b/docs/guides/how_the_code_fits_together.rst
index 6bb8109e..9c816ad5 100644
--- a/docs/guides/how_the_code_fits_together.rst
+++ b/docs/guides/how_the_code_fits_together.rst
@@ -6,37 +6,46 @@ How the Code Fits Together
The Zoobot package has many classes and methods.
This guide aims to be a map summarising how they fit together.
-.. note:: For simplicity, we will only consider the PyTorch version (see :ref:`pytorch_or_tensorflow`).
-
-Defining PyTorch Models
+The Map
-------------------------
-The deep learning part is the simplest piece.
-``define_model.py`` has functions to that define pure PyTorch ``nn.Modules`` (a.k.a. models).
+The Zoobot package has two roles:
+1. **Finetuning**: ``pytorch/training/finetune.py`` is the heart of the package. You will use these classes to load pretrained models and finetune them on new data.
+2. **Training from Scratch** ``pytorch/estimators/define_model.py`` and ``pytorch/training/train_with_pytorch_lightning.py`` create and train the Zoobot models from scratch. These are *not required* for finetuning and will eventually be migrated out.
-Encoders (a.k.a. models that take an image and compress it to a representation vector) are defined using the third party library ``timm``.
-Specifically, ``timm.create_model(architecture_name)`` is used to get the EfficientNet, ResNet, ViT, etc. architectures used to encode our galaxy images.
-This is helpful because defining complicated architectures becomes someone else's job (thanks, Ross Wightman!)
+Let's zoom in on the finetuning part.
-Heads (a.k.a. models that take a representation vector and make a prediction) are defined using ``torch.nn.Sequential``.
-The function :func:`zoobot.pytorch.estimators.define_model.get_pytorch_dirichlet_head`, for example, returns the custom head used to predict vote counts (see :ref:`training_on_vote_counts`).
+Finetuning with Zoobot Classes
+--------------------------------
-The encoders and heads in ``define_model.py`` are used for both training from scratch and finetuning
-Training with PyTorch Lightning
---------------------------------
+There are three Zoobot classes for finetuning:
+1. :class:`FinetuneableZoobotClassifier ` for classification tasks (including multi-class).
+2. :class:`FinetuneableZoobotRegressor ` for regression tasks (including on a unit interval e.g. a fraction).
+3. :class:`FinetuneableZoobotTree ` for training on a tree of labels (e.g. Galaxy Zoo vote counts).
+
+Each user-facing class is actually a subclass of a non-user-facing abstract class, :class:`FinetuneableZoobotAbstract `.
+:class:`FinetuneableZoobotAbstract ` has specifying how to finetune a general PyTorch model,
+which the user-facing classes inherit.
+
+`FinetuneableZoobotAbstract ` controls the core finetuning process: loading a model, accepting arguments controlling the finetuning process, and running the finetuning.
+The user-facing class adds features specific to that type of task. For example, :class:`FinetuneableZoobotClassifier ` adds additional arguments like `num_classes`.
+It also specifies an appropriate head and a loss function.
+
+
+
+Finetuning with PyTorch Lightning
+-----------------------------------
-PyTorch requires a lot of boilerplate code to train models, especially at scale (e.g. multi-node, multi-GPU).
-We use PyTorch Lightning, a third party wrapper API, to make this boilerplate code someone else's job as well.
-The core Zoobot classes you'll use - :class:`ZoobotTree `, :class:`FinetuneableZoobotClassifier ` and :class:`FinetuneableZoobotTree ` -
are all "LightningModule" classes.
These classes have (custom) methods like ``training_step``, ``validation_step``, etc., which specify what should happen at each training stage.
-:class:`FinetuneableZoobotClassifier ` and :class:`FinetuneableZoobotTree `
-are actually subclasses of a non-user-facing abstract class, :class:`FinetuneableZoobotAbstract `.
-:class:`FinetuneableZoobotAbstract ` has specifying how to finetune a general PyTorch model,
-which `FinetuneableZoobotClassifier ` and :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotTree` inherit.
+
+Zoobot is written in PyTorch, a popular deep learning library for Python.
+PyTorch requires a lot of boilerplate code to train models, especially at scale (e.g. multi-node, multi-GPU).
+We use PyTorch Lightning, a third party wrapper API, to make this boilerplate code someone else's job.
+
:class:`ZoobotTree ` is similar to :class:`FinetuneableZoobotAbstract ` but has methods for training from scratch.
@@ -66,28 +75,17 @@ Slightly confusingly, Lightning's ``Trainer`` can also be used to make predictio
and that's how we make predictions with :func:`zoobot.pytorch.predictions.predict_on_catalog.predict`.
-Loading Data
---------------------------
-
-You might notice ``datamodule`` in the examples above.
-Zoobot often includes code like:
-
-.. code-block:: python
+As you can see, there's quite a few layers (pun intended) to training Zoobot models. But we hope this setup is both simple to use and easy to extend, whichever (PyTorch) frameworks you're using.
- from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule
- datamodule = GalaxyDataModule(
- train_catalog=train_catalog,
- val_catalog=val_catalog,
- test_catalog=test_catalog,
- batch_size=batch_size,
- # ...
- )
+.. The deep learning part is the simplest piece.
+.. ``define_model.py`` has functions to that define pure PyTorch ``nn.Modules`` (a.k.a. models).
-Note the import - Zoobot actually doesn't have any code for loading data!
-That's in the separate repository `mwalmsley/galaxy-datasets `.
+.. Encoders (a.k.a. models that take an image and compress it to a representation vector) are defined using the third party library ``timm``.
+.. Specifically, ``timm.create_model(architecture_name)`` is used to get the EfficientNet, ResNet, ViT, etc. architectures used to encode our galaxy images.
+.. This is helpful because defining complicated architectures becomes someone else's job (thanks, Ross Wightman!)
-``galaxy-datasets`` has custom code to turn catalogs of galaxies into the ``LightningDataModule``s that Lightning `expects https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html<>`_.
-These ``LightningDataModule``s themselves have attributes like ``.train_dataloader()`` and ``.predict_dataloader()`` that Lightning's ``Trainer`` object uses to demand data when training, making predictions, and so forth.
+.. Heads (a.k.a. models that take a representation vector and make a prediction) are defined using ``torch.nn.Sequential``.
+.. The function :func:`zoobot.pytorch.estimators.define_model.get_pytorch_dirichlet_head`, for example, returns the custom head used to predict vote counts (see :ref:`training_on_vote_counts`).
-As you can see, there's quite a few layers (pun intended) to training Zoobot models. But we hope this setup is both simple to use and easy to extend, whichever (PyTorch) frameworks you're using.
+.. The encoders and heads in ``define_model.py`` are used for both training from scratch and finetuning
diff --git a/docs/guides/loading_data.rst b/docs/guides/loading_data.rst
new file mode 100644
index 00000000..c6c74857
--- /dev/null
+++ b/docs/guides/loading_data.rst
@@ -0,0 +1,52 @@
+
+Loading Data
+--------------------------
+
+Using GalaxyDataModule
+=========================
+
+Zoobot often includes code like:
+
+.. code-block:: python
+
+ from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule
+
+ datamodule = GalaxyDataModule(
+ train_catalog=train_catalog,
+ val_catalog=val_catalog,
+ test_catalog=test_catalog,
+ batch_size=batch_size,
+ label_cols=['is_cool_galaxy']
+ # ...
+ )
+
+Note the import - Zoobot actually doesn't have any code for loading data!
+That's in the separate repository `mwalmsley/galaxy-datasets `_.
+
+``galaxy-datasets`` has custom code to turn catalogs of galaxies into the ``LightningDataModule`` that Lightning `expects `_.
+Each ``LightningDataModule`` has attributes like ``.train_dataloader()`` and ``.predict_dataloader()`` that Lightning's ``Trainer`` object uses to demand data when training, making predictions, and so forth.
+
+You can pass ``GalaxyDataModule`` train, val, test and predict catalogs. Each catalog needs the columns:
+
+* ``file_loc``: the path to the image file
+* ``id_str``: a unique identifier for the galaxy
+* plus any columns for labels, which you will specify with ``label_cols``. Setting ``label_cols=None`` will load the data without labels (returning batches of (image, id_str)).
+
+``GalaxyDataModule`` will load the images from disk and apply any transformations you specify. Specify transforms one of three ways:
+
+* through the `default arguments `_ of ``GalaxyDataModule`` (e.g. ``GalaxyDataModule(resize_after_crop=(128, 128))``)
+* through a torchvision or albumentations ``Compose`` object e.g. ``GalaxyDataModule(custom_torchvision_transforms=Compose([RandomHorizontalFlip(), RandomVerticalFlip()]))``
+* through a tuple of ``Compose`` objects. The first element will be used for the train dataloaders, and the second for the other dataloaders.
+
+Using the default arguments is simplest and should work well for loading Galaxy-Zoo-like ``jpg`` images. Passing Compose objects offers full customization (short of writing your own ``LightningDataModule``). On that note...
+
+I Want To Do It Myself
+========================
+
+Using ``galaxy-datasets`` is optional. Zoobot is designed to work with any PyTorch ``LightningDataModule`` that returns batches of (images, labels).
+And advanced users can pass data to Zoobot's encoder however they like (see :doc:`advanced_finetuning`).
+
+Images should be PyTorch tensors of shape (batch_size, channels, height, width).
+Values should be floats normalized from 0 to 1 (though in practice, Zoobot can handle other ranges provided you use end-to-end finetuning).
+If you are presenting flux values, you should apply a dynamic range rescaling like ``np.arcsinh`` before normalizing to [0, 1].
+Galaxies should appear large and centered in the image.
diff --git a/docs/guides/pytorch_or_tensorflow.rst b/docs/guides/pytorch_or_tensorflow.rst
deleted file mode 100644
index 9c5bb244..00000000
--- a/docs/guides/pytorch_or_tensorflow.rst
+++ /dev/null
@@ -1,40 +0,0 @@
-.. _pytorch_or_tensorflow:
-
-
-
-PyTorch or TensorFlow?
-===========================
-
-.. warning:: You should use the PyTorch version if possible. This is being actively developed and has the latest features.
-
-Zoobot is really two separate sets of code: `zoobot/pytorch `_ and `zoobot/tensorflow `_.
-They can both train the same EfficientNet model architecture on the same Galaxy Zoo data in the same way, for extracting representations and for finetuning - but they use different underlying deep learning frameworks to do so.
-
-We originally created two versions of Zoobot so that astronomers can use their preferred framework.
-But maintaining two almost entirely separate sets of code is too much work for our current resources (Mike's time, basically).
-Going forward, the PyTorch version will be actively developed and gain new features, while the TensorFlow version will be kept up-to-date but will not otherwise improve.
-
-Tell Me More About What's Different
--------------------------------------
-
-The TensorFlow version was the original version.
-It was used for the `GZ DECaLS catalog `_ and the `Practical Morphology Tools `_ paper.
-You can train EfficientNetB0 and achieve the same performance as with PyTorch (see the "benchmarks folder").
-You can also finetune the trained model, although the process is slightly clunkier.
-
-The PyTorch version was introduced to support other researchers and to integrate with Bootstrap Your Own Latent for the `Towards Foundation Models `_ paper.
-This version is actively developed and includes the latest features.
-
-PyTorch-specific features include:
-
-- Any architecture option from timm (including ResNet and Max-ViT)
-- Improved interface for easy finetuning
-- Layerwise learning rate decay during finetuning
-- Integration with AstroAugmentations (courtesy Micah Bowles) for custom astronomy image augmentations
-- Per-question loss tracking on WandB
-
-
-Can I have a JAX version?
-----------------------------
-
-Only if you build it yourself.
diff --git a/docs/index.rst b/docs/index.rst
index a3a4cfc1..64fbcac4 100755
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -5,7 +5,7 @@ Zoobot Documentation
====================
Zoobot makes it easy to finetune a state-of-the-art deep learning classifier to solve your galaxy morphology problem.
-For example, you can finetune a classifier to find ring galaxies with `just a few hundred examples `_.
+For example, you can finetune a classifier to find ring galaxies with `just a few hundred examples `_.
.. figure:: finetuning_rings.png
:alt: Ring galaxies found using Zoobot
@@ -15,14 +15,16 @@ For example, you can finetune a classifier to find ring galaxies with `just a fe
The easiest way to learn to use Zoobot is simply to use Zoobot.
We suggest you start with our worked examples.
-The `Colab notebook `_ is the fastest way to get started.
-See the README for many scripts that you can run and adapt locally.
+* This `Colab notebook `_ will walk you through using Zoobot to classify galaxy images.
+* There's a similar `notebook `_ for using Zoobot for regression on galaxy images.
-Guides
+For more explanation, read on.
+
+User Guides
-------------
-If you'd like more explanation and context, we've written these guides.
+We've written these guides to add explanation and context.
.. toctree::
:maxdepth: 2
@@ -43,24 +45,17 @@ To choose and download a pretrained model, see here.
API reference
--------------
-Look here for information on a specific function, class or
-method.
+We've added docstrings to all the key methods you might use. Feel free to check the code or reach out if you have questions.
.. toctree::
- :maxdepth: 2
-
- autodoc/api
-
+ :maxdepth: 4
-.. You do not need to be a machine learning expert to use Zoobot.
-.. Zoobot includes :ref:`components ` for common tasks like loading images, managing training, and making predictions.
-.. You simply need to assemble these together.
-
-.. .. toctree::
-.. :maxdepth: 2
-
-.. components/overview
+ autodoc/pytorch
+.. different level to not expand schema too much
+.. toctree::
+ :maxdepth: 3
+ autodoc/shared
.. Indices
@@ -78,6 +73,7 @@ method.
.. To build:
.. install sphinx https://www.sphinx-doc.org/en/master/usage/installation.html is confusing, you can just use pip install -U sphinx
+.. and pip install furo
.. run from in docs folder: make html
.. can also check docs with
diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py
index 76f73990..c1e7402f 100644
--- a/zoobot/pytorch/training/finetune.py
+++ b/zoobot/pytorch/training/finetune.py
@@ -43,16 +43,16 @@ class FinetuneableZoobotAbstract(pl.LightningModule):
- When provided `learning_rate` it will set the optimizer to use that learning rate.
Any FinetuneableZoobot model can be loaded in one of three ways:
- - HuggingFace name e.g. FinetuneableZoobotX(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). Recommended.
- - Any PyTorch model in memory e.g. FinetuneableZoobotX(encoder=some_model, ...)
- - ZoobotTree checkpoint e.g. FinetuneableZoobotX(zoobot_checkpoint_loc='path/to/zoobot_tree.ckpt', ...)
+ - HuggingFace name e.g. `FinetuneableZoobotX(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...)`. Recommended.
+ - Any PyTorch model in memory e.g. `FinetuneableZoobotX(encoder=some_model, ...)`
+ - ZoobotTree checkpoint e.g. `FinetuneableZoobotX(zoobot_checkpoint_loc='path/to/zoobot_tree.ckpt', ...)`
You could subclass this class to solve new finetuning tasks - see :ref:`advanced_finetuning`.
Args:
name (str, optional): Name of a model on HuggingFace Hub e.g.'hf_hub:mwalmsley/zoobot-encoder-convnext_nano'. Defaults to None.
encoder (torch.nn.Module, optional): A PyTorch model already loaded in memory
- zoobot_checkpoint_loc (str, optional): Path to ZoobotTree lightning checkpoint to load. Loads with Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_encoder`. Defaults to None.
+ zoobot_checkpoint_loc (str, optional): Path to ZoobotTree lightning checkpoint to load. Loads with Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_zoobot`. Defaults to None.
n_blocks (int, optional):
lr_decay (float, optional): For each layer i below the head, reduce the learning rate by lr_decay ^ i. Defaults to 0.75.
@@ -174,11 +174,11 @@ def configure_optimizers(self):
and then pick the top self.n_blocks to finetune
weight_decay is applied to both the head and (if relevant) the encoder
- learning rate decay is applied to the encoder only: lr * (lr_decay**block_n), ignoring the head (block 0)
+ learning rate decay is applied to the encoder only: lr x (lr_decay^block_n), ignoring the head (block 0)
What counts as a "block" is a bit fuzzy, but I generally use the self.encoder.stages from timm. I also count the stem as a block.
- *batch norm layers may optionally still have updated statistics using always_train_batchnorm
+ batch norm layers may optionally still have updated statistics using always_train_batchnorm
"""
lr = self.learning_rate
@@ -395,10 +395,8 @@ class FinetuneableZoobotClassifier(FinetuneableZoobotAbstract):
These are shared between classifier, regressor, and tree models.
See the docstring of :class:``FinetuneableZoobotAbstract`` for more.
- Models can be loaded in one of three ways:
- - HuggingFace name e.g. FinetuneableZoobotClassifier(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). Recommended.
- - Any PyTorch model in memory e.g. FinetuneableZoobotClassifier(encoder=some_model, ...)
- - ZoobotTree checkpoint e.g. FinetuneableZoobotClassifier(zoobot_checkpoint_loc='path/to/zoobot_tree.ckpt', ...)
+ Models can be loaded with `FinetuneableZoobotClassifier(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...)`.
+ See :class:``FinetuneableZoobotAbstract`` for other loading options (e.g. in-memory models or local checkpoints).
Args:
num_classes (int): num. of target classes (e.g. 2 for binary classification).
@@ -511,10 +509,8 @@ class FinetuneableZoobotRegressor(FinetuneableZoobotAbstract):
These are shared between classifier, regressor, and tree models.
See the docstring of :class:``FinetuneableZoobotAbstract`` for more.
- Models can be loaded in one of three ways:
- - HuggingFace name e.g. FinetuneableZoobotRegressor(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). Recommended.
- - Any PyTorch model in memory e.g. FinetuneableZoobotRegressor(encoder=some_model, ...)
- - ZoobotTree checkpoint e.g. FinetuneableZoobotRegressor(zoobot_checkpoint_loc='path/to/zoobot_tree.ckpt', ...)
+ Models can be loaded with `FinetuneableZoobotRegressor(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...)`.
+ See :class:``FinetuneableZoobotAbstract`` for other loading options (e.g. in-memory models or local checkpoints).
Args:
@@ -619,10 +615,8 @@ class FinetuneableZoobotTree(FinetuneableZoobotAbstract):
These are shared between classifier, regressor, and tree models.
See the docstring of :class:``FinetuneableZoobotAbstract`` for more.
- Models can be loaded in one of three ways:
- - HuggingFace name e.g. FinetuneableZoobotRegressor(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...). Recommended.
- - Any PyTorch model in memory e.g. FinetuneableZoobotRegressor(encoder=some_model, ...)
- - ZoobotTree checkpoint e.g. FinetuneableZoobotRegressor(zoobot_checkpoint_loc='path/to/zoobot_tree.ckpt', ...)
+ Models can be loaded with `FinetuneableZoobotTree(name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', ...)`.
+ See :class:``FinetuneableZoobotAbstract`` for other loading options (e.g. in-memory models or local checkpoints).
Args:
schema (schemas.Schema): description of the layout of the decision tree. See :class:`zoobot.shared.schemas.Schema`.
@@ -680,7 +674,15 @@ def __init__(self, input_dim: int, output_dim: int, dropout_prob=0.5, activation
self.activation = activation
def forward(self, x):
- # returns logits, as recommended for CrossEntropy loss
+ """returns logits, as recommended for CrossEntropy loss
+
+ Args:
+ x (torch.Tensor): encoded representation
+
+ Returns:
+ torch.Tensor: result (see docstring of LinearHead)
+ """
+ #
x = self.dropout(x)
x = self.linear(x)
if self.activation is not None:
diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py
index 9e22b508..2c9e7524 100644
--- a/zoobot/pytorch/training/train_with_pytorch_lightning.py
+++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py
@@ -66,42 +66,48 @@ def train_default_zoobot_from_scratch(
) -> Tuple[define_model.ZoobotTree, pl.Trainer]:
"""
Train Zoobot from scratch on a big galaxy catalog.
- Zoobot is a base deep learning model (anything from timm, typically a CNN) plus a dirichlet head.
- Images are augmented using the default transforms (flips, rotations, zooms)
- from `the galaxy-datasets repo `_.
- Once trained, Zoobot can be finetuned to new data.
- For finetuning, see zoobot/pytorch/training/finetune.py.
- Many pretrained models are already available - see :ref:`datanotes`.
+ **You don't need to use this**.
+ Training from scratch is becoming increasingly complicated (as you can see from the arguments) due to ongoing research on the best methods.
+ This will be refactored to a dedicated "foundation" repo.
Args:
save_dir (str): folder to save training logs and trained model checkpoints
+ schema (shared.schemas.Schema): Schema object with label_cols, question_answer_pairs, and dependencies
catalog (pd.DataFrame, optional): Galaxy catalog with columns `id_str` and `file_loc`. Will be automatically split to train and val (no test). Defaults to None.
train_catalog (pd.DataFrame, optional): As above, but already split by you for training. Defaults to None.
val_catalog (pd.DataFrame, optional): As above, for validation. Defaults to None.
test_catalog (pd.DataFrame, optional): As above, for testing. Defaults to None.
+ train_urls (list, optional): List of URLs to webdatasets for training. Defaults to None.
+ val_urls (list, optional): List of URLs to webdatasets for validation. Defaults to None.
+ test_urls (list, optional): List of URLs to webdatasets for testing. Defaults to None.
+ cache_dir (str, optional): Directory to cache webdatasets. Defaults to None.
epochs (int, optional): Max. number of epochs to train for. Defaults to 1000.
patience (int, optional): Max. number of epochs to wait for any loss improvement before ending training. Defaults to 8.
architecture_name (str, optional): Architecture to use. Passed to timm. Must be in timm.list_models(). Defaults to 'efficientnet_b0'.
+ timm_kwargs (dict, optional): Additional kwargs to pass to timm model init method, for example {'drop_connect_rate': 0.2}. Defaults to {}.
+ batch_size (int, optional): Batch size. Defaults to 128.
dropout_rate (float, optional): Randomly drop activations prior to the output layer, with this probability. Defaults to 0.2.
- drop_connect_rate (float, optional): Randomly drop blocks with this probability, for regularisation. For supported timm models only. Defaults to 0.2.
learning_rate (float, optional): Base learning rate for AdamW. Defaults to 1e-3.
betas (tuple, optional): Beta args (i.e. momentum) for adamW. Defaults to (0.9, 0.999).
weight_decay (float, optional): Weight decay arg (i.e. L2 penalty) for AdamW. Defaults to 0.01.
- scheduler_params (dict, optional): Specify a learning rate scheduler. See code. Not recommended with AdamW. Defaults to {}.
+ scheduler_params (dict, optional): Specify a learning rate scheduler. See code below. Defaults to {}.
color (bool, optional): Train on RGB images rather than channel-averaged. Defaults to False.
resize_after_crop (int, optional): Input image size. After all transforms, images will be resized to this size. Defaults to 224.
crop_scale_bounds (tuple, optional): Off-center crop fraction (<1 means zoom in). Defaults to (0.7, 0.8).
crop_ratio_bounds (tuple, optional): Aspect ratio of crop above. Defaults to (0.9, 1.1).
nodes (int, optional): Multi-node training Unlikely to work on your cluster without tinkering. Defaults to 1 (i.e. one node).
gpus (int, optional): Multi-GPU training. Uses distributed data parallel - essentially, full dataset is split by GPU. See torch docs. Defaults to 2.
+ sync_batchnorm (bool, optional): Use synchronized batchnorm. Defaults to False.
num_workers (int, optional): Processes for loading data. See torch dataloader docs. Should be < num cpus. Defaults to 4.
prefetch_factor (int, optional): Num. batches to queue in memory per dataloader. See torch dataloader docs. Defaults to 4.
mixed_precision (bool, optional): Use (mostly) half-precision to halve memory requirements. May cause instability. See Lightning Trainer docs. Defaults to False.
+ compile_encoder (bool, optional): Compile the encoder with torch.compile (new in torch v2). Defaults to False.
wandb_logger (pl.loggers.wandb.WandbLogger, optional): Logger to track experiments on Weights and Biases. Defaults to None.
checkpoint_file_template (str, optional): formatting for checkpoint filename. See Lightning docs. Defaults to None.
auto_insert_metric_name (bool, optional): escape "/" in metric names when naming checkpoints. See Lightning docs. Defaults to True.
save_top_k (int, optional): Keep the k best checkpoints. See Lightning docs. Defaults to 3.
+ extra_callbacks (list, optional): Additional callbacks to pass to the Trainer. Defaults to None.
random_state (int, optional): Seed. Defaults to 42.
Returns:
diff --git a/zoobot/shared/schemas.py b/zoobot/shared/schemas.py
index 57ecb537..3f85dbbe 100755
--- a/zoobot/shared/schemas.py
+++ b/zoobot/shared/schemas.py
@@ -130,6 +130,7 @@ def set_dependencies(questions, dependencies):
class Schema():
+
def __init__(self, question_answer_pairs:dict, dependencies: dict):
"""
Relate the df label columns tor question/answer groups and to tfrecod label indices
@@ -141,6 +142,23 @@ def __init__(self, question_answer_pairs:dict, dependencies: dict):
- answers in between will be included: these are used to slice
- df columns must be contigious by question (e.g. not smooth_yes, bar_no, smooth_no) for this to work!
+ The following schemas are available via the module (e.g. `from zoobot.shared.schemas import decals_dr5_ortho_schema`):
+ - decals_dr5_ortho_schema
+ - decals_dr8_ortho_schema
+ - decals_all_campaigns_ortho_schema
+ - gz2_ortho_schema
+ - gz_candels_ortho_schema
+ - gz_hubble_ortho_schema
+ - cosmic_dawn_ortho_schema
+ - cosmic_dawn_schema
+ - gz_rings_schema
+ - desi_schema
+ - gz_evo_v1_schema (this is the schema currently used for pretraining)
+ - gz_ukidss_schema
+ - gz_jwst_schema
+
+ "ortho" refers to the orthogonal question suffix (-cd, -dr8, etc).
+
Args:
question_answer_pairs (dict): e.g. {'smooth-or-featured: ['_smooth, _featured-or-disk, ...], ...}
dependencies (dict): dict mapping each question (e.g. disk-edge-on) to the answer on which it depends (e.g. smooth-or-featured_featured-or-disk)