Skip to content

Commit

Permalink
Test datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Nov 13, 2021
1 parent 64027c5 commit 41502e1
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion tests/datasets/test_etci2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from _pytest.monkeypatch import MonkeyPatch

import torchgeo.datasets.utils
from torchgeo.datasets import ETCI2021
from torchgeo.datasets import ETCI2021, ETCI2021DataModule


def download_url(url: str, root: str, *args: str) -> None:
Expand Down Expand Up @@ -91,3 +91,25 @@ def test_plot(self, dataset: ETCI2021) -> None:
ETCI2021.plot(x, show_titles=False)
x["prediction"] = x["mask"][0].clone()
ETCI2021.plot(x)


class TestETCI2021DataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> ETCI2021DataModule:
root = os.path.join("tests", "data", "etci2021")
seed = 0
batch_size = 2
num_workers = 0
dm = ETCI2021DataModule(root, seed, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: ETCI2021DataModule) -> None:
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: ETCI2021DataModule) -> None:
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: ETCI2021DataModule) -> None:
next(iter(datamodule.test_dataloader()))

0 comments on commit 41502e1

Please sign in to comment.