diff --git a/.github/dependabot.yml b/.github/dependabot.yml index fd117a16013..49b4f7202bd 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -9,9 +9,6 @@ updates: schedule: interval: "daily" ignore: - # radiant-mlhub 0.5+ changed download behavior: - # https://github.com/radiantearth/radiant-mlhub/pull/104 - - dependency-name: "radiant-mlhub" # setuptools releases new versions almost daily - dependency-name: "setuptools" update-types: ["version-update:semver-patch"] diff --git a/environment.yml b/environment.yml index c358e36d88f..a1f235f7b12 100644 --- a/environment.yml +++ b/environment.yml @@ -38,7 +38,7 @@ dependencies: - pytorch-lightning>=1.5.1 - git+https://github.com/pytorch/pytorch_sphinx_theme - pyupgrade>=2.4 - - radiant-mlhub>=0.2.1,<0.5 + - radiant-mlhub>=0.2.1 - rtree>=1 - scikit-image>=0.18 - scikit-learn>=0.22 diff --git a/setup.cfg b/setup.cfg index cb293106b74..63481d8eba3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -85,9 +85,7 @@ datasets = pyvista>=0.25.2,<0.39 # radiant-mlhub 0.2.1+ required for api_key bugfix: # https://github.com/radiantearth/radiant-mlhub/pull/48 - # radiant-mlhub 0.5+ changed download behavior: - # https://github.com/radiantearth/radiant-mlhub/pull/104 - radiant-mlhub>=0.2.1,<0.5 + radiant-mlhub>=0.2.1,<0.6 # rarfile 4+ required for wheels rarfile>=4,<5 # scikit-image 0.18+ required for numpy 1.17+ compatibility diff --git a/tests/datasets/test_benin_cashews.py b/tests/datasets/test_benin_cashews.py index 7181b091d19..6c42c59b534 100644 --- a/tests/datasets/test_benin_cashews.py +++ b/tests/datasets/test_benin_cashews.py @@ -16,15 +16,15 @@ from torchgeo.datasets import BeninSmallHolderCashews -class Dataset: +class Collection: def download(self, output_dir: str, **kwargs: str) -> None: glob_path = os.path.join("tests", "data", "ts_cashew_benin", "*.tar.gz") for tarball in glob.iglob(glob_path): shutil.copy(tarball, output_dir) -def fetch(dataset_id: str, **kwargs: str) -> Dataset: - return Dataset() +def fetch(dataset_id: str, **kwargs: str) -> Collection: + return Collection() class TestBeninSmallHolderCashews: @@ -33,7 +33,7 @@ def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path ) -> BeninSmallHolderCashews: radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1") - monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch) + monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) source_md5 = "255efff0f03bc6322470949a09bc76db" labels_md5 = "ed2195d93ca6822d48eb02bc3e81c127" monkeypatch.setitem(BeninSmallHolderCashews.image_meta, "md5", source_md5) diff --git a/tests/datasets/test_cloud_cover.py b/tests/datasets/test_cloud_cover.py index 0fa9188e285..a389982c26f 100644 --- a/tests/datasets/test_cloud_cover.py +++ b/tests/datasets/test_cloud_cover.py @@ -15,7 +15,7 @@ from torchgeo.datasets import CloudCoverDetection -class Dataset: +class Collection: def download(self, output_dir: str, **kwargs: str) -> None: glob_path = os.path.join( "tests", "data", "ref_cloud_cover_detection_challenge_v1", "*.tar.gz" @@ -24,15 +24,15 @@ def download(self, output_dir: str, **kwargs: str) -> None: shutil.copy(tarball, output_dir) -def fetch(dataset_id: str, **kwargs: str) -> Dataset: - return Dataset() +def fetch(dataset_id: str, **kwargs: str) -> Collection: + return Collection() class TestCloudCoverDetection: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CloudCoverDetection: radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1") - monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch) + monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) test_image_meta = { "filename": "ref_cloud_cover_detection_challenge_v1_test_source.tar.gz", diff --git a/tests/datasets/test_cv4a_kenya_crop_type.py b/tests/datasets/test_cv4a_kenya_crop_type.py index 7eeeaedf0b3..37578a13a1d 100644 --- a/tests/datasets/test_cv4a_kenya_crop_type.py +++ b/tests/datasets/test_cv4a_kenya_crop_type.py @@ -16,7 +16,7 @@ from torchgeo.datasets import CV4AKenyaCropType -class Dataset: +class Collection: def download(self, output_dir: str, **kwargs: str) -> None: glob_path = os.path.join( "tests", "data", "ref_african_crops_kenya_02", "*.tar.gz" @@ -25,15 +25,15 @@ def download(self, output_dir: str, **kwargs: str) -> None: shutil.copy(tarball, output_dir) -def fetch(dataset_id: str, **kwargs: str) -> Dataset: - return Dataset() +def fetch(dataset_id: str, **kwargs: str) -> Collection: + return Collection() class TestCV4AKenyaCropType: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CV4AKenyaCropType: radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1") - monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch) + monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) source_md5 = "7f4dcb3f33743dddd73f453176308bfb" labels_md5 = "95fc59f1d94a85ec00931d4d1280bec9" monkeypatch.setitem(CV4AKenyaCropType.image_meta, "md5", source_md5) diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py index 1452ad18225..ee29d44ecc4 100644 --- a/tests/datasets/test_cyclone.py +++ b/tests/datasets/test_cyclone.py @@ -17,14 +17,14 @@ from torchgeo.datasets import TropicalCyclone -class Dataset: +class Collection: def download(self, output_dir: str, **kwargs: str) -> None: for tarball in glob.iglob(os.path.join("tests", "data", "cyclone", "*.tar.gz")): shutil.copy(tarball, output_dir) -def fetch(collection_id: str, **kwargs: str) -> Dataset: - return Dataset() +def fetch(collection_id: str, **kwargs: str) -> Collection: + return Collection() class TestTropicalCyclone: @@ -33,7 +33,7 @@ def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> TropicalCyclone: radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1") - monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch) + monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) md5s = { "train": { "source": "2b818e0a0873728dabf52c7054a0ce4c", diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index 706dcf52552..6887cc4173b 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -15,23 +15,35 @@ from torchgeo.datasets import NASAMarineDebris -class Dataset: +class Collection: def download(self, output_dir: str, **kwargs: str) -> None: glob_path = os.path.join("tests", "data", "nasa_marine_debris", "*.tar.gz") for tarball in glob.iglob(glob_path): shutil.copy(tarball, output_dir) -def fetch(dataset_id: str, **kwargs: str) -> Dataset: - return Dataset() +def fetch(collection_id: str, **kwargs: str) -> Collection: + return Collection() + + +class Collection_corrupted: + def download(self, output_dir: str, **kwargs: str) -> None: + filenames = NASAMarineDebris.filenames + for filename in filenames: + with open(os.path.join(output_dir, filename), "w") as f: + f.write("bad") + + +def fetch_corrupted(collection_id: str, **kwargs: str) -> Collection_corrupted: + return Collection_corrupted() class TestNASAMarineDebris: @pytest.fixture() def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris: radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1") - monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch) - md5s = ["fe8698d1e68b3f24f0b86b04419a797d", "d8084f5a72778349e07ac90ec1e1d990"] + monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch) + md5s = ["6f4f0d2313323950e45bf3fc0c09b5de", "540cf1cf4fd2c13b609d0355abe955d7"] monkeypatch.setattr(NASAMarineDebris, "md5s", md5s) root = str(tmp_path) transforms = nn.Identity() @@ -58,9 +70,25 @@ def test_already_downloaded_not_extracted( ) -> None: shutil.rmtree(dataset.root) os.makedirs(str(tmp_path), exist_ok=True) - Dataset().download(output_dir=str(tmp_path)) + Collection().download(output_dir=str(tmp_path)) NASAMarineDebris(root=str(tmp_path), download=False) + def test_corrupted_previously_downloaded(self, tmp_path: Path) -> None: + filenames = NASAMarineDebris.filenames + for filename in filenames: + with open(os.path.join(tmp_path, filename), "w") as f: + f.write("bad") + with pytest.raises(RuntimeError, match="Dataset checksum mismatch."): + NASAMarineDebris(root=str(tmp_path), download=False, checksum=True) + + def test_corrupted_new_download( + self, tmp_path: Path, monkeypatch: MonkeyPatch + ) -> None: + with pytest.raises(RuntimeError, match="Dataset checksum mismatch."): + radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1") + monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_corrupted) + NASAMarineDebris(root=str(tmp_path), download=True, checksum=True) + def test_not_downloaded(self, tmp_path: Path) -> None: err = "Dataset not found in `root` directory and `download=False`, " "either specify a different `root` directory or use `download=True` " diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index 8cea1f29d60..34dc2e734ae 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -17,7 +17,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive +from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive # TODO: read geospatial information from stac.json files @@ -56,6 +56,7 @@ class BeninSmallHolderCashews(NonGeoDataset): """ dataset_id = "ts_cashew_benin" + collection_ids = ["ts_cashew_benin_source", "ts_cashew_benin_labels"] image_meta = { "filename": "ts_cashew_benin_source.tar.gz", "md5": "957272c86e518a925a4e0d90dab4f92d", @@ -416,7 +417,8 @@ def _download(self, api_key: Optional[str] = None) -> None: print("Files already downloaded and verified") return - download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key) + for collection_id in self.collection_ids: + download_radiant_mlhub_collection(collection_id, self.root, api_key) image_archive_path = os.path.join(self.root, self.image_meta["filename"]) target_archive_path = os.path.join(self.root, self.target_meta["filename"]) diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py index 43f01e34cba..c11d2d463b9 100644 --- a/torchgeo/datasets/cloud_cover.py +++ b/torchgeo/datasets/cloud_cover.py @@ -14,7 +14,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive +from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive # TODO: read geospatial information from stac.json files @@ -54,7 +54,12 @@ class CloudCoverDetection(NonGeoDataset): .. versionadded:: 0.4 """ - dataset_id = "ref_cloud_cover_detection_challenge_v1" + collection_ids = [ + "ref_cloud_cover_detection_challenge_v1_train_source", + "ref_cloud_cover_detection_challenge_v1_train_labels", + "ref_cloud_cover_detection_challenge_v1_test_source", + "ref_cloud_cover_detection_challenge_v1_test_labels", + ] image_meta = { "train": { @@ -332,7 +337,8 @@ def _download(self, api_key: Optional[str] = None) -> None: print("Files already downloaded and verified") return - download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key) + for collection_id in self.collection_ids: + download_radiant_mlhub_collection(collection_id, self.root, api_key) image_archive_path = os.path.join( self.root, self.image_meta[self.split]["filename"] diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index f8b901a7b94..b6db80ea247 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -15,7 +15,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive +from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive # TODO: read geospatial information from stac.json files @@ -56,7 +56,10 @@ class CV4AKenyaCropType(NonGeoDataset): imagery and labels from the Radiant Earth MLHub """ - dataset_id = "ref_african_crops_kenya_02" + collection_ids = [ + "ref_african_crops_kenya_02_labels", + "ref_african_crops_kenya_02_source", + ] image_meta = { "filename": "ref_african_crops_kenya_02_source.tar.gz", "md5": "9c2004782f6dc83abb1bf45ba4d0da46", @@ -394,7 +397,8 @@ def _download(self, api_key: Optional[str] = None) -> None: print("Files already downloaded and verified") return - download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key) + for collection_id in self.collection_ids: + download_radiant_mlhub_collection(collection_id, self.root, api_key) image_archive_path = os.path.join(self.root, self.image_meta["filename"]) target_archive_path = os.path.join(self.root, self.target_meta["filename"]) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index 45821d61f1f..f8bf4645225 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -15,7 +15,7 @@ from torch import Tensor from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive +from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive class TropicalCyclone(NonGeoDataset): @@ -45,6 +45,12 @@ class TropicalCyclone(NonGeoDataset): """ collection_id = "nasa_tropical_storm_competition" + collection_ids = [ + "nasa_tropical_storm_competition_train_source", + "nasa_tropical_storm_competition_test_source", + "nasa_tropical_storm_competition_train_labels", + "nasa_tropical_storm_competition_test_labels", + ] md5s = { "train": { "source": "97e913667a398704ea8d28196d91dad6", @@ -207,7 +213,8 @@ def _download(self, api_key: Optional[str] = None) -> None: print("Files already downloaded and verified") return - download_radiant_mlhub_dataset(self.collection_id, self.root, api_key) + for collection_id in self.collection_ids: + download_radiant_mlhub_collection(collection_id, self.root, api_key) for split, resources in self.md5s.items(): for resource_type in resources: diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 4a506e5c701..026af81123e 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -14,7 +14,7 @@ from torchvision.utils import draw_bounding_boxes from .geo import NonGeoDataset -from .utils import download_radiant_mlhub_dataset, extract_archive +from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive class NASAMarineDebris(NonGeoDataset): @@ -51,7 +51,7 @@ class NASAMarineDebris(NonGeoDataset): .. versionadded:: 0.2 """ - dataset_id = "nasa_marine_debris" + collection_ids = ["nasa_marine_debris_source", "nasa_marine_debris_labels"] directories = ["nasa_marine_debris_source", "nasa_marine_debris_labels"] filenames = ["nasa_marine_debris_source.tar.gz", "nasa_marine_debris_labels.tar.gz"] md5s = ["fe8698d1e68b3f24f0b86b04419a797d", "d8084f5a72778349e07ac90ec1e1d990"] @@ -189,9 +189,11 @@ def _verify(self) -> None: # Check if zip file already exists (if so then extract) exists = [] - for filename in self.filenames: + for filename, md5 in zip(self.filenames, self.md5s): filepath = os.path.join(self.root, filename) if os.path.exists(filepath): + if self.checksum and not check_integrity(filepath, md5): + raise RuntimeError("Dataset checksum mismatch.") exists.append(True) extract_archive(filepath) else: @@ -208,11 +210,13 @@ def _verify(self) -> None: "to automatically download the dataset." ) - # TODO: need a checksum check in here post downloading # Download and extract the dataset - download_radiant_mlhub_dataset(self.dataset_id, self.root, self.api_key) - for filename in self.filenames: + for collection_id in self.collection_ids: + download_radiant_mlhub_collection(collection_id, self.root, self.api_key) + for filename, md5 in zip(self.filenames, self.md5s): filepath = os.path.join(self.root, filename) + if self.checksum and not check_integrity(filepath, md5): + raise RuntimeError("Dataset checksum mismatch.") extract_archive(filepath) def plot(