diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index b3380612b17..b71b08eca93 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import os +import pickle from pathlib import Path from typing import Dict @@ -121,6 +122,14 @@ def test_str(self, dataset: GeoDataset) -> None: assert "bbox: BoundingBox" in out assert "size: 1" in out + def test_picklable(self, dataset: GeoDataset) -> None: + x = pickle.dumps(dataset) + y = pickle.loads(x) + assert dataset.crs == y.crs + assert dataset.res == y.res + assert len(dataset) == len(y) + assert dataset.bounds == y.bounds + def test_abstract(self) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): GeoDataset() # type: ignore[abstract] diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 54509ce70f7..5c09a2d8930 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -172,6 +172,41 @@ def __str__(self) -> str: bbox: {self.bounds} size: {len(self)}""" + # NOTE: This hack should be removed once the following issue is fixed: + # https://github.com/Toblerity/rtree/issues/87 + + def __getstate__( + self, + ) -> Tuple[ + Dict[Any, Any], + List[Tuple[int, Tuple[float, float, float, float, float, float], str]], + ]: + """Define how instances are pickled. + + Returns: + the state necessary to unpickle the instance + """ + objects = self.index.intersection(self.index.bounds, objects=True) + tuples = [(item.id, item.bounds, item.object) for item in objects] + return self.__dict__, tuples + + def __setstate__( + self, + state: Tuple[ + Dict[Any, Any], + List[Tuple[int, Tuple[float, float, float, float, float, float], str]], + ], + ) -> None: + """Define how to unpickle an instance. + + Args: + state: the state of the instance when it was pickled + """ + attrs, tuples = state + self.__dict__.update(attrs) + for item in tuples: + self.index.insert(*item) + @property def bounds(self) -> BoundingBox: """Bounds of the index.