From e04e1a53fd6a21506693d53f8a8519dbf4261817 Mon Sep 17 00:00:00 2001 From: Dimitris Mantas <75796651+DimitrisMantas@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:45:52 +0100 Subject: [PATCH] Implement deterministic GeoDataset (#1908) --- tests/datasets/test_geo.py | 15 +++++++++++++++ torchgeo/datasets/geo.py | 5 +++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index 0e1a4ce17e3..8af3333a8c3 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -177,6 +177,21 @@ def test_files_property_for_virtual_files(self) -> None: ] assert len(CustomGeoDataset(paths=paths).files) == len(paths) + def test_files_property_ordered(self) -> None: + """Ensure that the list of files is ordered.""" + paths = ["file://file3.tif", "file://file1.tif", "file://file2.tif"] + assert CustomGeoDataset(paths=paths).files == sorted(paths) + + def test_files_property_deterministic(self) -> None: + """Ensure that the list of files is consistent regardless of their original + order. + """ + paths1 = ["file://file3.tif", "file://file1.tif", "file://file2.tif"] + paths2 = ["file://file2.tif", "file://file3.tif", "file://file1.tif"] + assert ( + CustomGeoDataset(paths=paths1).files == CustomGeoDataset(paths=paths2).files + ) + class TestRasterDataset: @pytest.fixture(params=zip([["R", "G", "B"], None], [True, False])) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 1e2382db0fe..1f1ede5890d 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -287,7 +287,7 @@ def res(self, new_res: float) -> None: self._res = new_res @property - def files(self) -> set[str]: + def files(self) -> list[str]: """A list of all files in the dataset. Returns: @@ -316,7 +316,8 @@ def files(self) -> set[str]: UserWarning, ) - return files + # Sort the output to enforce deterministic behavior. + return sorted(files) class RasterDataset(GeoDataset):