From e70f90d9bd9775c4e3a846955e42c8bd85f3632e Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 28 Oct 2021 22:31:02 -0500 Subject: [PATCH 01/19] add additional bigearthnet test data for train/val/test split --- .../bigearthnet/BigEarthNet-S1-v1.0.tar.gz | Bin 1086 -> 1358 bytes .../bigearthnet/BigEarthNet-S2-v1.0.tar.gz | Bin 1110 -> 1762 bytes 2 files changed, 0 insertions(+), 0 deletions(-) diff --git a/tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz b/tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz index 9169a8a932c58a3e93789b8f209d8ff6a590043a..d9df455f105d44fc1fab8b90f739c1a585089166 100644 GIT binary patch literal 1358 zcmV-U1+n@ciwFP!000001MOT*ZzDArp0w3&Sz5FlSS}!0;fDG#p0BW6XiL#a&RT-p62{ER*G6R;`#zFm$A0~h zc z+0!1#Fr;xp7(yELA{LxucTxEj#nzDduj5U3XWxDB<@Q4t*SGgvgU^@lVDm9{Y~35>_jK#&( zcP;1`J37|T?j?<<58e4^QIwo#cTxEj#l-7>9AnqQ?%u<-gf?gWu>m<`%zt7158eOm zn$yq!hHcgO`qyQA0XKaA-`V+4TjcRwO{>gn6`q>op;AcC;)ppHJWe#~3{N%X_ZHa5l^-(>*!ANXG`|L+0+ zD+DF@-_c9_?-;=UGEnCKEELjR{ujLcy8JIfm;VKK595Ckt~>vqjyE6)P=Wsq-2wie z5v22fKpK7!yGMRP4oE`Qp2kr)>o4I3Lka#j3;b{A{BHvP%fRXK9mG-XEX(T0=<7~K zmiZ_OS1Dt(MH2r25k?w;j{*`NR96aCpGF~V`YcLg)Qp%PC)G~jH@Q*Oiz0SFL*9t! z;IvRR4SwJ``buBc$zWE%N>)L~s+04q+fkcEPkk?;QOFCc)n5OM533!8d=M=+HrzHz zyhCA9H#XhR>pQzV>Rq-v)=k{Gx?eRBdw%GV7kGBEM_AIBKPALSrp{zuk- zUhV(3#k0WwlK{|v%Jo41nJ$3;)rDmK2k}QS0RIF31OE@&1OHDG7wbRfApeUk0r+16 zQvDt$^k|y|!3l~yFKts^p>cvpi#M4DNy_{c^gT@uQA}GQZBmbf32)IS zF^ZZfB;0}o=-6-Z?IRO;3O#JGh#ag2J~`%%<}y^FcqbG?LPvzfhkjfv+2`YC$cs`I z^MYKs8me2BWvjAkRTiyEnpG}Y4Ogrt6s+`a3uV3P9J%W1s8xVzW3v6 zDFFW;{Qrs8|2W3r|0@R>^MB#*e~#7l{~cRA3nG611OHzNrf~jisO!D_e;)QVIU8o) zzaD49RN#Nm|7Hid{15t{YEXm!2k-wl|JQ*3WdNT4D%XSOzorY|e{~_5|Ka?vVo-wr z%jbV=9nSwq0i6FS*SksQe@cWN6fd5ynca!`-;D6wMiTn@xk77`Q|DD16KihJE|D^!*pK?9Wf2NCz_#e)H z&JJ?5{kx17%XzYY8^1LvOqJnspxsUH))zJDW#ug6hQ0000000000 Q0C2_eA1;maY5;fu0KvP>q5uE@ literal 1086 zcmV-E1i||siwFP!000001MQmcZ`(u|$DOv?%Ba!yg7F5VQ+S8k)W6O4hO{c$B83T& ztwln}x;RdH!?nZNE>vX_0;KSYm;42Yw`~6if5u+(C-9u(EN&8-mbQ+!J|F3F=ex&u zXP@uBPki2Hy-z8R_aAxjMoZl|RBOtnbjc`yZJHuhZBxm|BP7*Ob^Tp<*m>xvt({#*gYi-*D{4awuw@vd@t%wo zc&bi33yP{*iYbxi74_!?qw(iq7@uZ0ulzOT!t?(;#?J1;TFknOzGh@h5#yhm|1;Nr z)3zqof75_0q7-ohuKD`kXnZ8CLVR12D$7y@QcDmjx%5(&CJ>ghv5Wj@**I+qq|1KClAFQS4?Lzv+ zU@7f=x01i6e?B=GpOvwFCGF8oNm>!p}ec2Y>K7Bh<&W?OboLq&@2YWin6yHASD)f3YN>{ue;8`tQ@W=SR+g7t;=nY3*qg28+G}*P9aPzY3+} z_y3xj>OYvH{ujXUEdo@B6SI1se3c%jc{?m}!ysls5+)Hzc$)^KOWT~e)%CH}qcGTD zoU@Es>XSpS2g^sy!XV$mHV7sdJFMj*OG4k(ZI zy(nL@2mR*64HF(gL7M;du(;}mA@49~*AsWGHvZ5i&3KkI z(dXe)&rL7pKB-n;HO%1x`e6XA_nb&``ZRX;g+)u-az1M{nh=fW8O$JE9Zttl4Uy{w zo)f`Uxn90z>VBKBh=~TBm>mquZarJ{Q-~~zi>?3qy{p>)vvL2g5YYce|Nlbs|2)R% z|Cf#;#((bnKg$^T|CTAf{}J#1(El%lV|nrqfP5z3UY_H6%m3wJ%bANU&-Ah7X(wBr z_Oj(^H(Q?av*paumS=m~^2Jk#Zj#Omz8B2>7re9kw~+%X`*?fY#p`vtzt{F6$1ABneKe#`!*X6n>|^P%E1Hp6x2% z+vAp>YdxU&eIQ7=+CMaKx8W9#Lh)=ocPM4VNZ@qH{^!iA^f3Lr=!Na%*1p`!=d0l8 zUeP|pd^xNURJ+&oH|D&EIM{fQynM$>)>uY;OqcQ~KO3qC?m5G1LCq?(+iidCj}{O^L&vV0mB-+s^^T9BR0<&d9hZ?+)wD45(xSyH*G9i z)plnzZ7^44-1%4?9c#y{h;m%k80zxmXn}{Ti2P#Y-{7Dg|1(kMkweSe)1plWy%QeD zuKTe@$pMu;%d!n@XyabxC;P_wO@{|d1AOqF+W%b8L3cBxF-&IRCs-8Q4#j|b+ymKUX9?Sr=FzDO7L8%{tlpNDkrG_?zy>M` zqx{h#+yIt$zakp8WpaJZk%=l-0A4h>!Nk9!xhQcLWmiDC&hgUkk5jcEOX24=}ex`vH#Hq1T zqF5uC`JW9-{)5-=>1&HzEa)lLj8)F4br^Md0jU`)*TQ=5K9Y>O9i^p?!C%m;o3R-7 zoP%2>e2B={iwr0rqU*`1`y!%mAF`WsjO*MUG0QLx)(y4s*Ld=zPTW^>!Ldq+*cMP) z4J(jQ<_0ar%wS+1hkKHuD5x}=G6+jbtcglMw07AA2~p(GM_uWPA;J3ZEul%i0+|$l z&_z3zF3FNmYu{T#%Y&G0=o6Y}C7cpMr1_|dp9~97*3ePFU-23-0GG zX?NoDTmmLo*SD(oUeRb5*wu}EwMn`LTCK{Urjfdv02?re;~%6ytJ3aN4DKJD%7Sas ztO@E2_Chu?U?G9+jB_YF0hyiCdNC!wHo5(r3@J&DnNBBDz9Wfm)x8Ul{^d|EojE%7xoioJj_(`2xGJu5AHCvKK8~bU9{L`(Puts*}uKD%zIc! zd9%#HCnT|haI>9grF8qVcDNdzUi+2vXR$7e^8Lsg^~+Trug>7NY}dvV&aG^{{J>U5 zej?R4amfQdtK`@uLA?eCX3MA6ki+I~=vPs#ArZX~XF0(rR)=G{c$RZi^Y9~GDcqk& zwH1?TErEK@L;(DFv@q0qGURpgK)dD81UxEBz}k-S_S zi0?OCY_XLKgS6`(#7Pz19tA@M=^BZ$?5`Wl6Ncwes_dAK@Jdi6hiMLa0YdL6T4LS@ zRa`l&Se6lyGGcWJxUct?@7UzluuFsRfT7&~DL>z~9UgAg zDjA_Jn<|cpAH5Cc^a1Za`A(~JDxqe40TPUgueI1+MrD zNp|%uA9LT-&47gWj1dBwO>2(9OIg=r8TOosT$CoYZc`fN`f#C;mAQ8Fc30EFlh+kt zoy{v9+Oz3prn&dVV*T9@iGE4%I>clATDD1_uXq~dj+#)#@`e-KrY3P3$0y9KH6czH zcR+UuBZ3Gl_1jY5`!yg|(|&^?3>=FB1dBaE{5HsdiyXZx%FBoON}2wV!%jR9S^Q3g ztCls`Q%hPbTd2i~nU!MZX%KAv$(dpu-cAMeev)|GfaWTG02dDK-f1U*CqOFjcEsJ8 zSp4z|5WR-Dp5+RBnjl-QLEG*6^uZUj*{I9V4Cf4Y+40K?H5$_fuz28Xfa2%?%MOVE z%!V!IwxV?r>IE@H;5!$6QyPd;*Q)TT2)HUkM(i4;kZI=>qk%00d5{=4{}UO%3G!5{kEh;C%`FVY`s A1ONa4 From fbaa8b5e29eebdc580b8806f421edbeaec338089 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 28 Oct 2021 22:31:44 -0500 Subject: [PATCH 02/19] update bigearthnet dataset length test --- tests/datasets/test_bigearthnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py index 89416828fa9..f30d7c7d182 100644 --- a/tests/datasets/test_bigearthnet.py +++ b/tests/datasets/test_bigearthnet.py @@ -71,7 +71,7 @@ def test_getitem(self, dataset: BigEarthNet) -> None: assert x["image"].shape == (12, 120, 120) def test_len(self, dataset: BigEarthNet) -> None: - assert len(dataset) == 2 + assert len(dataset) == 4 def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None: BigEarthNet(root=str(tmp_path), bands=dataset.bands, download=True) From 5f50ec2b09e71bcdc0942f8f9ea9e877f44892a9 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 28 Oct 2021 22:32:16 -0500 Subject: [PATCH 03/19] add MultiLabelClassificationTask --- torchgeo/trainers/__init__.py | 6 +- torchgeo/trainers/tasks.py | 117 +++++++++++++++++++++++++++++++++- 2 files changed, 121 insertions(+), 2 deletions(-) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index 9fb4f1d9c41..59e970c5483 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -3,6 +3,7 @@ """TorchGeo trainers.""" +from .bigearthnet import BigEarthNetClassificationTask, BigEarthNetDataModule from .byol import BYOLTask from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask @@ -11,13 +12,16 @@ from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask from .so2sat import So2SatClassificationTask, So2SatDataModule -from .tasks import ClassificationTask +from .tasks import ClassificationTask, MultiLabelClassificationTask from .ucmerced import UCMercedClassificationTask, UCMercedDataModule __all__ = ( # Tasks "ClassificationTask", + "MultiLabelClassificationTask", # Trainers + "BigEarthNetClassificationTask", + "BigEarthNetDataModule", "BYOLTask", "ChesapeakeCVPRSegmentationTask", "ChesapeakeCVPRDataModule", diff --git a/torchgeo/trainers/tasks.py b/torchgeo/trainers/tasks.py index 1a1b6e46638..679ee9877dc 100644 --- a/torchgeo/trainers/tasks.py +++ b/torchgeo/trainers/tasks.py @@ -73,7 +73,10 @@ def config_model(self) -> None: w_new = torch.clone( # type: ignore[attr-defined] self.model.conv1.weight ).detach() - w_new[:, :3, :, :] = w_old + if in_channels > 3: + w_new[:, :3, :, :] = w_old + else: + w_new[:, :in_channels, :, :] = w_old[:, :in_channels, :, :] self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501 w_new ) @@ -264,3 +267,115 @@ def configure_optimizers(self) -> Dict[str, Any]: "monitor": "val_loss", }, } + + +class MultiLabelClassificationTask(ClassificationTask): + """Abstract base class for multi label image classification LightningModules.""" + + #: number of classes in dataset + num_classes: int = 43 + + def config_task(self) -> None: + """Configures the task based on kwargs parameters passed to the constructor.""" + self.config_model() + + if self.hparams["loss"] == "bce": + self.loss = nn.BCEWithLogitsLoss() # type: ignore[attr-defined] + else: + raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.") + + def __init__(self, **kwargs: Any) -> None: + """Initialize the LightningModule with a model and loss function. + + Keyword Args: + classification_model: Name of the classification model use + loss: Name of the loss function + weights: Either "random", "imagenet_only", "imagenet_and_random", or + "random_rgb" + """ + super().__init__(**kwargs) + self.save_hyperparameters() # creates `self.hparams` from kwargs + + self.config_task() + + self.train_metrics = MetricCollection( + { + "OverallAccuracy": Accuracy( + num_classes=self.num_classes, average="micro", multiclass=False + ), + "AverageAccuracy": Accuracy( + num_classes=self.num_classes, average="macro", multiclass=False + ), + "F1Score": FBeta( + num_classes=self.num_classes, + beta=1.0, + average="micro", + multiclass=False, + ), + }, + prefix="train_", + ) + self.val_metrics = self.train_metrics.clone(prefix="val_") + self.test_metrics = self.train_metrics.clone(prefix="test_") + + def training_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> Tensor: + """Training step. + + Args: + batch: Current batch + batch_idx: Index of current batch + + Returns: + training loss + """ + x = batch["image"] + y = batch["label"] + y_hat = self.forward(x) + + loss = self.loss(y_hat, y.to(torch.float)) + + # by default, the train step logs every `log_every_n_steps` steps where + # `log_every_n_steps` is a parameter to the `Trainer` object + self.log("train_loss", loss, on_step=True, on_epoch=False) + self.train_metrics(y_hat, y) + + return cast(Tensor, loss) + + def validation_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Validation step. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["label"] + y_hat = self.forward(x) + + loss = self.loss(y_hat, y.to(torch.float)) + + self.log("val_loss", loss, on_step=False, on_epoch=True) + self.val_metrics(y_hat, y) + + def test_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Test step. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["label"] + y_hat = self.forward(x) + + loss = self.loss(y_hat, y.to(torch.float)) + + # by default, the test and validation steps only log per *epoch* + self.log("test_loss", loss, on_step=False, on_epoch=True) + self.test_metrics(y_hat, y) From aa3891b33481e98466b058bfed12c16c844942b7 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 28 Oct 2021 22:32:40 -0500 Subject: [PATCH 04/19] add BigEarthNet trainer and datamodule --- torchgeo/trainers/bigearthnet.py | 149 +++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 torchgeo/trainers/bigearthnet.py diff --git a/torchgeo/trainers/bigearthnet.py b/torchgeo/trainers/bigearthnet.py new file mode 100644 index 00000000000..21ed896dd8b --- /dev/null +++ b/torchgeo/trainers/bigearthnet.py @@ -0,0 +1,149 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""BigEarthNet trainer.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import Compose, Normalize + +from ..datasets import BigEarthNet +from ..datasets.utils import dataset_split +from .tasks import MultiLabelClassificationTask + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class BigEarthNetClassificationTask(MultiLabelClassificationTask): + """LightningModule for training models on the BigEarthNet Dataset.""" + + num_classes = 43 + + +class BigEarthNetDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the BigEarthNet dataset. + + Uses the train/val/test splits from the dataset. + """ + + # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) + band_mins = torch.tensor( # type: ignore[attr-defined] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ) + band_maxs = torch.tensor( # type: ignore[attr-defined] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ) + band_means = torch.tensor( # type: ignore[attr-defined] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ) + band_stds = torch.tensor( # type: ignore[attr-defined] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ) + + def __init__( + self, + root_dir: str, + bands: str = "all", + batch_size: int = 64, + num_workers: int = 4, + unsupervised_mode: bool = False, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for BigEarthNet based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes + bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all} + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + unsupervised_mode: Makes the train dataloader return imagery from the train, + val, and test sets + val_split_pct: What percentage of the dataset to use as a validation set + test_split_pct: What percentage of the dataset to use as a test set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.bands = bands + self.batch_size = batch_size + self.num_workers = num_workers + self.unsupervised_mode = unsupervised_mode + + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + + self.norm = Normalize(self.band_means, self.band_stds) + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset.""" + sample["image"] = sample["image"].float() + # sample["image"] /= 255.0 + # sample["image"] = self.norm(sample["image"]) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + BigEarthNet(self.root_dir, bands=self.bands, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + """ + transforms = Compose([self.preprocess]) + + if not self.unsupervised_mode: + + dataset = BigEarthNet( + self.root_dir, bands=self.bands, transforms=transforms + ) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + ) + else: + self.train_dataset = BigEarthNet( + self.root_dir, bands=self.bands, transforms=transforms + ) + self.val_dataset, self.test_dataset = None, None # type: ignore[assignment] + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation.""" + if self.unsupervised_mode or self.val_split_pct == 0: + return self.train_dataloader() + else: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing.""" + if self.unsupervised_mode or self.test_split_pct == 0: + return self.train_dataloader() + else: + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) From 48d12862fd43218e43b21853a9b518bbc03c2e3d Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 28 Oct 2021 22:33:31 -0500 Subject: [PATCH 05/19] add bigearthnet and multilabelclassificationtask tests --- conf/bigearthnet.yaml | 18 +++++ conf/task_defaults/bigearthnet.yaml | 13 +++ tests/trainers/test_bigearthnet.py | 47 +++++++++++ tests/trainers/test_tasks.py | 121 +++++++++++++++++++++++++--- 4 files changed, 187 insertions(+), 12 deletions(-) create mode 100644 conf/bigearthnet.yaml create mode 100644 conf/task_defaults/bigearthnet.yaml create mode 100644 tests/trainers/test_bigearthnet.py diff --git a/conf/bigearthnet.yaml b/conf/bigearthnet.yaml new file mode 100644 index 00000000000..7f8b52f9da2 --- /dev/null +++ b/conf/bigearthnet.yaml @@ -0,0 +1,18 @@ +trainer: + gpus: 1 # single GPU training + min_epochs: 10 + max_epochs: 40 + benchmark: True + +experiment: + task: "bigearthnet" + module: + loss: "bce" + classification_model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 14 + datamodule: + batch_size: 128 + num_workers: 6 + bands: "all" diff --git a/conf/task_defaults/bigearthnet.yaml b/conf/task_defaults/bigearthnet.yaml new file mode 100644 index 00000000000..723d4e7d954 --- /dev/null +++ b/conf/task_defaults/bigearthnet.yaml @@ -0,0 +1,13 @@ +experiment: + task: "bigearthnet" + module: + loss: "bce" + classification_model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: "random" + in_channels: 14 + datamodule: + batch_size: 128 + num_workers: 6 + bands: "all" diff --git a/tests/trainers/test_bigearthnet.py b/tests/trainers/test_bigearthnet.py new file mode 100644 index 00000000000..cc417f36548 --- /dev/null +++ b/tests/trainers/test_bigearthnet.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from typing import Tuple, cast + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.trainers import BigEarthNetDataModule + + +@pytest.fixture(scope="module", params=[("s1", 2), ("s2", 12), ("all", 14)]) +def bands(request: SubRequest) -> Tuple[str, int]: + return cast(Tuple[str, int], request.param) + + +@pytest.fixture(scope="module", params=[True, False]) +def datamodule(bands: Tuple[str, int], request: SubRequest) -> BigEarthNetDataModule: + band_set = bands[0] + unsupervised_mode = request.param + root = os.path.join("tests", "data", "bigearthnet") + batch_size = 1 + num_workers = 0 + dm = BigEarthNetDataModule( + root, + band_set, + batch_size, + num_workers, + unsupervised_mode, + val_split_pct=0.3, + test_split_pct=0.3, + ) + dm.prepare_data() + dm.setup() + return dm + + +class TestBigEarthNetDataModule: + def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/trainers/test_tasks.py b/tests/trainers/test_tasks.py index 1f913d161b6..1dae029c5e5 100644 --- a/tests/trainers/test_tasks.py +++ b/tests/trainers/test_tasks.py @@ -9,19 +9,31 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf -from torchgeo.trainers import ClassificationTask, So2SatDataModule +from torchgeo.trainers import ( + BigEarthNetDataModule, + ClassificationTask, + MultiLabelClassificationTask, + So2SatDataModule, +) from .test_utils import mocked_log @pytest.fixture(scope="module", params=[("rgb", 3), ("s2", 10)]) -def bands(request: SubRequest) -> Tuple[str, int]: +def bands_so2sat(request: SubRequest) -> Tuple[str, int]: + return cast(Tuple[str, int], request.param) + + +@pytest.fixture(scope="module", params=[("s1", 2), ("s2", 12), ("all", 14)]) +def bands_bigearthnet(request: SubRequest) -> Tuple[str, int]: return cast(Tuple[str, int], request.param) @pytest.fixture(scope="module", params=[True, False]) -def datamodule(bands: Tuple[str, int], request: SubRequest) -> So2SatDataModule: - band_set = bands[0] +def datamodule_classification( + bands_so2sat: Tuple[str, int], request: SubRequest +) -> So2SatDataModule: + band_set = bands_so2sat[0] unsupervised_mode = request.param root = os.path.join("tests", "data", "so2sat") batch_size = 2 @@ -32,15 +44,40 @@ def datamodule(bands: Tuple[str, int], request: SubRequest) -> So2SatDataModule: return dm +@pytest.fixture(scope="module", params=[True, False]) +def datamodule_multilabel( + bands_bigearthnet: Tuple[str, int], request: SubRequest +) -> BigEarthNetDataModule: + band_set = bands_bigearthnet[0] + unsupervised_mode = request.param + root = os.path.join("tests", "data", "bigearthnet") + batch_size = 1 + num_workers = 0 + dm = BigEarthNetDataModule( + root, + band_set, + batch_size, + num_workers, + unsupervised_mode, + val_split_pct=0.3, + test_split_pct=0.3, + ) + dm.prepare_data() + dm.setup() + return dm + + class TestClassificationTask: @pytest.fixture( params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]) ) - def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]: + def config( + self, request: SubRequest, bands_so2sat: Tuple[str, int] + ) -> Dict[str, Any]: task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) task_args = OmegaConf.to_object(task_conf.experiment.module) task_args = cast(Dict[str, Any], task_args) - task_args["in_channels"] = bands[1] + task_args["in_channels"] = bands_so2sat[1] loss, weights = request.param task_args["loss"] = loss task_args["weights"] = weights @@ -60,21 +97,23 @@ def test_configure_optimizers(self, task: ClassificationTask) -> None: assert "lr_scheduler" in out def test_training( - self, datamodule: So2SatDataModule, task: ClassificationTask + self, datamodule_classification: So2SatDataModule, task: ClassificationTask ) -> None: - batch = next(iter(datamodule.train_dataloader())) + batch = next(iter(datamodule_classification.train_dataloader())) task.training_step(batch, 0) task.training_epoch_end(0) def test_validation( - self, datamodule: So2SatDataModule, task: ClassificationTask + self, datamodule_classification: So2SatDataModule, task: ClassificationTask ) -> None: - batch = next(iter(datamodule.val_dataloader())) + batch = next(iter(datamodule_classification.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) - def test_test(self, datamodule: So2SatDataModule, task: ClassificationTask) -> None: - batch = next(iter(datamodule.test_dataloader())) + def test_test( + self, datamodule_classification: So2SatDataModule, task: ClassificationTask + ) -> None: + batch = next(iter(datamodule_classification.test_dataloader())) task.test_step(batch, 0) task.test_epoch_end(0) @@ -110,3 +149,61 @@ def test_invalid_pretrained(self, checkpoint: str, config: Dict[str, Any]) -> No error_message = "Trying to load resnet18 weights into a resnet50" with pytest.raises(ValueError, match=error_message): ClassificationTask(**config) + + +class TestMultiLabelClassificationTask: + @pytest.fixture(params=zip(["bce", "bce"], ["imagenet", "random"])) + def config( + self, request: SubRequest, bands_bigearthnet: Tuple[str, int] + ) -> Dict[str, Any]: + task_conf = OmegaConf.load( + os.path.join("conf", "task_defaults", "bigearthnet.yaml") + ) + task_args = OmegaConf.to_object(task_conf.experiment.module) + task_args = cast(Dict[str, Any], task_args) + task_args["in_channels"] = bands_bigearthnet[1] + loss, weights = request.param + task_args["loss"] = loss + task_args["weights"] = weights + return task_args + + @pytest.fixture + def task( + self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] + ) -> MultiLabelClassificationTask: + task = MultiLabelClassificationTask(**config) + monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] + return task + + def test_training( + self, + datamodule_multilabel: BigEarthNetDataModule, + task: MultiLabelClassificationTask, + ) -> None: + batch = next(iter(datamodule_multilabel.train_dataloader())) + task.training_step(batch, 0) + task.training_epoch_end(0) + + def test_validation( + self, + datamodule_multilabel: BigEarthNetDataModule, + task: MultiLabelClassificationTask, + ) -> None: + batch = next(iter(datamodule_multilabel.val_dataloader())) + task.validation_step(batch, 0) + task.validation_epoch_end(0) + + def test_test( + self, + datamodule_multilabel: BigEarthNetDataModule, + task: MultiLabelClassificationTask, + ) -> None: + batch = next(iter(datamodule_multilabel.test_dataloader())) + task.test_step(batch, 0) + task.test_epoch_end(0) + + def test_invalid_loss(self, config: Dict[str, Any]) -> None: + config["loss"] = "invalid_loss" + error_message = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=error_message): + MultiLabelClassificationTask(**config) From 48a16d52727d4c9c200d6e80acd5485229acee91 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 28 Oct 2021 22:47:05 -0500 Subject: [PATCH 06/19] mypy and format --- torchgeo/trainers/bigearthnet.py | 2 +- torchgeo/trainers/tasks.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torchgeo/trainers/bigearthnet.py b/torchgeo/trainers/bigearthnet.py index 21ed896dd8b..9a4258246bb 100644 --- a/torchgeo/trainers/bigearthnet.py +++ b/torchgeo/trainers/bigearthnet.py @@ -110,7 +110,7 @@ def setup(self, stage: Optional[str] = None) -> None: dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) else: - self.train_dataset = BigEarthNet( + self.train_dataset = BigEarthNet( # type: ignore[assignment] self.root_dir, bands=self.bands, transforms=transforms ) self.val_dataset, self.test_dataset = None, None # type: ignore[assignment] diff --git a/torchgeo/trainers/tasks.py b/torchgeo/trainers/tasks.py index 679ee9877dc..b7a6e7b73fb 100644 --- a/torchgeo/trainers/tasks.py +++ b/torchgeo/trainers/tasks.py @@ -55,7 +55,7 @@ def config_model(self) -> None: # Update first layer if in_channels != 3: - w_old = None + w_old = torch.empty(0) # type: ignore[attr-defined] if pretrained: w_old = torch.clone( # type: ignore[attr-defined] self.model.conv1.weight @@ -76,7 +76,8 @@ def config_model(self) -> None: if in_channels > 3: w_new[:, :3, :, :] = w_old else: - w_new[:, :in_channels, :, :] = w_old[:, :in_channels, :, :] + w_old = w_old[:, :in_channels, :, :] + w_new[:, :in_channels, :, :] = w_old self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501 w_new ) @@ -334,7 +335,7 @@ def training_step( # type: ignore[override] y = batch["label"] y_hat = self.forward(x) - loss = self.loss(y_hat, y.to(torch.float)) + loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined] # by default, the train step logs every `log_every_n_steps` steps where # `log_every_n_steps` is a parameter to the `Trainer` object @@ -356,7 +357,7 @@ def validation_step( # type: ignore[override] y = batch["label"] y_hat = self.forward(x) - loss = self.loss(y_hat, y.to(torch.float)) + loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined] self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat, y) @@ -374,7 +375,7 @@ def test_step( # type: ignore[override] y = batch["label"] y_hat = self.forward(x) - loss = self.loss(y_hat, y.to(torch.float)) + loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined] # by default, the test and validation steps only log per *epoch* self.log("test_loss", loss, on_step=False, on_epoch=True) From 21418563cba99f40be0ce2ad99f8a174a0b3aa9a Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 28 Oct 2021 23:16:31 -0500 Subject: [PATCH 07/19] add estimated band min/max values for normalization --- torchgeo/trainers/bigearthnet.py | 68 ++++++++++++++++++++++++++------ 1 file changed, 56 insertions(+), 12 deletions(-) diff --git a/torchgeo/trainers/bigearthnet.py b/torchgeo/trainers/bigearthnet.py index 9a4258246bb..c60bb771003 100644 --- a/torchgeo/trainers/bigearthnet.py +++ b/torchgeo/trainers/bigearthnet.py @@ -8,7 +8,7 @@ import pytorch_lightning as pl import torch from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize +from torchvision.transforms import Compose from ..datasets import BigEarthNet from ..datasets.utils import dataset_split @@ -32,17 +32,51 @@ class BigEarthNetDataModule(pl.LightningDataModule): """ # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) - band_mins = torch.tensor( # type: ignore[attr-defined] - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # min/max band statistics computed on 83k random samples + band_mins_raw = torch.tensor( # type: ignore[attr-defined] + [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] ) - band_maxs = torch.tensor( # type: ignore[attr-defined] - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + band_maxs_raw = torch.tensor( # type: ignore[attr-defined] + [ + 31.0, + 35.0, + 18556.0, + 20528.0, + 18903.0, + 17846.0, + 16593.0, + 16512.0, + 16394.0, + 16575.0, + 16124.0, + 16097.0, + 15336.0, + 15203.0, + ] ) - band_means = torch.tensor( # type: ignore[attr-defined] - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + + # min/max band statistics computed by percentile clipping the + # above to samples to [2, 98] + band_mins = torch.tensor( # type: ignore[attr-defined] + [-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ) - band_stds = torch.tensor( # type: ignore[attr-defined] - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + band_maxs = torch.tensor( # type: ignore[attr-defined] + [ + 6.0, + 16.0, + 9859.0, + 12874.1, + 13160.1, + 14437.3, + 12479.0, + 12564.3, + 12282.2, + 15605.0, + 12186.0, + 9453.1, + 5896.0, + 5533.0, + ] ) def __init__( @@ -78,13 +112,23 @@ def __init__( self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.norm = Normalize(self.band_means, self.band_stds) + if bands == "all": + self.mins = self.band_mins[:, None, None] + self.maxs = self.band_maxs[:, None, None] + elif bands == "s1": + self.mins = self.band_mins[:2, None, None] + self.maxs = self.band_maxs[:2, None, None] + else: + self.mins = self.band_mins[2:, None, None] + self.maxs = self.band_maxs[2:, None, None] def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset.""" sample["image"] = sample["image"].float() - # sample["image"] /= 255.0 - # sample["image"] = self.norm(sample["image"]) + sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins) + sample["image"] = torch.clip( # type: ignore[attr-defined] + sample["image"], min=0.0, max=1.0 + ) return sample def prepare_data(self) -> None: From 407b4aa21b7e66856199e5322424e07f5be1a396 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 28 Oct 2021 23:27:32 -0500 Subject: [PATCH 08/19] softmax outputs to correctly compute metrics --- torchgeo/trainers/tasks.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchgeo/trainers/tasks.py b/torchgeo/trainers/tasks.py index b7a6e7b73fb..eb36da80bf3 100644 --- a/torchgeo/trainers/tasks.py +++ b/torchgeo/trainers/tasks.py @@ -334,13 +334,14 @@ def training_step( # type: ignore[override] x = batch["image"] y = batch["label"] y_hat = self.forward(x) + y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined] loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined] # by default, the train step logs every `log_every_n_steps` steps where # `log_every_n_steps` is a parameter to the `Trainer` object self.log("train_loss", loss, on_step=True, on_epoch=False) - self.train_metrics(y_hat, y) + self.train_metrics(y_hat_hard, y) return cast(Tensor, loss) @@ -356,11 +357,12 @@ def validation_step( # type: ignore[override] x = batch["image"] y = batch["label"] y_hat = self.forward(x) + y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined] loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined] self.log("val_loss", loss, on_step=False, on_epoch=True) - self.val_metrics(y_hat, y) + self.val_metrics(y_hat_hard, y) def test_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int @@ -374,9 +376,10 @@ def test_step( # type: ignore[override] x = batch["image"] y = batch["label"] y_hat = self.forward(x) + y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined] loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined] # by default, the test and validation steps only log per *epoch* self.log("test_loss", loss, on_step=False, on_epoch=True) - self.test_metrics(y_hat, y) + self.test_metrics(y_hat_hard, y) From 928d13ba6117520cbd62e05e15816284eff73881 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 28 Oct 2021 23:43:05 -0500 Subject: [PATCH 09/19] update min/max stats for 100k samples --- torchgeo/trainers/bigearthnet.py | 34 ++++++++++++++++---------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/torchgeo/trainers/bigearthnet.py b/torchgeo/trainers/bigearthnet.py index c60bb771003..a0abaadd2a9 100644 --- a/torchgeo/trainers/bigearthnet.py +++ b/torchgeo/trainers/bigearthnet.py @@ -32,7 +32,7 @@ class BigEarthNetDataModule(pl.LightningDataModule): """ # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) - # min/max band statistics computed on 83k random samples + # min/max band statistics computed on 100k random samples band_mins_raw = torch.tensor( # type: ignore[attr-defined] [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] ) @@ -42,13 +42,13 @@ class BigEarthNetDataModule(pl.LightningDataModule): 35.0, 18556.0, 20528.0, - 18903.0, - 17846.0, - 16593.0, + 18976.0, + 17874.0, + 16611.0, 16512.0, 16394.0, - 16575.0, - 16124.0, + 16672.0, + 16141.0, 16097.0, 15336.0, 15203.0, @@ -65,17 +65,17 @@ class BigEarthNetDataModule(pl.LightningDataModule): 6.0, 16.0, 9859.0, - 12874.1, - 13160.1, - 14437.3, - 12479.0, - 12564.3, - 12282.2, - 15605.0, - 12186.0, - 9453.1, - 5896.0, - 5533.0, + 12872.0, + 13163.0, + 14445.0, + 12477.0, + 12563.0, + 12289.0, + 15596.0, + 12183.0, + 9458.0, + 5897.0, + 5544.0, ] ) From 41467c9972bd761def74b56e48b8af8b0fb18516 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Sat, 30 Oct 2021 15:11:37 -0500 Subject: [PATCH 10/19] organize imports in torchgeo.trainers.__init__.py --- torchgeo/trainers/__init__.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index 59e970c5483..01e9034bd7a 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -16,28 +16,29 @@ from .ucmerced import UCMercedClassificationTask, UCMercedDataModule __all__ = ( - # Tasks + # Base Classes "ClassificationTask", "MultiLabelClassificationTask", - # Trainers + # Tasks "BigEarthNetClassificationTask", - "BigEarthNetDataModule", "BYOLTask", "ChesapeakeCVPRSegmentationTask", - "ChesapeakeCVPRDataModule", - "CycloneDataModule", "CycloneSimpleRegressionTask", - "LandcoverAIDataModule", "LandcoverAISegmentationTask", - "NAIPChesapeakeDataModule", "NAIPChesapeakeSegmentationTask", "RESISC45ClassificationTask", - "RESISC45DataModule", - "SEN12MSDataModule", "SEN12MSSegmentationTask", - "So2SatDataModule", "So2SatClassificationTask", "UCMercedClassificationTask", + # DataModules + "BigEarthNetDataModule", + "ChesapeakeCVPRDataModule", + "CycloneDataModule", + "LandcoverAIDataModule", + "NAIPChesapeakeDataModule", + "RESISC45DataModule", + "SEN12MSDataModule", + "So2SatDataModule", "UCMercedDataModule", ) From 58a09bd89d2ce6eba8bde46f83e8d4c53c34374c Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Sat, 30 Oct 2021 15:21:47 -0500 Subject: [PATCH 11/19] clean up fixtures in test_tasks.py --- tests/trainers/test_tasks.py | 132 ++++++++++++++++------------------- 1 file changed, 62 insertions(+), 70 deletions(-) diff --git a/tests/trainers/test_tasks.py b/tests/trainers/test_tasks.py index 1dae029c5e5..340afe5ca0d 100644 --- a/tests/trainers/test_tasks.py +++ b/tests/trainers/test_tasks.py @@ -19,65 +19,35 @@ from .test_utils import mocked_log -@pytest.fixture(scope="module", params=[("rgb", 3), ("s2", 10)]) -def bands_so2sat(request: SubRequest) -> Tuple[str, int]: - return cast(Tuple[str, int], request.param) - - -@pytest.fixture(scope="module", params=[("s1", 2), ("s2", 12), ("all", 14)]) -def bands_bigearthnet(request: SubRequest) -> Tuple[str, int]: - return cast(Tuple[str, int], request.param) - - -@pytest.fixture(scope="module", params=[True, False]) -def datamodule_classification( - bands_so2sat: Tuple[str, int], request: SubRequest -) -> So2SatDataModule: - band_set = bands_so2sat[0] - unsupervised_mode = request.param - root = os.path.join("tests", "data", "so2sat") - batch_size = 2 - num_workers = 0 - dm = So2SatDataModule(root, batch_size, num_workers, band_set, unsupervised_mode) - dm.prepare_data() - dm.setup() - return dm - - -@pytest.fixture(scope="module", params=[True, False]) -def datamodule_multilabel( - bands_bigearthnet: Tuple[str, int], request: SubRequest -) -> BigEarthNetDataModule: - band_set = bands_bigearthnet[0] - unsupervised_mode = request.param - root = os.path.join("tests", "data", "bigearthnet") - batch_size = 1 - num_workers = 0 - dm = BigEarthNetDataModule( - root, - band_set, - batch_size, - num_workers, - unsupervised_mode, - val_split_pct=0.3, - test_split_pct=0.3, - ) - dm.prepare_data() - dm.setup() - return dm - - class TestClassificationTask: + @pytest.fixture(params=[("rgb", 3), ("s2", 10)]) + def bands(self, request: SubRequest) -> Tuple[str, int]: + return cast(Tuple[str, int], request.param) + + @pytest.fixture(params=[True, False]) + def datamodule( + self, bands: Tuple[str, int], request: SubRequest + ) -> So2SatDataModule: + band_set = bands[0] + unsupervised_mode = request.param + root = os.path.join("tests", "data", "so2sat") + batch_size = 2 + num_workers = 0 + dm = So2SatDataModule( + root, batch_size, num_workers, band_set, unsupervised_mode + ) + dm.prepare_data() + dm.setup() + return dm + @pytest.fixture( params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]) ) - def config( - self, request: SubRequest, bands_so2sat: Tuple[str, int] - ) -> Dict[str, Any]: + def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]: task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) task_args = OmegaConf.to_object(task_conf.experiment.module) task_args = cast(Dict[str, Any], task_args) - task_args["in_channels"] = bands_so2sat[1] + task_args["in_channels"] = bands[1] loss, weights = request.param task_args["loss"] = loss task_args["weights"] = weights @@ -97,23 +67,21 @@ def test_configure_optimizers(self, task: ClassificationTask) -> None: assert "lr_scheduler" in out def test_training( - self, datamodule_classification: So2SatDataModule, task: ClassificationTask + self, datamodule: So2SatDataModule, task: ClassificationTask ) -> None: - batch = next(iter(datamodule_classification.train_dataloader())) + batch = next(iter(datamodule.train_dataloader())) task.training_step(batch, 0) task.training_epoch_end(0) def test_validation( - self, datamodule_classification: So2SatDataModule, task: ClassificationTask + self, datamodule: So2SatDataModule, task: ClassificationTask ) -> None: - batch = next(iter(datamodule_classification.val_dataloader())) + batch = next(iter(datamodule.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) - def test_test( - self, datamodule_classification: So2SatDataModule, task: ClassificationTask - ) -> None: - batch = next(iter(datamodule_classification.test_dataloader())) + def test_test(self, datamodule: So2SatDataModule, task: ClassificationTask) -> None: + batch = next(iter(datamodule.test_dataloader())) task.test_step(batch, 0) task.test_epoch_end(0) @@ -152,16 +120,40 @@ def test_invalid_pretrained(self, checkpoint: str, config: Dict[str, Any]) -> No class TestMultiLabelClassificationTask: + @pytest.fixture(params=[("s1", 2), ("s2", 12), ("all", 14)]) + def bands(self, request: SubRequest) -> Tuple[str, int]: + return cast(Tuple[str, int], request.param) + + @pytest.fixture(params=[True, False]) + def datamodule( + self, bands: Tuple[str, int], request: SubRequest + ) -> BigEarthNetDataModule: + band_set = bands[0] + unsupervised_mode = request.param + root = os.path.join("tests", "data", "bigearthnet") + batch_size = 1 + num_workers = 0 + dm = BigEarthNetDataModule( + root, + band_set, + batch_size, + num_workers, + unsupervised_mode, + val_split_pct=0.3, + test_split_pct=0.3, + ) + dm.prepare_data() + dm.setup() + return dm + @pytest.fixture(params=zip(["bce", "bce"], ["imagenet", "random"])) - def config( - self, request: SubRequest, bands_bigearthnet: Tuple[str, int] - ) -> Dict[str, Any]: + def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]: task_conf = OmegaConf.load( os.path.join("conf", "task_defaults", "bigearthnet.yaml") ) task_args = OmegaConf.to_object(task_conf.experiment.module) task_args = cast(Dict[str, Any], task_args) - task_args["in_channels"] = bands_bigearthnet[1] + task_args["in_channels"] = bands[1] loss, weights = request.param task_args["loss"] = loss task_args["weights"] = weights @@ -177,28 +169,28 @@ def task( def test_training( self, - datamodule_multilabel: BigEarthNetDataModule, + datamodule: BigEarthNetDataModule, task: MultiLabelClassificationTask, ) -> None: - batch = next(iter(datamodule_multilabel.train_dataloader())) + batch = next(iter(datamodule.train_dataloader())) task.training_step(batch, 0) task.training_epoch_end(0) def test_validation( self, - datamodule_multilabel: BigEarthNetDataModule, + datamodule: BigEarthNetDataModule, task: MultiLabelClassificationTask, ) -> None: - batch = next(iter(datamodule_multilabel.val_dataloader())) + batch = next(iter(datamodule.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) def test_test( self, - datamodule_multilabel: BigEarthNetDataModule, + datamodule: BigEarthNetDataModule, task: MultiLabelClassificationTask, ) -> None: - batch = next(iter(datamodule_multilabel.test_dataloader())) + batch = next(iter(datamodule.test_dataloader())) task.test_step(batch, 0) task.test_epoch_end(0) From 206495582842272a9f1d8760b090e09b1a252d62 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Sat, 30 Oct 2021 15:36:14 -0500 Subject: [PATCH 12/19] added bigearthnet to train.py --- train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/train.py b/train.py index bbd77c51b85..c3e6763b7d6 100755 --- a/train.py +++ b/train.py @@ -15,6 +15,8 @@ from torchgeo.trainers import ( BYOLTask, + BigEarthNetClassificationTask, + BigEarthNetDataModule, ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask, CycloneDataModule, @@ -36,6 +38,7 @@ TASK_TO_MODULES_MAPPING: Dict[ str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]] ] = { + "bigearthnet": (BigEarthNetClassificationTask, BigEarthNetDataModule), "byol": (BYOLTask, ChesapeakeCVPRDataModule), "chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule), "cyclone": (CycloneSimpleRegressionTask, CycloneDataModule), From a46d3e7f48231fd78cac4c534cbedc4de4eb4b66 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Sat, 30 Oct 2021 15:36:33 -0500 Subject: [PATCH 13/19] format --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index c3e6763b7d6..0386620a614 100755 --- a/train.py +++ b/train.py @@ -14,9 +14,9 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from torchgeo.trainers import ( - BYOLTask, BigEarthNetClassificationTask, BigEarthNetDataModule, + BYOLTask, ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask, CycloneDataModule, From d6e548b72bfb6ca7dcb0e8190f10ab9da6f63699 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Sun, 31 Oct 2021 21:05:47 -0500 Subject: [PATCH 14/19] move fixtures into class methods --- tests/trainers/test_bigearthnet.py | 52 +++++++++++++++--------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/trainers/test_bigearthnet.py b/tests/trainers/test_bigearthnet.py index cc417f36548..b3bfcd0d316 100644 --- a/tests/trainers/test_bigearthnet.py +++ b/tests/trainers/test_bigearthnet.py @@ -10,33 +10,33 @@ from torchgeo.trainers import BigEarthNetDataModule -@pytest.fixture(scope="module", params=[("s1", 2), ("s2", 12), ("all", 14)]) -def bands(request: SubRequest) -> Tuple[str, int]: - return cast(Tuple[str, int], request.param) - - -@pytest.fixture(scope="module", params=[True, False]) -def datamodule(bands: Tuple[str, int], request: SubRequest) -> BigEarthNetDataModule: - band_set = bands[0] - unsupervised_mode = request.param - root = os.path.join("tests", "data", "bigearthnet") - batch_size = 1 - num_workers = 0 - dm = BigEarthNetDataModule( - root, - band_set, - batch_size, - num_workers, - unsupervised_mode, - val_split_pct=0.3, - test_split_pct=0.3, - ) - dm.prepare_data() - dm.setup() - return dm - - class TestBigEarthNetDataModule: + @pytest.fixture(params=[("s1", 2), ("s2", 12), ("all", 14)]) + def bands(self, request: SubRequest) -> Tuple[str, int]: + return cast(Tuple[str, int], request.param) + + @pytest.fixture(params=[True, False]) + def datamodule( + self, bands: Tuple[str, int], request: SubRequest + ) -> BigEarthNetDataModule: + band_set = bands[0] + unsupervised_mode = request.param + root = os.path.join("tests", "data", "bigearthnet") + batch_size = 1 + num_workers = 0 + dm = BigEarthNetDataModule( + root, + band_set, + batch_size, + num_workers, + unsupervised_mode, + val_split_pct=0.3, + test_split_pct=0.3, + ) + dm.prepare_data() + dm.setup() + return dm + def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None: next(iter(datamodule.train_dataloader())) From 05f3230394d06294b0c75e04f9fb30a94ea564ad Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Sun, 31 Oct 2021 21:10:45 -0500 Subject: [PATCH 15/19] consolidate bigearthnet fixtures --- tests/trainers/test_bigearthnet.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/trainers/test_bigearthnet.py b/tests/trainers/test_bigearthnet.py index b3bfcd0d316..fb1f99585bd 100644 --- a/tests/trainers/test_bigearthnet.py +++ b/tests/trainers/test_bigearthnet.py @@ -11,22 +11,24 @@ class TestBigEarthNetDataModule: - @pytest.fixture(params=[("s1", 2), ("s2", 12), ("all", 14)]) - def bands(self, request: SubRequest) -> Tuple[str, int]: - return cast(Tuple[str, int], request.param) - - @pytest.fixture(params=[True, False]) - def datamodule( - self, bands: Tuple[str, int], request: SubRequest - ) -> BigEarthNetDataModule: - band_set = bands[0] - unsupervised_mode = request.param + @pytest.fixture( + params=[ + ("s1", True), + ("s2", True), + ("all", True), + ("s1", False), + ("s2", False), + ("all", False), + ] + ) + def datamodule(self, request: SubRequest) -> BigEarthNetDataModule: + bands, unsupervised_mode = request.param root = os.path.join("tests", "data", "bigearthnet") batch_size = 1 num_workers = 0 dm = BigEarthNetDataModule( root, - band_set, + bands, batch_size, num_workers, unsupervised_mode, From 8e2cd4f80873f67a69a270d3faae0c53a6853e85 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 1 Nov 2021 14:12:32 -0500 Subject: [PATCH 16/19] refactor tasks tests --- tests/trainers/test_bigearthnet.py | 12 +- tests/trainers/test_tasks.py | 198 ++++++++++++++++++----------- torchgeo/trainers/__init__.py | 5 +- 3 files changed, 125 insertions(+), 90 deletions(-) diff --git a/tests/trainers/test_bigearthnet.py b/tests/trainers/test_bigearthnet.py index fb1f99585bd..add7d8861b2 100644 --- a/tests/trainers/test_bigearthnet.py +++ b/tests/trainers/test_bigearthnet.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -from typing import Tuple, cast import pytest from _pytest.fixtures import SubRequest @@ -11,16 +10,7 @@ class TestBigEarthNetDataModule: - @pytest.fixture( - params=[ - ("s1", True), - ("s2", True), - ("all", True), - ("s1", False), - ("s2", False), - ("all", False), - ] - ) + @pytest.fixture(scope="class", params=zip(["s1", "s2", "all"], [True, True, False])) def datamodule(self, request: SubRequest) -> BigEarthNetDataModule: bands, unsupervised_mode = request.param root = os.path.join("tests", "data", "bigearthnet") diff --git a/tests/trainers/test_tasks.py b/tests/trainers/test_tasks.py index 340afe5ca0d..9db4156962c 100644 --- a/tests/trainers/test_tasks.py +++ b/tests/trainers/test_tasks.py @@ -2,52 +2,109 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, Generator, Tuple, cast +from typing import Any, Dict, Generator, Optional, cast import pytest +import pytorch_lightning as pl +import torch +import torch.nn.functional as F from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf +from torch import Tensor +from torch.utils.data import DataLoader, Dataset, TensorDataset -from torchgeo.trainers import ( - BigEarthNetDataModule, - ClassificationTask, - MultiLabelClassificationTask, - So2SatDataModule, -) +from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask from .test_utils import mocked_log +class DummyDataset(Dataset): # type: ignore[type-arg] + def __init__(self, num_channels: int, num_classes: int, multilabel: bool) -> None: + x = torch.randn(10, num_channels, 128, 128) # (b, c, h, w) + y = torch.randint( # type: ignore[attr-defined] + 0, num_classes, size=(10,) + ) # (b,) + + if multilabel: + y = F.one_hot(y, num_classes=num_classes) # (b, classes) + + self.dataset = TensorDataset(x, y) + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> Dict[str, Tensor]: + x, y = self.dataset[idx] + sample = {"image": x, "label": y} + return sample + + +class DummyDataModule(pl.LightningDataModule): + def __init__( + self, + num_channels: int, + num_classes: int, + multilabel: bool, + batch_size: int = 1, + num_workers: int = 0, + ) -> None: + super().__init__() # type: ignore[no-untyped-call] + self.num_channels = num_channels + self.num_classes = num_classes + self.multilabel = multilabel + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage: Optional[str] = None) -> None: + self.dataset = DummyDataset( + num_channels=self.num_channels, + num_classes=self.num_classes, + multilabel=self.multilabel, + ) + + def train_dataloader(self): + return DataLoader( + self.dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) + + def val_dataloader(self): + return DataLoader( + self.dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) + + def test_dataloader(self): + return DataLoader( + self.dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) + + class TestClassificationTask: - @pytest.fixture(params=[("rgb", 3), ("s2", 10)]) - def bands(self, request: SubRequest) -> Tuple[str, int]: - return cast(Tuple[str, int], request.param) - - @pytest.fixture(params=[True, False]) - def datamodule( - self, bands: Tuple[str, int], request: SubRequest - ) -> So2SatDataModule: - band_set = bands[0] - unsupervised_mode = request.param - root = os.path.join("tests", "data", "so2sat") - batch_size = 2 - num_workers = 0 - dm = So2SatDataModule( - root, batch_size, num_workers, band_set, unsupervised_mode + @pytest.fixture(scope="class", params=[2, 3, 5]) + def datamodule(self, request: SubRequest) -> DummyDataModule: + dm = DummyDataModule( + num_channels=request.param, + num_classes=45, + multilabel=False, + batch_size=2, + num_workers=0, ) dm.prepare_data() dm.setup() return dm @pytest.fixture( - params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]) + scope="class", + params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]), ) - def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]: - task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["in_channels"] = bands[1] + def config( + self, request: SubRequest, datamodule: DummyDataModule + ) -> Dict[str, Any]: + task_args = {} + task_args["classification_model"] = "resnet18" + task_args["learning_rate"] = 3e-4 # type: ignore[assignment] + task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment] + task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment] loss, weights = request.param task_args["loss"] = loss task_args["weights"] = weights @@ -67,21 +124,25 @@ def test_configure_optimizers(self, task: ClassificationTask) -> None: assert "lr_scheduler" in out def test_training( - self, datamodule: So2SatDataModule, task: ClassificationTask + self, datamodule: DummyDataModule, task: ClassificationTask ) -> None: - batch = next(iter(datamodule.train_dataloader())) + batch = next( + iter(datamodule.train_dataloader()) # type: ignore[no-untyped-call] + ) task.training_step(batch, 0) task.training_epoch_end(0) def test_validation( - self, datamodule: So2SatDataModule, task: ClassificationTask + self, datamodule: DummyDataModule, task: ClassificationTask ) -> None: - batch = next(iter(datamodule.val_dataloader())) + batch = next(iter(datamodule.val_dataloader())) # type: ignore[no-untyped-call] task.validation_step(batch, 0) task.validation_epoch_end(0) - def test_test(self, datamodule: So2SatDataModule, task: ClassificationTask) -> None: - batch = next(iter(datamodule.test_dataloader())) + def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None: + batch = next( + iter(datamodule.test_dataloader()) # type: ignore[no-untyped-call] + ) task.test_step(batch, 0) task.test_epoch_end(0) @@ -101,6 +162,7 @@ def test_invalid_model(self, config: Dict[str, Any]) -> None: def test_invalid_loss(self, config: Dict[str, Any]) -> None: config["loss"] = "invalid_loss" + config["classification_model"] = "resnet18" error_message = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=error_message): ClassificationTask(**config) @@ -120,40 +182,28 @@ def test_invalid_pretrained(self, checkpoint: str, config: Dict[str, Any]) -> No class TestMultiLabelClassificationTask: - @pytest.fixture(params=[("s1", 2), ("s2", 12), ("all", 14)]) - def bands(self, request: SubRequest) -> Tuple[str, int]: - return cast(Tuple[str, int], request.param) - - @pytest.fixture(params=[True, False]) - def datamodule( - self, bands: Tuple[str, int], request: SubRequest - ) -> BigEarthNetDataModule: - band_set = bands[0] - unsupervised_mode = request.param - root = os.path.join("tests", "data", "bigearthnet") - batch_size = 1 - num_workers = 0 - dm = BigEarthNetDataModule( - root, - band_set, - batch_size, - num_workers, - unsupervised_mode, - val_split_pct=0.3, - test_split_pct=0.3, + @pytest.fixture + def datamodule(self, request: SubRequest) -> DummyDataModule: + dm = DummyDataModule( + num_channels=3, + num_classes=43, + multilabel=True, + batch_size=2, + num_workers=0, ) dm.prepare_data() dm.setup() return dm @pytest.fixture(params=zip(["bce", "bce"], ["imagenet", "random"])) - def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]: - task_conf = OmegaConf.load( - os.path.join("conf", "task_defaults", "bigearthnet.yaml") - ) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["in_channels"] = bands[1] + def config( + self, datamodule: DummyDataModule, request: SubRequest + ) -> Dict[str, Any]: + task_args = {} + task_args["classification_model"] = "resnet18" + task_args["learning_rate"] = 3e-4 # type: ignore[assignment] + task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment] + task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment] loss, weights = request.param task_args["loss"] = loss task_args["weights"] = weights @@ -168,29 +218,25 @@ def task( return task def test_training( - self, - datamodule: BigEarthNetDataModule, - task: MultiLabelClassificationTask, + self, datamodule: DummyDataModule, task: ClassificationTask ) -> None: - batch = next(iter(datamodule.train_dataloader())) + batch = next( + iter(datamodule.train_dataloader()) # type: ignore[no-untyped-call] + ) task.training_step(batch, 0) task.training_epoch_end(0) def test_validation( - self, - datamodule: BigEarthNetDataModule, - task: MultiLabelClassificationTask, + self, datamodule: DummyDataModule, task: ClassificationTask ) -> None: - batch = next(iter(datamodule.val_dataloader())) + batch = next(iter(datamodule.val_dataloader())) # type: ignore[no-untyped-call] task.validation_step(batch, 0) task.validation_epoch_end(0) - def test_test( - self, - datamodule: BigEarthNetDataModule, - task: MultiLabelClassificationTask, - ) -> None: - batch = next(iter(datamodule.test_dataloader())) + def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None: + batch = next( + iter(datamodule.test_dataloader()) # type: ignore[no-untyped-call] + ) task.test_step(batch, 0) task.test_epoch_end(0) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index 01e9034bd7a..70c6e138ad6 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -16,15 +16,14 @@ from .ucmerced import UCMercedClassificationTask, UCMercedDataModule __all__ = ( - # Base Classes - "ClassificationTask", - "MultiLabelClassificationTask", # Tasks "BigEarthNetClassificationTask", "BYOLTask", "ChesapeakeCVPRSegmentationTask", + "ClassificationTask", "CycloneSimpleRegressionTask", "LandcoverAISegmentationTask", + "MultiLabelClassificationTask", "NAIPChesapeakeSegmentationTask", "RESISC45ClassificationTask", "SEN12MSSegmentationTask", From 14f6e45a996dd989e921df886f86822c90c1af5a Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 1 Nov 2021 14:15:08 -0500 Subject: [PATCH 17/19] add scope=class --- tests/trainers/test_tasks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trainers/test_tasks.py b/tests/trainers/test_tasks.py index 9db4156962c..a73ea57802a 100644 --- a/tests/trainers/test_tasks.py +++ b/tests/trainers/test_tasks.py @@ -110,7 +110,7 @@ def config( task_args["weights"] = weights return task_args - @pytest.fixture + @pytest.fixture(scope="class") def task( self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] ) -> ClassificationTask: @@ -182,7 +182,7 @@ def test_invalid_pretrained(self, checkpoint: str, config: Dict[str, Any]) -> No class TestMultiLabelClassificationTask: - @pytest.fixture + @pytest.fixture(scope="class") def datamodule(self, request: SubRequest) -> DummyDataModule: dm = DummyDataModule( num_channels=3, @@ -195,7 +195,7 @@ def datamodule(self, request: SubRequest) -> DummyDataModule: dm.setup() return dm - @pytest.fixture(params=zip(["bce", "bce"], ["imagenet", "random"])) + @pytest.fixture(scope="class", params=zip(["bce", "bce"], ["imagenet", "random"])) def config( self, datamodule: DummyDataModule, request: SubRequest ) -> Dict[str, Any]: @@ -209,7 +209,7 @@ def config( task_args["weights"] = weights return task_args - @pytest.fixture + @pytest.fixture(scope="class") def task( self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] ) -> MultiLabelClassificationTask: From ae16fdb5a46d591fc1d8dfb8b4d3f338461b18d3 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 1 Nov 2021 14:59:39 -0500 Subject: [PATCH 18/19] style/mypy fixes --- tests/trainers/test_tasks.py | 6 +++--- torchgeo/trainers/tasks.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/trainers/test_tasks.py b/tests/trainers/test_tasks.py index c2ce6a44755..2b34824c925 100644 --- a/tests/trainers/test_tasks.py +++ b/tests/trainers/test_tasks.py @@ -68,17 +68,17 @@ def setup(self, stage: Optional[str] = None) -> None: multilabel=self.multilabel, ) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: # type: ignore[type-arg] return DataLoader( self.dataset, batch_size=self.batch_size, num_workers=self.num_workers ) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: # type: ignore[type-arg] return DataLoader( self.dataset, batch_size=self.batch_size, num_workers=self.num_workers ) - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: # type: ignore[type-arg] return DataLoader( self.dataset, batch_size=self.batch_size, num_workers=self.num_workers ) diff --git a/torchgeo/trainers/tasks.py b/torchgeo/trainers/tasks.py index 8d7d69dc4ab..037633445a8 100644 --- a/torchgeo/trainers/tasks.py +++ b/torchgeo/trainers/tasks.py @@ -287,6 +287,7 @@ def config_task(self) -> None: def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function. + Keyword Args: classification_model: Name of the classification model use loss: Name of the loss function @@ -322,6 +323,7 @@ def training_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> Tensor: """Training step. + Args: batch: Current batch batch_idx: Index of current batch @@ -346,6 +348,7 @@ def validation_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> None: """Validation step. + Args: batch: Current batch batch_idx: Index of current batch @@ -364,6 +367,7 @@ def test_step( # type: ignore[override] self, batch: Dict[str, Any], batch_idx: int ) -> None: """Test step. + Args: batch: Current batch batch_idx: Index of current batch From 503a558fabdc16fa6a014379608a6d6fafadd5ae Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Mon, 1 Nov 2021 15:17:05 -0500 Subject: [PATCH 19/19] mypy fixes --- tests/trainers/test_tasks.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/trainers/test_tasks.py b/tests/trainers/test_tasks.py index 2b34824c925..869cb230c1d 100644 --- a/tests/trainers/test_tasks.py +++ b/tests/trainers/test_tasks.py @@ -131,23 +131,19 @@ def test_configure_optimizers(self, task: ClassificationTask) -> None: def test_training( self, datamodule: DummyDataModule, task: ClassificationTask ) -> None: - batch = next( - iter(datamodule.train_dataloader()) # type: ignore[no-untyped-call] - ) + batch = next(iter(datamodule.train_dataloader())) task.training_step(batch, 0) task.training_epoch_end(0) def test_validation( self, datamodule: DummyDataModule, task: ClassificationTask ) -> None: - batch = next(iter(datamodule.val_dataloader())) # type: ignore[no-untyped-call] + batch = next(iter(datamodule.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None: - batch = next( - iter(datamodule.test_dataloader()) # type: ignore[no-untyped-call] - ) + batch = next(iter(datamodule.test_dataloader())) task.test_step(batch, 0) task.test_epoch_end(0) @@ -225,23 +221,19 @@ def task( def test_training( self, datamodule: DummyDataModule, task: ClassificationTask ) -> None: - batch = next( - iter(datamodule.train_dataloader()) # type: ignore[no-untyped-call] - ) + batch = next(iter(datamodule.train_dataloader())) task.training_step(batch, 0) task.training_epoch_end(0) def test_validation( self, datamodule: DummyDataModule, task: ClassificationTask ) -> None: - batch = next(iter(datamodule.val_dataloader())) # type: ignore[no-untyped-call] + batch = next(iter(datamodule.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None: - batch = next( - iter(datamodule.test_dataloader()) # type: ignore[no-untyped-call] - ) + batch = next(iter(datamodule.test_dataloader())) task.test_step(batch, 0) task.test_epoch_end(0)