Skip to content

Commit

Permalink
Replaces test models with dummy models
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Dec 31, 2021
1 parent b08934a commit c3e640c
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 6 deletions.
4 changes: 4 additions & 0 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torchgeo.trainers import BYOLTask
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation

from .test_utils import ClassificationTestModel


class TestBYOL:
def test_custom_augment_fn(self) -> None:
Expand Down Expand Up @@ -53,6 +55,8 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
model_kwargs = conf_dict["module"]
model = BYOLTask(**model_kwargs)

model.encoder = ClassificationTestModel(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1)
trainer.fit(model=model, datamodule=datamodule)
Expand Down
31 changes: 28 additions & 3 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# Licensed under the MIT License.

import os
from typing import Any, Dict, Type, cast
from typing import Any, Dict, Generator, Type, cast

import pytest
import timm
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
from torch.nn.modules import Module

from torchgeo.datamodules import (
BigEarthNetDataModule,
Expand All @@ -17,6 +20,12 @@
)
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask

from .test_utils import ClassificationTestModel


def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


class TestClassificationTask:
@pytest.mark.parametrize(
Expand All @@ -29,7 +38,12 @@ class TestClassificationTask:
("ucmerced", UCMercedDataModule),
],
)
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
def test_trainer(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
name: str,
classname: Type[LightningDataModule],
) -> None:
if name == "so2sat":
pytest.importorskip("h5py")

Expand All @@ -42,6 +56,9 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
datamodule = classname(**datamodule_kwargs)

# Instantiate model
monkeypatch.setattr( # type: ignore[attr-defined]
timm, "create_model", create_model
)
model_kwargs = conf_dict["module"]
model = ClassificationTask(**model_kwargs)

Expand Down Expand Up @@ -102,7 +119,12 @@ class TestMultiLabelClassificationTask:
("bigearthnet_s2", BigEarthNetDataModule),
],
)
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
def test_trainer(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
name: str,
classname: Type[LightningDataModule],
) -> None:
conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
Expand All @@ -112,6 +134,9 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
datamodule = classname(**datamodule_kwargs)

# Instantiate model
monkeypatch.setattr( # type: ignore[attr-defined]
timm, "create_model", create_model
)
model_kwargs = conf_dict["module"]
model = MultiLabelClassificationTask(**model_kwargs)

Expand Down
4 changes: 4 additions & 0 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torchgeo.datamodules import COWCCountingDataModule, CycloneDataModule
from torchgeo.trainers import RegressionTask

from .test_utils import RegressionTestModel


class TestRegressionTask:
@pytest.mark.parametrize(
Expand All @@ -30,6 +32,8 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
model_kwargs = conf_dict["module"]
model = RegressionTask(**model_kwargs)

model.model = RegressionTestModel()

# Instantiate trainer
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1)
trainer.fit(model=model, datamodule=datamodule)
Expand Down
22 changes: 20 additions & 2 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# Licensed under the MIT License.

import os
from typing import Any, Dict, Type, cast
from typing import Any, Dict, Generator, Type, cast

import pytest
import segmentation_models_pytorch as smp
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
from torch.nn.modules import Module

from torchgeo.datamodules import (
ChesapeakeCVPRDataModule,
Expand All @@ -18,6 +21,12 @@
)
from torchgeo.trainers import SemanticSegmentationTask

from .test_utils import SegmentationTestModel


def create_model(**kwargs: Any) -> Module:
return SegmentationTestModel(**kwargs)


class TestSemanticSegmentationTask:
@pytest.mark.parametrize(
Expand All @@ -35,7 +44,12 @@ class TestSemanticSegmentationTask:
("sen12ms_s2_reduced", SEN12MSDataModule),
],
)
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
def test_trainer(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
name: str,
classname: Type[LightningDataModule],
) -> None:
conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
Expand All @@ -45,6 +59,10 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
datamodule = classname(**datamodule_kwargs)

# Instantiate model
monkeypatch.setattr(smp, "Unet", create_model) # type: ignore[attr-defined]
monkeypatch.setattr( # type: ignore[attr-defined]
smp, "DeepLabV3Plus", create_model
)
model_kwargs = conf_dict["module"]
model = SemanticSegmentationTask(**model_kwargs)

Expand Down
39 changes: 38 additions & 1 deletion tests/trainers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import os
from pathlib import Path
from typing import Any
from typing import Any, cast

import pytest
import torch
Expand All @@ -17,6 +17,43 @@
)


class ClassificationTestModel(Module):
def __init__(
self, in_chans: int = 3, num_classes: int = 1000, **kwargs: Any
) -> None:
super().__init__()
self.conv1 = nn.Conv2d( # type: ignore[attr-defined]
in_channels=in_chans, out_channels=1, kernel_size=1
)
self.pool = nn.AdaptiveAvgPool2d((1, 1)) # type: ignore[attr-defined]
self.fc = nn.Linear(1, num_classes) # type: ignore[attr-defined]

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.pool(x)
x = torch.flatten(x, 1) # type: ignore[attr-defined]
x = self.fc(x)
return x


class RegressionTestModel(ClassificationTestModel):
def __init__(self, **kwargs: Any) -> None:
super().__init__(in_chans=3, num_classes=1)


class SegmentationTestModel(Module):
def __init__(
self, in_channels: int = 3, classes: int = 1000, **kwargs: Any
) -> None:
super().__init__()
self.conv1 = nn.Conv2d( # type: ignore[attr-defined]
in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, self.conv1(x))


class FakeExperiment(object):
def add_figure(self, *args: Any, **kwargs: Any) -> None:
pass
Expand Down

0 comments on commit c3e640c

Please sign in to comment.