Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random crop logic to DeepGlobeLandCover Datamodule #876

Merged
merged 40 commits into from
Dec 30, 2022

Conversation

nilsleh
Copy link
Collaborator

@nilsleh nilsleh commented Oct 29, 2022

Adressing #855 by adding random crop logic to datamodule. Additionally, add a test file for the datamodule.

@github-actions github-actions bot added datamodules PyTorch Lightning datamodules testing Continuous integration testing labels Oct 29, 2022
torchgeo/datamodules/deepglobelandcover.py Outdated Show resolved Hide resolved
torchgeo/datamodules/deepglobelandcover.py Outdated Show resolved Hide resolved
@calebrob6
Copy link
Member

I confirmed this all works as expected with the actual dataset!

@github-actions github-actions bot added the datasets Geospatial or benchmark datasets label Nov 23, 2022
@nilsleh
Copy link
Collaborator Author

nilsleh commented Nov 24, 2022

I noticed something after the last commit, namely the case when val_split_pct=0. Because then the val_dataset is assigned the training dataset with its transforms (importantly the random crop tile logic), however, the val_dataloader does not receive the collate_fn like the train_dataloader to combine the patches into a batch. I suppose there are two cases:

  1. Also pass the collate_fn to val_dataloader if val_split_pct=0 (included in the last commit)
  2. Or when val_split_pct=0 then train_dataset is assigned to val_dataset but with the validation transforms that do not contain the crop logic.

Just wanted to ask, which logic we should go for and whether we can do that consistently across all the datamodules.

@nilsleh nilsleh mentioned this pull request Nov 27, 2022
@adamjstewart adamjstewart added the backwards-incompatible Changes that are not backwards compatible label Nov 27, 2022
@adamjstewart adamjstewart added this to the 0.4.0 milestone Nov 27, 2022
@calebrob6 calebrob6 closed this Dec 17, 2022
@calebrob6 calebrob6 reopened this Dec 17, 2022
@calebrob6
Copy link
Member

What is the minimum test complaining about?

@adamjstewart
Copy link
Collaborator

What is the minimum test complaining about?

Same thing I fixed here. Either merge that PR first or make the same change to this PR.

@ashnair1
Copy link
Collaborator

Suggestions made in #929 (relative imports in datamodule and if guard for collate fn) are applicable here as well

torchgeo/datasets/utils.py Outdated Show resolved Hide resolved
torchgeo/datasets/utils.py Outdated Show resolved Hide resolved
tests/datamodules/test_deepglobelandcover.py Outdated Show resolved Hide resolved
tests/trainers/test_segmentation.py Outdated Show resolved Hide resolved
tests/trainers/test_segmentation.py Outdated Show resolved Hide resolved
torchgeo/datamodules/deepglobelandcover.py Outdated Show resolved Hide resolved
torchgeo/datamodules/deepglobelandcover.py Outdated Show resolved Hide resolved
torchgeo/datamodules/deepglobelandcover.py Outdated Show resolved Hide resolved
@github-actions github-actions bot added the transforms Data augmentation transforms label Dec 20, 2022
@nilsleh
Copy link
Collaborator Author

nilsleh commented Dec 20, 2022

Not sure what you think of the transform implementation I did. It is very similar to the existing AugmentationSequential Transform and I think in the future we could consolidate that in a new PR. Trying to stick the existing AugmentationSequential into a Compose doesn't seem to work, but in a sense this new version is more general.

Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this approach makes sense. Or rather, it's a perfectly valid approach, but it's not the way torchvision/torchgeo/kornia do things, and I don't see any reason why we need to introduce a completely different way of handling this. The way I think about it is:

  1. _pad_segmentation_sample is a transform/data augmentation
  2. All transforms/data augmentations in torchvision/torchgeo/kornia are classes that implement __call__, not functions
  3. Kornia's AugmentationSequential/AugmentationPatches and torchvision's Compose are containers, not transforms/data augmentations

There are also a lot of inconsistencies in the API that this introduces:

  1. All torchgeo transforms are in torchgeo/transforms, not torchgeo/datasets
  2. Kornia: AugmentationSequential/AugmentationPatches, TorchGeo: AugmentationSequential/PatchesAugmentation (names are not consistent)
  3. Kornia: AugmentationPatches splits into patches, performs augmentation, reassembles. TorchGeo: PatchesAugmentation applies augmentations to all patches (behavior is not consistent)

I still think we should use a transform instead of a function. Then you can use it with Compose or AugmentationSequential or whatever else. I would subclass nn.Module like all of our other existing transforms.

I also think this probably belongs in torchgeo/transforms/, although we can still keep it a hidden method until we have a better idea of where to fit it in the library. We can put it in transforms.py for now.

torchgeo/datasets/utils.py Outdated Show resolved Hide resolved
@nilsleh
Copy link
Collaborator Author

nilsleh commented Dec 20, 2022

Okay so I should

  1. create one transform class for _pad_segmentation_sample and put it in transforms.py instead of utils.py?
  2. regarding what I currently named PatchesAugmentation in transforms.py, it is intended as a torchvision transform that uses Kornia because it handles dicts but is used again in torch.Compose because for the dataset processing that also includes the preprocess function. I am not sure what you are recommending here? I am not aiming to copy the Kornia AugmentationPatches here, rather just one transformation for the segmentation datamodules.

