Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor datamodule/model testing #329

Merged
merged 21 commits into from
Dec 30, 2021
Merged

Refactor datamodule/model testing #329

merged 21 commits into from
Dec 30, 2021

Conversation

adamjstewart
Copy link
Collaborator

@adamjstewart adamjstewart commented Dec 25, 2021

Note: The word "model" in this PR refers to pytorch-lightning models, AKA torchgeo.trainers.*Task objects. The word "trainer" in this PR refers to pl.Trainer objects, not torchgeo.trainers.*Task objects. I really wish naming conventions between torchvision and pytorch-lightning were more consistent...

Motivation

Previously, we attempted to unit test all datamodules and models. However, pytorch-lightning isn't designed for unit testing (datamodules/models don't work standalone, they need a pl.Trainer class wrapping them). This meant that we needed to monkeypatch large parts of the code to add fake trainers and loggers. In order to properly test #286, we would need to monkeypatch even more features. This isn't sustainable, and would break if pytorch-lightning ever changed their API.

Implementation

This PR removes almost all tests from tests/datamodules and converts tests/trainers to use real datamodules and models for integration testing with a pl.Trainer. This has a number of advantages:

  • removes the need for all monkeypatching
  • significantly reduces the number of lines of code needed for testing (-35%)
  • tests all datamodules/models in a more realistic setting
  • testing a new datamodule would only involve adding a single line of code

This change also necessitates modifying some of our testing data to increase image sizes. In these cases, I've added a data.py script that can be used to generate testing data. We should start encouraging these for contributors adding new datasets. In the future, this could allow us to remove all fake data from the repo and generate it on-the-fly for testing purposes.

@github-actions github-actions bot added the testing Continuous integration testing label Dec 25, 2021
@adamjstewart adamjstewart force-pushed the tests/trainer-refactor branch from 153b60c to a102194 Compare December 27, 2021 16:02
@github-actions github-actions bot added datamodules PyTorch Lightning datamodules trainers PyTorch Lightning trainers datasets Geospatial or benchmark datasets labels Dec 27, 2021
@calebrob6
Copy link
Member

calebrob6 commented Dec 27, 2021 via email

@adamjstewart adamjstewart mentioned this pull request Dec 27, 2021


# The values here are taken from the defaults here https://pytorch-lightning.readthedocs.io/en/1.3.8/common/trainer.html#init
# this probably should be made into a schema, e.g. as shown https://omegaconf.readthedocs.io/en/2.0_branch/structured_config.html#merging-with-other-configs
trainer: # These are the parameters passed to the pytorch lightning Trainer object
logger: True
checkpoint_callback: True
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many of these defaults have been renamed. The previous values give deprecation warnings. We may want to not hardcode these values and instead let pytorch-lightning assign these values automatically.

ignore_zeros: False
datamodule:
naip_root_dir: "tests/data/naip"
chesapeake_root_dir: "tests/data/chesapeake/BAYWIDE"
batch_size: 32
batch_size: 2
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a batch size of 1 (what I'm using on all other tests) this breaks and I don't know why.

root_dir: "tests/data/oscd"
batch_size: 1
num_workers: 0
val_split_pct: 0.5
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val_split_pct == 0 breaks the tests and I don't know why

"ucmerced",
],
)
def test_tasks(task: str, tmp_path: Path) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are now in tests/trainers/*.py and will be run on every commit instead of just on releases.

batch_size: 64
patches_per_tile: 2
patch_size: 64
batch_size: 2
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BYOLTask tests fail with a batch size of 1 and I have no idea why. SemanticSegmentationTask tests work fine with a batch size of 1. Oh, the mysteries of life...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If any preprocessing methods use .squeeze(), then they will remove the batch dimension which will in turn break the forward pass

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update these to define the only dim that should be squeezed like .squeeze(dim=1)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tried this (both dim=0 and dim=1) and as soon as it fixes one issue, it creates another one. I don't think I have the time to debug this any further, but if anyone wants to submit a follow-up PR to fix this I would be very happy.

@@ -261,7 +261,7 @@ def __init__(
image_size: the size of the training images
hidden_layer: the hidden layer in ``model`` to attach the projection
head to, can be the name of the layer or index of the layer
input_channels: number of input channels to the model
in_channels: number of input channels to the model
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This trainer was pretty out-of-sync with our other trainers and uses different key names. Tried to sync them a bit more.

@adamjstewart adamjstewart marked this pull request as ready for review December 28, 2021 18:12
@adamjstewart adamjstewart requested review from calebrob6 and isaaccorley and removed request for calebrob6 December 28, 2021 18:12
@adamjstewart adamjstewart added this to the 0.2.0 milestone Dec 28, 2021
@adamjstewart adamjstewart force-pushed the tests/trainer-refactor branch from 082a078 to e3811cf Compare December 28, 2021 21:11
batch_size: 2
num_workers: 0
class_set: ${experiment.module.num_classes}
use_prior_labels: True
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These datamodule settings (use_prior_labels: True) work with BYOLTask but not with SemanticSegmentationTask and I have no idea why.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what error you're getting but BYOLTask doesn't use the masks so not surprised it passes. The prior labels I believe are soft probabilities so I don't think we've set up the SegmentationTask loss to handle that.

@adamjstewart
Copy link
Collaborator Author

For OSCD, I think a lot of the confusion is that our image is [2 x C x H x W] instead of [2C x H x W] unlike all other datasets. We could concatenate the two images instead of stacking to resolve some of these issues. Alternatively, we could create a new ChangeDetectionTask for these kinds of datasets.

Copy link
Collaborator

@isaaccorley isaaccorley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lgtm. My only concern is that adding data.py scripts as an additional step for adding datasets may be confusing/complicated to new contributors.

@adamjstewart
Copy link
Collaborator Author

adamjstewart commented Dec 30, 2021

I'm not sure if I would consider data.py to be an additional step since all of those commands need to be run anyway to create the fake data. That script doesn't need to be fancy, it could even be a data.sh script that includes the same commands they would run on the command line. I think it helps us know that the fake data is actually fake and not real data stuffed into a tarball. It also gives examples for how fake data was created for other datasets.

@adamjstewart adamjstewart merged commit 744078f into main Dec 30, 2021
@adamjstewart adamjstewart deleted the tests/trainer-refactor branch December 30, 2021 19:54
@adamjstewart adamjstewart added utilities Utilities for working with geospatial data and removed utilities Utilities for working with geospatial data labels Jan 2, 2022
@adamjstewart adamjstewart mentioned this pull request Dec 26, 2022
12 tasks
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
* Refactor RegressionTask testing

* Programmatically determine max value

* Refactor ClassificationTask testing

* Silence warnings

* Refactor SegmentationTask testing

* Fix training mappings

* Fix GeoDataset trainers

* Fix ETCI trainer fake data

* Update OSCD training data

* Get LandCoverAI tests to pass

* Fix OSCD checksum handling

* Fix NAIP-Chesapeake tests

* Fix OSCD tests

* Keep BoundingBox icy

* Fix other datamodules

* Fix chesapeake testing

* Refactor BYOLTask tests

* Style fixes

* Silence pytorch-lightning warnings

* Get coverage for Chesapeake CVPR prior

* Fix trainer tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants