diff --git a/tests/datasets/test_etci2021.py b/tests/datasets/test_etci2021.py index f70f29f07a5..feb94f733cc 100644 --- a/tests/datasets/test_etci2021.py +++ b/tests/datasets/test_etci2021.py @@ -13,7 +13,7 @@ from _pytest.monkeypatch import MonkeyPatch import torchgeo.datasets.utils -from torchgeo.datasets import ETCI2021 +from torchgeo.datasets import ETCI2021, ETCI2021DataModule def download_url(url: str, root: str, *args: str) -> None: @@ -91,3 +91,25 @@ def test_plot(self, dataset: ETCI2021) -> None: ETCI2021.plot(x, show_titles=False) x["prediction"] = x["mask"][0].clone() ETCI2021.plot(x) + + +class TestETCI2021DataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> ETCI2021DataModule: + root = os.path.join("tests", "data", "etci2021") + seed = 0 + batch_size = 2 + num_workers = 0 + dm = ETCI2021DataModule(root, seed, batch_size, num_workers) + dm.prepare_data() + dm.setup() + return dm + + def test_train_dataloader(self, datamodule: ETCI2021DataModule) -> None: + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: ETCI2021DataModule) -> None: + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: ETCI2021DataModule) -> None: + next(iter(datamodule.test_dataloader()))