diff --git a/conf/bigearthnet.yaml b/conf/bigearthnet.yaml new file mode 100644 index 00000000000..7f8b52f9da2 --- /dev/null +++ b/conf/bigearthnet.yaml @@ -0,0 +1,18 @@ +trainer: + gpus: 1 # single GPU training + min_epochs: 10 + max_epochs: 40 + benchmark: True + +experiment: + task: "bigearthnet" + module: + loss: "bce" + classification_model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 14 + datamodule: + batch_size: 128 + num_workers: 6 + bands: "all" diff --git a/conf/task_defaults/bigearthnet.yaml b/conf/task_defaults/bigearthnet.yaml new file mode 100644 index 00000000000..723d4e7d954 --- /dev/null +++ b/conf/task_defaults/bigearthnet.yaml @@ -0,0 +1,13 @@ +experiment: + task: "bigearthnet" + module: + loss: "bce" + classification_model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: "random" + in_channels: 14 + datamodule: + batch_size: 128 + num_workers: 6 + bands: "all" diff --git a/tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz b/tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz index 9169a8a932c..d9df455f105 100644 Binary files a/tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz and b/tests/data/bigearthnet/BigEarthNet-S1-v1.0.tar.gz differ diff --git a/tests/data/bigearthnet/BigEarthNet-S2-v1.0.tar.gz b/tests/data/bigearthnet/BigEarthNet-S2-v1.0.tar.gz index e8937e8746a..5f56c81fcb1 100644 Binary files a/tests/data/bigearthnet/BigEarthNet-S2-v1.0.tar.gz and b/tests/data/bigearthnet/BigEarthNet-S2-v1.0.tar.gz differ diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py index 89416828fa9..f30d7c7d182 100644 --- a/tests/datasets/test_bigearthnet.py +++ b/tests/datasets/test_bigearthnet.py @@ -71,7 +71,7 @@ def test_getitem(self, dataset: BigEarthNet) -> None: assert x["image"].shape == (12, 120, 120) def test_len(self, dataset: BigEarthNet) -> None: - assert len(dataset) == 2 + assert len(dataset) == 4 def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None: BigEarthNet(root=str(tmp_path), bands=dataset.bands, download=True) diff --git a/tests/trainers/test_bigearthnet.py b/tests/trainers/test_bigearthnet.py new file mode 100644 index 00000000000..add7d8861b2 --- /dev/null +++ b/tests/trainers/test_bigearthnet.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from _pytest.fixtures import SubRequest + +from torchgeo.trainers import BigEarthNetDataModule + + +class TestBigEarthNetDataModule: + @pytest.fixture(scope="class", params=zip(["s1", "s2", "all"], [True, True, False])) + def datamodule(self, request: SubRequest) -> BigEarthNetDataModule: + bands, unsupervised_mode = request.param + root = os.path.join("tests", "data", "bigearthnet") + batch_size = 1 + num_workers = 0 + dm = BigEarthNetDataModule( + root, + bands, + batch_size, + num_workers, + unsupervised_mode, + val_split_pct=0.3, + test_split_pct=0.3, + ) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None: + next(iter(datamodule.test_dataloader())) diff --git a/tests/trainers/test_tasks.py b/tests/trainers/test_tasks.py index 0f1973cb01f..869cb230c1d 100644 --- a/tests/trainers/test_tasks.py +++ b/tests/trainers/test_tasks.py @@ -2,50 +2,114 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, Generator, Tuple, cast +from typing import Any, Dict, Generator, Optional, cast import pytest +import pytorch_lightning as pl +import torch +import torch.nn.functional as F from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf +from torch import Tensor +from torch.utils.data import DataLoader, Dataset, TensorDataset from torchgeo.trainers import ( ClassificationTask, CycloneDataModule, + MultiLabelClassificationTask, RegressionTask, - So2SatDataModule, ) from .test_utils import mocked_log -@pytest.fixture(scope="module", params=[("rgb", 3), ("s2", 10)]) -def bands(request: SubRequest) -> Tuple[str, int]: - return cast(Tuple[str, int], request.param) +class DummyDataset(Dataset): # type: ignore[type-arg] + def __init__(self, num_channels: int, num_classes: int, multilabel: bool) -> None: + x = torch.randn(10, num_channels, 128, 128) # (b, c, h, w) + y = torch.randint( # type: ignore[attr-defined] + 0, num_classes, size=(10,) + ) # (b,) + if multilabel: + y = F.one_hot(y, num_classes=num_classes) # (b, classes) -@pytest.fixture(scope="module", params=[True, False]) -def datamodule(bands: Tuple[str, int], request: SubRequest) -> So2SatDataModule: - band_set = bands[0] - unsupervised_mode = request.param - root = os.path.join("tests", "data", "so2sat") - batch_size = 2 - num_workers = 0 - dm = So2SatDataModule(root, batch_size, num_workers, band_set, unsupervised_mode) - dm.prepare_data() - dm.setup() - return dm + self.dataset = TensorDataset(x, y) + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> Dict[str, Tensor]: + x, y = self.dataset[idx] + sample = {"image": x, "label": y} + return sample + + +class DummyDataModule(pl.LightningDataModule): + def __init__( + self, + num_channels: int, + num_classes: int, + multilabel: bool, + batch_size: int = 1, + num_workers: int = 0, + ) -> None: + super().__init__() # type: ignore[no-untyped-call] + self.num_channels = num_channels + self.num_classes = num_classes + self.multilabel = multilabel + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage: Optional[str] = None) -> None: + self.dataset = DummyDataset( + num_channels=self.num_channels, + num_classes=self.num_classes, + multilabel=self.multilabel, + ) + + def train_dataloader(self) -> DataLoader: # type: ignore[type-arg] + return DataLoader( + self.dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) + + def val_dataloader(self) -> DataLoader: # type: ignore[type-arg] + return DataLoader( + self.dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) + + def test_dataloader(self) -> DataLoader: # type: ignore[type-arg] + return DataLoader( + self.dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) class TestClassificationTask: + @pytest.fixture(scope="class", params=[2, 3, 5]) + def datamodule(self, request: SubRequest) -> DummyDataModule: + dm = DummyDataModule( + num_channels=request.param, + num_classes=45, + multilabel=False, + batch_size=2, + num_workers=0, + ) + dm.prepare_data() + dm.setup() + return dm + @pytest.fixture( - params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]) + scope="class", + params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]), ) - def config(self, request: SubRequest, bands: Tuple[str, int]) -> Dict[str, Any]: - task_conf = OmegaConf.load(os.path.join("conf", "task_defaults", "so2sat.yaml")) - task_args = OmegaConf.to_object(task_conf.experiment.module) - task_args = cast(Dict[str, Any], task_args) - task_args["in_channels"] = bands[1] + def config( + self, request: SubRequest, datamodule: DummyDataModule + ) -> Dict[str, Any]: + task_args = {} + task_args["classification_model"] = "resnet18" + task_args["learning_rate"] = 3e-4 # type: ignore[assignment] + task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment] + task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment] loss, weights = request.param task_args["loss"] = loss task_args["weights"] = weights @@ -65,20 +129,20 @@ def test_configure_optimizers(self, task: ClassificationTask) -> None: assert "lr_scheduler" in out def test_training( - self, datamodule: So2SatDataModule, task: ClassificationTask + self, datamodule: DummyDataModule, task: ClassificationTask ) -> None: batch = next(iter(datamodule.train_dataloader())) task.training_step(batch, 0) task.training_epoch_end(0) def test_validation( - self, datamodule: So2SatDataModule, task: ClassificationTask + self, datamodule: DummyDataModule, task: ClassificationTask ) -> None: batch = next(iter(datamodule.val_dataloader())) task.validation_step(batch, 0) task.validation_epoch_end(0) - def test_test(self, datamodule: So2SatDataModule, task: ClassificationTask) -> None: + def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None: batch = next(iter(datamodule.test_dataloader())) task.test_step(batch, 0) task.test_epoch_end(0) @@ -99,6 +163,7 @@ def test_invalid_model(self, config: Dict[str, Any]) -> None: def test_invalid_loss(self, config: Dict[str, Any]) -> None: config["loss"] = "invalid_loss" + config["classification_model"] = "resnet18" error_message = "Loss type 'invalid_loss' is not valid." with pytest.raises(ValueError, match=error_message): ClassificationTask(**config) @@ -117,6 +182,68 @@ def test_invalid_pretrained(self, checkpoint: str, config: Dict[str, Any]) -> No ClassificationTask(**config) +class TestMultiLabelClassificationTask: + @pytest.fixture(scope="class") + def datamodule(self, request: SubRequest) -> DummyDataModule: + dm = DummyDataModule( + num_channels=3, + num_classes=43, + multilabel=True, + batch_size=2, + num_workers=0, + ) + dm.prepare_data() + dm.setup() + return dm + + @pytest.fixture(scope="class", params=zip(["bce", "bce"], ["imagenet", "random"])) + def config( + self, datamodule: DummyDataModule, request: SubRequest + ) -> Dict[str, Any]: + task_args = {} + task_args["classification_model"] = "resnet18" + task_args["learning_rate"] = 3e-4 # type: ignore[assignment] + task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment] + task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment] + loss, weights = request.param + task_args["loss"] = loss + task_args["weights"] = weights + return task_args + + @pytest.fixture + def task( + self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None] + ) -> MultiLabelClassificationTask: + task = MultiLabelClassificationTask(**config) + monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined] + return task + + def test_training( + self, datamodule: DummyDataModule, task: ClassificationTask + ) -> None: + batch = next(iter(datamodule.train_dataloader())) + task.training_step(batch, 0) + task.training_epoch_end(0) + + def test_validation( + self, datamodule: DummyDataModule, task: ClassificationTask + ) -> None: + batch = next(iter(datamodule.val_dataloader())) + task.validation_step(batch, 0) + task.validation_epoch_end(0) + + def test_test(self, datamodule: DummyDataModule, task: ClassificationTask) -> None: + batch = next(iter(datamodule.test_dataloader())) + task.test_step(batch, 0) + task.test_epoch_end(0) + + def test_invalid_loss(self, config: Dict[str, Any]) -> None: + config["loss"] = "invalid_loss" + error_message = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=error_message): + MultiLabelClassificationTask(**config) + + class TestRegressionTask: @pytest.fixture(scope="class") def datamodule(self) -> CycloneDataModule: diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index a07eee5ef2c..20148f7521e 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -3,6 +3,7 @@ """TorchGeo trainers.""" +from .bigearthnet import BigEarthNetClassificationTask, BigEarthNetDataModule from .byol import BYOLTask from .chesapeake import ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask from .cyclone import CycloneDataModule @@ -11,29 +12,35 @@ from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask from .so2sat import So2SatClassificationTask, So2SatDataModule -from .tasks import ClassificationTask, RegressionTask +from .tasks import ClassificationTask, MultiLabelClassificationTask, RegressionTask from .ucmerced import UCMercedClassificationTask, UCMercedDataModule __all__ = ( # Tasks - "ClassificationTask", - "RegressionTask", - # Trainers + "BigEarthNetClassificationTask", "BYOLTask", "ChesapeakeCVPRSegmentationTask", "ChesapeakeCVPRDataModule", + "ClassificationTask", "CycloneDataModule", "LandcoverAIDataModule", "LandcoverAISegmentationTask", - "NAIPChesapeakeDataModule", + "MultiLabelClassificationTask", "NAIPChesapeakeSegmentationTask", "RESISC45ClassificationTask", - "RESISC45DataModule", - "SEN12MSDataModule", + "RegressionTask", "SEN12MSSegmentationTask", - "So2SatDataModule", "So2SatClassificationTask", "UCMercedClassificationTask", + # DataModules + "BigEarthNetDataModule", + "ChesapeakeCVPRDataModule", + "CycloneDataModule", + "LandcoverAIDataModule", + "NAIPChesapeakeDataModule", + "RESISC45DataModule", + "SEN12MSDataModule", + "So2SatDataModule", "UCMercedDataModule", ) diff --git a/torchgeo/trainers/bigearthnet.py b/torchgeo/trainers/bigearthnet.py new file mode 100644 index 00000000000..a0abaadd2a9 --- /dev/null +++ b/torchgeo/trainers/bigearthnet.py @@ -0,0 +1,193 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""BigEarthNet trainer.""" + +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader +from torchvision.transforms import Compose + +from ..datasets import BigEarthNet +from ..datasets.utils import dataset_split +from .tasks import MultiLabelClassificationTask + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +DataLoader.__module__ = "torch.utils.data" + + +class BigEarthNetClassificationTask(MultiLabelClassificationTask): + """LightningModule for training models on the BigEarthNet Dataset.""" + + num_classes = 43 + + +class BigEarthNetDataModule(pl.LightningDataModule): + """LightningDataModule implementation for the BigEarthNet dataset. + + Uses the train/val/test splits from the dataset. + """ + + # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) + # min/max band statistics computed on 100k random samples + band_mins_raw = torch.tensor( # type: ignore[attr-defined] + [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] + ) + band_maxs_raw = torch.tensor( # type: ignore[attr-defined] + [ + 31.0, + 35.0, + 18556.0, + 20528.0, + 18976.0, + 17874.0, + 16611.0, + 16512.0, + 16394.0, + 16672.0, + 16141.0, + 16097.0, + 15336.0, + 15203.0, + ] + ) + + # min/max band statistics computed by percentile clipping the + # above to samples to [2, 98] + band_mins = torch.tensor( # type: ignore[attr-defined] + [-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ) + band_maxs = torch.tensor( # type: ignore[attr-defined] + [ + 6.0, + 16.0, + 9859.0, + 12872.0, + 13163.0, + 14445.0, + 12477.0, + 12563.0, + 12289.0, + 15596.0, + 12183.0, + 9458.0, + 5897.0, + 5544.0, + ] + ) + + def __init__( + self, + root_dir: str, + bands: str = "all", + batch_size: int = 64, + num_workers: int = 4, + unsupervised_mode: bool = False, + val_split_pct: float = 0.2, + test_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a LightningDataModule for BigEarthNet based DataLoaders. + + Args: + root_dir: The ``root`` arugment to pass to the BigEarthNet Dataset classes + bands: load Sentinel-1 bands, Sentinel-2, or both. one of {s1, s2, all} + batch_size: The batch size to use in all created DataLoaders + num_workers: The number of workers to use in all created DataLoaders + unsupervised_mode: Makes the train dataloader return imagery from the train, + val, and test sets + val_split_pct: What percentage of the dataset to use as a validation set + test_split_pct: What percentage of the dataset to use as a test set + """ + super().__init__() # type: ignore[no-untyped-call] + self.root_dir = root_dir + self.bands = bands + self.batch_size = batch_size + self.num_workers = num_workers + self.unsupervised_mode = unsupervised_mode + + self.val_split_pct = val_split_pct + self.test_split_pct = test_split_pct + + if bands == "all": + self.mins = self.band_mins[:, None, None] + self.maxs = self.band_maxs[:, None, None] + elif bands == "s1": + self.mins = self.band_mins[:2, None, None] + self.maxs = self.band_maxs[:2, None, None] + else: + self.mins = self.band_mins[2:, None, None] + self.maxs = self.band_maxs[2:, None, None] + + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset.""" + sample["image"] = sample["image"].float() + sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins) + sample["image"] = torch.clip( # type: ignore[attr-defined] + sample["image"], min=0.0, max=1.0 + ) + return sample + + def prepare_data(self) -> None: + """Make sure that the dataset is downloaded. + + This method is only called once per run. + """ + BigEarthNet(self.root_dir, bands=self.bands, checksum=False) + + def setup(self, stage: Optional[str] = None) -> None: + """Initialize the main ``Dataset`` objects. + + This method is called once per GPU per run. + """ + transforms = Compose([self.preprocess]) + + if not self.unsupervised_mode: + + dataset = BigEarthNet( + self.root_dir, bands=self.bands, transforms=transforms + ) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + ) + else: + self.train_dataset = BigEarthNet( # type: ignore[assignment] + self.root_dir, bands=self.bands, transforms=transforms + ) + self.val_dataset, self.test_dataset = None, None # type: ignore[assignment] + + def train_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for training.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for validation.""" + if self.unsupervised_mode or self.val_split_pct == 0: + return self.train_dataloader() + else: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Return a DataLoader for testing.""" + if self.unsupervised_mode or self.test_split_pct == 0: + return self.train_dataloader() + else: + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/torchgeo/trainers/tasks.py b/torchgeo/trainers/tasks.py index 3554ce83219..037633445a8 100644 --- a/torchgeo/trainers/tasks.py +++ b/torchgeo/trainers/tasks.py @@ -57,7 +57,7 @@ def config_model(self) -> None: # Update first layer if in_channels != 3: - w_old = None + w_old = torch.empty(0) # type: ignore[attr-defined] if pretrained: w_old = torch.clone( # type: ignore[attr-defined] self.model.conv1.weight @@ -75,7 +75,11 @@ def config_model(self) -> None: w_new = torch.clone( # type: ignore[attr-defined] self.model.conv1.weight ).detach() - w_new[:, :3, :, :] = w_old + if in_channels > 3: + w_new[:, :3, :, :] = w_old + else: + w_old = w_old[:, :in_channels, :, :] + w_new[:, :in_channels, :, :] = w_old self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501 w_new ) @@ -266,6 +270,120 @@ def configure_optimizers(self) -> Dict[str, Any]: } +class MultiLabelClassificationTask(ClassificationTask): + """Abstract base class for multi label image classification LightningModules.""" + + #: number of classes in dataset + num_classes: int = 43 + + def config_task(self) -> None: + """Configures the task based on kwargs parameters passed to the constructor.""" + self.config_model() + + if self.hparams["loss"] == "bce": + self.loss = nn.BCEWithLogitsLoss() # type: ignore[attr-defined] + else: + raise ValueError(f"Loss type '{self.hparams['loss']}' is not valid.") + + def __init__(self, **kwargs: Any) -> None: + """Initialize the LightningModule with a model and loss function. + + Keyword Args: + classification_model: Name of the classification model use + loss: Name of the loss function + weights: Either "random", "imagenet_only", "imagenet_and_random", or + "random_rgb" + """ + super().__init__(**kwargs) + self.save_hyperparameters() # creates `self.hparams` from kwargs + + self.config_task() + + self.train_metrics = MetricCollection( + { + "OverallAccuracy": Accuracy( + num_classes=self.num_classes, average="micro", multiclass=False + ), + "AverageAccuracy": Accuracy( + num_classes=self.num_classes, average="macro", multiclass=False + ), + "F1Score": FBeta( + num_classes=self.num_classes, + beta=1.0, + average="micro", + multiclass=False, + ), + }, + prefix="train_", + ) + self.val_metrics = self.train_metrics.clone(prefix="val_") + self.test_metrics = self.train_metrics.clone(prefix="test_") + + def training_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> Tensor: + """Training step. + + Args: + batch: Current batch + batch_idx: Index of current batch + Returns: + training loss + """ + x = batch["image"] + y = batch["label"] + y_hat = self.forward(x) + y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined] + + loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined] + + # by default, the train step logs every `log_every_n_steps` steps where + # `log_every_n_steps` is a parameter to the `Trainer` object + self.log("train_loss", loss, on_step=True, on_epoch=False) + self.train_metrics(y_hat_hard, y) + + return cast(Tensor, loss) + + def validation_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Validation step. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["label"] + y_hat = self.forward(x) + y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined] + + loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined] + + self.log("val_loss", loss, on_step=False, on_epoch=True) + self.val_metrics(y_hat_hard, y) + + def test_step( # type: ignore[override] + self, batch: Dict[str, Any], batch_idx: int + ) -> None: + """Test step. + + Args: + batch: Current batch + batch_idx: Index of current batch + """ + x = batch["image"] + y = batch["label"] + y_hat = self.forward(x) + y_hat_hard = torch.softmax(y_hat, dim=-1) # type: ignore[attr-defined] + + loss = self.loss(y_hat, y.to(torch.float)) # type: ignore[attr-defined] + + # by default, the test and validation steps only log per *epoch* + self.log("test_loss", loss, on_step=False, on_epoch=True) + self.test_metrics(y_hat_hard, y) + + class RegressionTask(pl.LightningModule): """LightningModule for training models on regression datasets.""" diff --git a/train.py b/train.py index 2d646b30b83..dc17a1b1f72 100755 --- a/train.py +++ b/train.py @@ -14,6 +14,8 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from torchgeo.trainers import ( + BigEarthNetClassificationTask, + BigEarthNetDataModule, BYOLTask, ChesapeakeCVPRDataModule, ChesapeakeCVPRSegmentationTask, @@ -36,6 +38,7 @@ TASK_TO_MODULES_MAPPING: Dict[ str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]] ] = { + "bigearthnet": (BigEarthNetClassificationTask, BigEarthNetDataModule), "byol": (BYOLTask, ChesapeakeCVPRDataModule), "chesapeake_cvpr": (ChesapeakeCVPRSegmentationTask, ChesapeakeCVPRDataModule), "cyclone": (RegressionTask, CycloneDataModule),