@adamjstewart
Copy link
Collaborator

I think we want to move

FROM: function _pad_segmentation_sequential in torchgeo/datasets/utils.py
TO: class PadSegmentationSequential in torchgeo/transforms/transforms.py

We can bikeshed on the name later, but I'll prob want to rename it to match whatever torchvision/kornia calls it.

I also think we want to remove PatchesAugmentation entirely and use either Compose or AugmentationSequential like we do with all of our other transforms.

I'm starting to suspect that there's something I'm not understanding about why PatchesAugmentation is even required, or why we can't use torchvision's Pad or Kornia's PadTo transforms. Might be useful to clear up that misunderstanding first, maybe none of this is required.

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Dec 27, 2022
@github-actions github-actions bot added the samplers Samplers for indexing datasets label Dec 28, 2022
Comment on lines +32 to +33
# kornia 0.6.5+ required due to change in kornia.augmentation API
kornia>=0.6.5,<0.7
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Kornia 0.6.5+, all augmentation instance methods have a new flags parameter. So the transforms I added won't work with Kornia 0.6.4 and older. Once we upstream these transforms to Kornia, we'll need to depend on an even newer version anyway.

Comment on lines -11 to +10
from torchvision.transforms import Compose
from kornia.augmentation import Normalize
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm planning on removing all torchvision transforms. Torchvision relies on PIL for many of its transforms, which doesn't support MSI. Kornia has all of the same transforms, but they are in pure PyTorch, so they can run on the GPU and support MSI. I don't see a good reason not to only use Kornia transforms.

Comment on lines +56 to +57
*batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*,
and *patch_size*.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to only document API changes, not internal changes. So the fact that we're now using random cropping isn't documented, only that the parameters changed.

self.kwargs = kwargs

def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot use instance methods as transforms, see #886 for what happens when we do.

Comment on lines -73 to -82
self.train_dataset: Dataset[Any]
self.val_dataset: Dataset[Any]

if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset
self.val_dataset = dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea why our previous logic was so complicated, but I don't think it needs to be.

Comment on lines +138 to +139
if self.trainer:
if self.trainer.training:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So much cleaner than our previous logic!

Comment on lines 144 to 145
# Kornia adds a channel dimension to the mask
batch["mask"] = batch["mask"].squeeze(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kornia does a lot of weird stuff with transforms that I don't like. Masks are required to be floats (why? slower, more storage). If the mask you input doesn't have a channel dimension, it will add one. Some of the transforms actually break if the mask doesn't have a channel dimension when you input it, so we may need to add an unsqueeze above.

Comment on lines +14 to +21
@overload
def _to_tuple(value: Union[Tuple[int, int], int]) -> Tuple[int, int]:
...


@overload
def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Python typing, all ints are floats, but not all floats are ints. This meant that if I pass an int as input, mypy would consider its output type to be float. These overloads ensure that int maps to int and float maps to float as expected.

Returns:
the transformation
"""
out: Tensor = self.identity_matrix(input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't correct but something is required and we don't actually use it. I'll iron out the details when we upstream this.

torchgeo/transforms/transforms.py Show resolved Hide resolved
@adamjstewart adamjstewart marked this pull request as ready for review December 29, 2022 19:06
@nilsleh
Copy link
Collaborator Author

nilsleh commented Dec 29, 2022

I am using the following script to test:

dm = DeepGlobeLandCoverDataModule(
    root="./data/archive/", num_tiles_per_batch=2, num_patches_per_tile=4, patch_size=64
)

model = SemanticSegmentationTask(
    model="unet",
    backbone="resnet18",
    weights="imagenet",
    in_channels=3,
    num_classes=2,
    loss="jaccard",
    ignore_index=0,
    learning_rate=0.01,
    learning_rate_schedule_patience=0,
)
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1)
trainer.fit(model, dm)

But in model.training_step the input has shape [8, 3, 64, 64] but the mask has shape [4, 2, 64, 64].

@nilsleh
Copy link
Collaborator Author

nilsleh commented Dec 29, 2022

Here is a visualization, looks correct.

Input image:
image

Sampled patches:
patches

@adamjstewart adamjstewart merged commit c62d832 into microsoft:main Dec 30, 2022
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
* crop logic

* typo

* change train_batch_size logic

* fix failing test

* typos and naming

* return argument train dataloader

* typo

* fix failing test

* suggestions except about test file

* remove test_deepglobe and add test to trainer

* forgot new conf file

* reanme collate function

* move cropping logic to transform and utils

* remove comment

* simplify

* move pad_segmentation to transforms

* another one

* naming and versionadded

* another transforms approach

* typo

* fix read the docs

* some checks for Ncrop

* add unit tests new transforms

* Remove cruft

* More simplification

* Add config file

* Implemented ExtractTensorPatches

* Remove tests

* Remove unnecessary attrs

* Apply to both input and mask

* Implement RandomNCrop

* Fix dimensions

* mypy fixes

* Fix docs

* Ensure that image and mask get the same transformation

* Bump min kornia version

* ignore still needed?

* Remove unneeded hacks

* Fix pydocstyle

* Fix dimensions

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backwards-incompatible Changes that are not backwards compatible datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets documentation Improvements or additions to documentation samplers Samplers for indexing datasets testing Continuous integration testing transforms Data augmentation transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants