Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dataset-specific trainers #286

Merged
merged 18 commits into from
Jan 1, 2022
Merged

Conversation

@adamjstewart adamjstewart added the trainers PyTorch Lightning trainers label Dec 15, 2021
@adamjstewart adamjstewart added this to the 0.2.0 milestone Dec 15, 2021
@isaaccorley
Copy link
Collaborator

Just a note: I've recently been using kornia augmentations in the datamodule on a side project and one of the things I ran into was how to let the datamodule know if I'm loading data for a train/val/test set so that I can choose to augment or not. Found that you can access a bool attr self.trainer.training in the datamodule so you can do something like:

def on_after_batch_transfer(self, batch, dataloader_idx):
   if self.trainer.training:
      # Augment only if loading for train_step
      batch = augmentations(batch)
    return batch

@adamjstewart
Copy link
Collaborator Author

Note that none of this code currently gets hit by our tests. We aren't using a pl.Trainer and so things like self.trainer are None. Still trying to figure out the best way to test this.

@adamjstewart adamjstewart force-pushed the trainers/dataset-specific branch 3 times, most recently from c3c69bc to 9657a67 Compare December 24, 2021 22:37
@github-actions github-actions bot added datasets Geospatial or benchmark datasets documentation Improvements or additions to documentation testing Continuous integration testing datamodules PyTorch Lightning datamodules labels Dec 24, 2021
@adamjstewart adamjstewart force-pushed the trainers/dataset-specific branch from 9657a67 to 76f24fa Compare December 30, 2021 20:59
@adamjstewart
Copy link
Collaborator Author

Note: I don't think we're adding the predictions to the batch before plotting, we probably should

@@ -17,26 +16,6 @@
)


class FakeExperiment(object):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to remove these in #329

@adamjstewart
Copy link
Collaborator Author

Test failure is because the So2Sat dataset doesn't know how to plot any of the datamodule reduced band set options. We should probably move these to the dataset level.

@adamjstewart
Copy link
Collaborator Author

Another hiccup: self.trainer.datamodule.val_dataset.plot(...) doesn't work for datamodules that use Subset or random_split. One possible solution would be to use self.trainer.datamodule.plot(...) and add a plot(...) method to every DataModule that passes all args to the Dataset plot(...) method.

@isaaccorley
Copy link
Collaborator

isaaccorley commented Dec 31, 2021

You can access the plot method for Subset datasets like self.trainer.datamodule.val_dataset.dataset.plot. Not sure what workaround we should make for this.

Edit: I think adding a plot method to each datamodule that just calls the dataset plot method is a decent solution.

@adamjstewart adamjstewart force-pushed the trainers/dataset-specific branch from 6bf2b19 to 1035d5e Compare December 31, 2021 20:46
@adamjstewart adamjstewart marked this pull request as ready for review December 31, 2021 22:09
@adamjstewart adamjstewart force-pushed the trainers/dataset-specific branch from ccfe068 to e4f08d3 Compare January 1, 2022 02:16
@adamjstewart adamjstewart marked this pull request as draft January 1, 2022 02:38
@adamjstewart
Copy link
Collaborator Author

I believe the failing unit tests for ClassificationTask are because VisionClassificationDataset is overwriting self.classes and our fake data only has 2 classes. Simple fix would be to add more fake data.

Still haven't investigated the failing unit tests for SemanticSegmentationTask. Will do so tomorrow.

@adamjstewart adamjstewart marked this pull request as ready for review January 1, 2022 16:52
torchgeo/trainers/classification.py Show resolved Hide resolved
torchgeo/datamodules/landcoverai.py Outdated Show resolved Hide resolved
conf/task_defaults/eurosat.yaml Show resolved Hide resolved
@adamjstewart adamjstewart merged commit 42b9a6d into main Jan 1, 2022
@adamjstewart adamjstewart deleted the trainers/dataset-specific branch January 1, 2022 20:14
@adamjstewart adamjstewart added utilities Utilities for working with geospatial data and removed utilities Utilities for working with geospatial data labels Jan 2, 2022
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
* Remove dataset-specific trainers

* Collation functions will be new in 0.2.0

* Clarify arg docstring

* Style fixes

* Remove files forgotten in rebase

* Fix bug in unbind_samples, add tests

* Fix bugs in datamodule augmentations

* Increase coverage for datamodules

* Fix bugs in logger plotting, properly test

* Fix tests

* Increase coverage of trainers

* Use datamodule plot instead of dataset plot

* Skip datamodules without tests

* Plot predictions

* Fix ClassificationTask tests

* Fix SemanticSegmentationTask tests

* EAFP -> LBYL

* Ensure that tensors are on the CPU before plotting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets documentation Improvements or additions to documentation testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Reduce code duplication in trainers
3 participants