Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Contribution Guide for Datamodule #2452

Merged
merged 5 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ torchgeo
tutorials/geospatial
tutorials/contribute_non_geo_dataset
tutorials/custom_raster_dataset
tutorials/contribute_datamodule
tutorials/transforms
tutorials/indices
tutorials/trainers
Expand Down
240 changes: 240 additions & 0 deletions docs/tutorials/contribute_datamodule.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Contribute a New DataModule"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_Written by: Nils Lehmann_\n",
"\n",
"TorchGeo provides Lightning `DataModules` and trainers to faciliate easy and scalabel model training based on simple configuration files. Essentially, a `DataModule` implements the logic for splitting a dataset into train, validation and test splits for reproducability, wrapping them in PyTorch `DataLoaders` and apply augmentations to batches of data. This tutorial will outline a guide to adding a new datamodule to TorchGeo. It is often easy to do so alongside a new dataset and will make the dataset directly useable for a Lightning training and evaluation pipeline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Adding the datamodule"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding a datamodule to TorchGeo consists of roughly four parts:\n",
"\n",
"1. a `dataset_name.py` file under `torchgeo/datamodules` that implements the split logic and defines augmentation\n",
"2. a `dataset_name.yaml` file under `tests/configs` that defines arguments to directly test the datamodule with the appropriate task\n",
"3. add the above yaml file to the list of files to be tested in the corresponding `test_{task}.py` file under `tests/trainers`\n",
"4. an entry to the documentation page file `datamodules.rst` under `docs/api/`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The datamodule `dataset_name.py` file"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The vast majority of new DataModules can inherit from one of the base classes that take care of the majority of the work. The goal of the dataset specific DataModule is to specify how the dataset should be split into train/val/test and any augmentations that should be applied to batches of data.\n",
"\n",
"\n",
"```python\n",
"\n",
"\"\"\"NewDatasetDataModule datamodule.\"\"\"\n",
"\n",
"import os\n",
"from typing import Any\n",
"\n",
"import kornia.augmentation as K\n",
"import torch\n",
"from torch.utils.data import Subset\n",
"\n",
"from .geo import NonGeoDataModule\n",
"from .utils import group_shuffle_split\n",
"\n",
"\n",
"# We follow the convention of appending the dataset_name with \"DataModule\"\n",
"class NewDatasetDataModule(NonGeoDataModule):\n",
" \"\"\"LightningDataModule implementation for the NewDataset dataset.\n",
"\n",
" Make a comment here about how the dataset is split into train/val/test.\n",
"\n",
" You can also add any other comments or references that are helpful to \n",
" understand implementation decisions\n",
"\n",
" .. versionadded:: for example 0.7\n",
" \"\"\"\n",
" # you can define channelwise normalization statistics that will be applied\n",
" # to data batches, which is usually crucial for training stability and decent performance\n",
" mean = torch.Tensor([0.5, 0.4, 0.3])\n",
" std = torch.Tensor([1.5, 1.4, 1.3])\n",
"\n",
" def __init__(\n",
" self, batch_size: int = 64, num_workers: int = 0, size: int = 256, **kwargs: Any\n",
" ) -> None:\n",
" \"\"\"Initialize a new NewDatasetModule instance.\n",
"\n",
" Args:\n",
" batch_size: Size of each mini-batch.\n",
" num_workers: Number of workers for parallel data loading.\n",
" size: resize images of input size 1000x1000 to size x size\n",
" **kwargs: Additional keyword arguments passed to\n",
" :class:`~torchgeo.datasets.NewDataset`.\n",
" \"\"\"\n",
" # in the init method of the base class the dataset will be instantiated with **kwargs\n",
" super().__init__(NewDatasetName, batch_size, num_workers, **kwargs)\n",
"\n",
" # you can specify a series of Kornia augmentations that will be\n",
" # applied to a batch of training data in `on_after_batch_transfer` in the NonGeoDataModule base class\n",
" self.train_aug = K.AugmentationSequential(\n",
" K.Resize(size),\n",
" K.Normalize(self.mean, self.std),\n",
" K.RandomHorizontalFlip(p=0.5),\n",
" K.RandomVerticalFlip(p=0.5),\n",
" data_keys=None,\n",
" keepdim=True,\n",
" )\n",
"\n",
" # you can also define specific augmentations for other experiment phases, if not specified\n",
" # self.aug Augmentations will be applied\n",
" self.aug = K.AugmentationSequential(\n",
" K.Normalize(self.mean, self.std),\n",
" K.Resize(size), data_keys=None, keepdim=True\n",
" )\n",
"\n",
" self.size = size\n",
"\n",
" # setup defines how the dataset should be split\n",
" # this could either be predefined from the dataset authors or\n",
" # done in a prescribed way if some or no splits are specified\n",
" def setup(self, stage: str) -> None:\n",
" \"\"\"Set up datasets.\n",
"\n",
" Args:\n",
" stage: Either 'fit', 'validate', 'test', or 'predict'.\n",
" \"\"\"\n",
" if stage in ['fit', 'validate']:\n",
" dataset = NewDatasetName(split='train', **self.kwargs)\n",
" # perhaps the dataset contains some geographical metadata based on which you would create reproducible random\n",
" # splits\n",
" grouping_paths = [os.path.dirname(path) for path in dataset.file_list]\n",
" train_indices, val_indices = group_shuffle_split(\n",
" grouping_paths, test_size=0.2, random_state=0\n",
" )\n",
" self.train_dataset = Subset(dataset, train_indices)\n",
" self.val_dataset = Subset(dataset, val_indices)\n",
" if stage in ['test']:\n",
" self.test_dataset = NewDatasetName(split='test', **self.kwargs)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Linters\n",
"\n",
"See the [linter docs](https://torchgeo.readthedocs.io/en/stable/user/contributing.html#linters) for an overview of linters that TorchGeo employs and how to apply them during commits for example. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Unit tests"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TorchGeo maintains a test coverage of 100%. This means, that every line of code written within the torchgeo directory is being called by some unit test. For new datasets, we commonly write a separate test file, however, for datamodules we would like to test them directly with one of the task trainers. To do this, you simply need to define a `config.yaml` file and add it to the list of files to be tested by a task. For example, if you added a new datamodule for image segmentation you would write a config file that should look something like this:\n",
"\n",
"```yaml\n",
"model:\n",
" class_path: SemanticSegmentationTask\n",
" init_args:\n",
" loss: 'ce'\n",
" model: 'unet'\n",
" backbone: 'resnet18'\n",
" in_channels: 3 # number of input channels for the dataset\n",
" num_classes: 7 # number of segmentation models\n",
" num_filters: 1 # a smaller model version for faster unit tests\n",
" ignore_index: null # one can ignore certain classes during the loss computation\n",
"data:\n",
" class_path: NewDatasetNameDataModule # arguments to the DataModule above you wrote\n",
" init_args:\n",
" batch_size: 1 # \n",
" dict_kwargs:\n",
" root: 'tests/data/deepglobelandcover' # necessary arguments for the underlying dataset class that the datamodule builds on\n",
"```\n",
"\n",
"The yaml file should \"simulate\" how you would use this datamodule for an actual experiment. Add this file with `dataset_name.yaml` to the `tests/conf` directory."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Final Checklist"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This final checklist might provide a useful overview of the individual parts discussed in this tutorial. You definitely do not need to check all boxes, before submitting a PR. If you have any questions feel free to ask in the Slack channel or open a PR already such that maintainers or other community members can answer specific questions or give pointers. If you want to run your PR as a work of progress, such that the CI tests are run against your code while you work on ticking more boxes you can also convert the PR to a draft on the right side."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- The datamodule implementation\n",
" - define training/val/test split\n",
" - if there are dataset specific augmentations, implement and reference them\n",
" - add microsoft copyright notice to top of the file\n",
"- The config test file\n",
" - select the appropriate task, if the dataset supports multiple ones, you can create one for each task\n",
" - correct arguments such as the number of targets (classes)\n",
" - add the config file to the list of files to be tested in the corresponding `test_{task}.py` file under `tests/trainers`\n",
"- Unit Tests\n",
" - 100% test coverage\n",
"- Documentation\n",
" - an entry to the documentation page file `datamodules.rst` under `docs/api/`"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading