Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speeding up and reducing the memory footprint of the trainer tests #344

Merged
merged 2 commits into from
Jan 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torchgeo.trainers import BYOLTask
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation

from .test_utils import ClassificationTestModel


class TestBYOL:
def test_custom_augment_fn(self) -> None:
Expand Down Expand Up @@ -53,6 +55,8 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
model_kwargs = conf_dict["module"]
model = BYOLTask(**model_kwargs)

model.encoder = ClassificationTestModel(**model_kwargs)

# Instantiate trainer
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1)
trainer.fit(model=model, datamodule=datamodule)
Expand Down
31 changes: 28 additions & 3 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# Licensed under the MIT License.

import os
from typing import Any, Dict, Type, cast
from typing import Any, Dict, Generator, Type, cast

import pytest
import timm
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
from torch.nn.modules import Module

from torchgeo.datamodules import (
BigEarthNetDataModule,
Expand All @@ -17,6 +20,12 @@
)
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask

from .test_utils import ClassificationTestModel


def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)


class TestClassificationTask:
@pytest.mark.parametrize(
Expand All @@ -29,7 +38,12 @@ class TestClassificationTask:
("ucmerced", UCMercedDataModule),
],
)
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
def test_trainer(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
name: str,
classname: Type[LightningDataModule],
) -> None:
if name.startswith("so2sat"):
pytest.importorskip("h5py")

Expand All @@ -42,6 +56,9 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
datamodule = classname(**datamodule_kwargs)

# Instantiate model
monkeypatch.setattr( # type: ignore[attr-defined]
timm, "create_model", create_model
)
model_kwargs = conf_dict["module"]
model = ClassificationTask(**model_kwargs)

Expand Down Expand Up @@ -119,7 +136,12 @@ class TestMultiLabelClassificationTask:
("bigearthnet_s2", BigEarthNetDataModule),
],
)
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
def test_trainer(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
name: str,
classname: Type[LightningDataModule],
) -> None:
conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
Expand All @@ -129,6 +151,9 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
datamodule = classname(**datamodule_kwargs)

# Instantiate model
monkeypatch.setattr( # type: ignore[attr-defined]
timm, "create_model", create_model
)
model_kwargs = conf_dict["module"]
model = MultiLabelClassificationTask(**model_kwargs)

Expand Down
4 changes: 4 additions & 0 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torchgeo.datamodules import COWCCountingDataModule, CycloneDataModule
from torchgeo.trainers import RegressionTask

from .test_utils import RegressionTestModel


class TestRegressionTask:
@pytest.mark.parametrize(
Expand All @@ -30,6 +32,8 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
model_kwargs = conf_dict["module"]
model = RegressionTask(**model_kwargs)

model.model = RegressionTestModel()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could monkeypatch torchvision.models.resnet18 here to save having the ResNet18 object actually being created in the constructor, but it seems like that won't save much time/RAM and will change when we upgrade RegressionTask


# Instantiate trainer
trainer = Trainer(fast_dev_run=True, log_every_n_steps=1)
trainer.fit(model=model, datamodule=datamodule)
Expand Down
22 changes: 20 additions & 2 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# Licensed under the MIT License.

import os
from typing import Any, Dict, Type, cast
from typing import Any, Dict, Generator, Type, cast

import pytest
import segmentation_models_pytorch as smp
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
from torch.nn.modules import Module

from torchgeo.datamodules import (
ChesapeakeCVPRDataModule,
Expand All @@ -18,6 +21,12 @@
)
from torchgeo.trainers import SemanticSegmentationTask

from .test_utils import SegmentationTestModel


def create_model(**kwargs: Any) -> Module:
return SegmentationTestModel(**kwargs)


class TestSemanticSegmentationTask:
@pytest.mark.parametrize(
Expand All @@ -35,7 +44,12 @@ class TestSemanticSegmentationTask:
("sen12ms_s2_reduced", SEN12MSDataModule),
],
)
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
def test_trainer(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
name: str,
classname: Type[LightningDataModule],
) -> None:
conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
Expand All @@ -45,6 +59,10 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
datamodule = classname(**datamodule_kwargs)

# Instantiate model
monkeypatch.setattr(smp, "Unet", create_model) # type: ignore[attr-defined]
monkeypatch.setattr( # type: ignore[attr-defined]
smp, "DeepLabV3Plus", create_model
)
model_kwargs = conf_dict["module"]
model = SemanticSegmentationTask(**model_kwargs)

Expand Down
38 changes: 38 additions & 0 deletions tests/trainers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
from pathlib import Path
from typing import Any, cast

import pytest
import torch
Expand All @@ -16,6 +17,43 @@
)


class ClassificationTestModel(Module):
def __init__(
self, in_chans: int = 3, num_classes: int = 1000, **kwargs: Any
) -> None:
super().__init__()
self.conv1 = nn.Conv2d( # type: ignore[attr-defined]
in_channels=in_chans, out_channels=1, kernel_size=1
)
self.pool = nn.AdaptiveAvgPool2d((1, 1)) # type: ignore[attr-defined]
self.fc = nn.Linear(1, num_classes) # type: ignore[attr-defined]

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.pool(x)
x = torch.flatten(x, 1) # type: ignore[attr-defined]
x = self.fc(x)
return x


class RegressionTestModel(ClassificationTestModel):
def __init__(self, **kwargs: Any) -> None:
super().__init__(in_chans=3, num_classes=1)


class SegmentationTestModel(Module):
def __init__(
self, in_channels: int = 3, classes: int = 1000, **kwargs: Any
) -> None:
super().__init__()
self.conv1 = nn.Conv2d( # type: ignore[attr-defined]
in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, self.conv1(x))


def test_extract_encoder_unsupported_model(tmp_path: Path) -> None:
checkpoint = {"hyper_parameters": {"some_unsupported_model": "resnet18"}}
path = os.path.join(str(tmp_path), "dummy.ckpt")
Expand Down