Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for radiant-mlhub 0.5+ #1102

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ updates:
schedule:
interval: "daily"
ignore:
# radiant-mlhub 0.5+ changed download behavior:
# https://github.com/radiantearth/radiant-mlhub/pull/104
- dependency-name: "radiant-mlhub"
# setuptools releases new versions almost daily
- dependency-name: "setuptools"
update-types: ["version-update:semver-patch"]
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies:
- pytorch-lightning>=1.5.1
- git+https://github.com/pytorch/pytorch_sphinx_theme
- pyupgrade>=2.4
- radiant-mlhub>=0.2.1,<0.5
- radiant-mlhub>=0.2.1
- rtree>=1
- scikit-image>=0.18
- scikit-learn>=0.22
Expand Down
4 changes: 1 addition & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ datasets =
pyvista>=0.25.2,<0.39
# radiant-mlhub 0.2.1+ required for api_key bugfix:
# https://github.com/radiantearth/radiant-mlhub/pull/48
# radiant-mlhub 0.5+ changed download behavior:
# https://github.com/radiantearth/radiant-mlhub/pull/104
radiant-mlhub>=0.2.1,<0.5
radiant-mlhub>=0.2.1,<0.6
# rarfile 4+ required for wheels
rarfile>=4,<5
# scikit-image 0.18+ required for numpy 1.17+ compatibility
Expand Down
8 changes: 4 additions & 4 deletions tests/datasets/test_benin_cashews.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
from torchgeo.datasets import BeninSmallHolderCashews


class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When radiant-mlhub 0.5 was first released, we actually didn't notice when things broke because we replace their download method with our own. Do you know if it's possible to monkeypatch something else farther internal so that we can actually test their download method but still use a local file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually have never used monkeypatch! So I'm not sure how that would work. It looks like the download call just creates a requests. Session and download chunks for the dataset in a threadpool. We would probably have to over-ride something in requests? But then our test data would need to be chunked I think.

glob_path = os.path.join("tests", "data", "ts_cashew_benin", "*.tar.gz")
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)


def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()


class TestBeninSmallHolderCashews:
Expand All @@ -33,7 +33,7 @@ def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path
) -> BeninSmallHolderCashews:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
source_md5 = "255efff0f03bc6322470949a09bc76db"
labels_md5 = "ed2195d93ca6822d48eb02bc3e81c127"
monkeypatch.setitem(BeninSmallHolderCashews.image_meta, "md5", source_md5)
Expand Down
8 changes: 4 additions & 4 deletions tests/datasets/test_cloud_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchgeo.datasets import CloudCoverDetection


class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
"tests", "data", "ref_cloud_cover_detection_challenge_v1", "*.tar.gz"
Expand All @@ -24,15 +24,15 @@ def download(self, output_dir: str, **kwargs: str) -> None:
shutil.copy(tarball, output_dir)


def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()


class TestCloudCoverDetection:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CloudCoverDetection:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)

test_image_meta = {
"filename": "ref_cloud_cover_detection_challenge_v1_test_source.tar.gz",
Expand Down
8 changes: 4 additions & 4 deletions tests/datasets/test_cv4a_kenya_crop_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchgeo.datasets import CV4AKenyaCropType


class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
"tests", "data", "ref_african_crops_kenya_02", "*.tar.gz"
Expand All @@ -25,15 +25,15 @@ def download(self, output_dir: str, **kwargs: str) -> None:
shutil.copy(tarball, output_dir)


def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()


class TestCV4AKenyaCropType:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CV4AKenyaCropType:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
source_md5 = "7f4dcb3f33743dddd73f453176308bfb"
labels_md5 = "95fc59f1d94a85ec00931d4d1280bec9"
monkeypatch.setitem(CV4AKenyaCropType.image_meta, "md5", source_md5)
Expand Down
8 changes: 4 additions & 4 deletions tests/datasets/test_cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from torchgeo.datasets import TropicalCyclone


class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
for tarball in glob.iglob(os.path.join("tests", "data", "cyclone", "*.tar.gz")):
shutil.copy(tarball, output_dir)


def fetch(collection_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(collection_id: str, **kwargs: str) -> Collection:
return Collection()


class TestTropicalCyclone:
Expand All @@ -33,7 +33,7 @@ def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> TropicalCyclone:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
md5s = {
"train": {
"source": "2b818e0a0873728dabf52c7054a0ce4c",
Expand Down
40 changes: 34 additions & 6 deletions tests/datasets/test_nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,35 @@
from torchgeo.datasets import NASAMarineDebris


class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join("tests", "data", "nasa_marine_debris", "*.tar.gz")
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)


def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(collection_id: str, **kwargs: str) -> Collection:
return Collection()


class Collection_corrupted:
def download(self, output_dir: str, **kwargs: str) -> None:
filenames = NASAMarineDebris.filenames
for filename in filenames:
with open(os.path.join(output_dir, filename), "w") as f:
f.write("bad")


def fetch_corrupted(collection_id: str, **kwargs: str) -> Collection_corrupted:
return Collection_corrupted()


class TestNASAMarineDebris:
@pytest.fixture()
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
md5s = ["fe8698d1e68b3f24f0b86b04419a797d", "d8084f5a72778349e07ac90ec1e1d990"]
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
md5s = ["6f4f0d2313323950e45bf3fc0c09b5de", "540cf1cf4fd2c13b609d0355abe955d7"]
monkeypatch.setattr(NASAMarineDebris, "md5s", md5s)
root = str(tmp_path)
transforms = nn.Identity()
Expand All @@ -58,9 +70,25 @@ def test_already_downloaded_not_extracted(
) -> None:
shutil.rmtree(dataset.root)
os.makedirs(str(tmp_path), exist_ok=True)
Dataset().download(output_dir=str(tmp_path))
Collection().download(output_dir=str(tmp_path))
NASAMarineDebris(root=str(tmp_path), download=False)

def test_corrupted_previously_downloaded(self, tmp_path: Path) -> None:
filenames = NASAMarineDebris.filenames
for filename in filenames:
with open(os.path.join(tmp_path, filename), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset checksum mismatch."):
NASAMarineDebris(root=str(tmp_path), download=False, checksum=True)

def test_corrupted_new_download(
self, tmp_path: Path, monkeypatch: MonkeyPatch
) -> None:
with pytest.raises(RuntimeError, match="Dataset checksum mismatch."):
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_corrupted)
NASAMarineDebris(root=str(tmp_path), download=True, checksum=True)

def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
Expand Down
6 changes: 4 additions & 2 deletions torchgeo/datasets/benin_cashews.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch import Tensor

from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved


# TODO: read geospatial information from stac.json files
Expand Down Expand Up @@ -56,6 +56,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
"""

dataset_id = "ts_cashew_benin"
collection_ids = ["ts_cashew_benin_source", "ts_cashew_benin_labels"]
image_meta = {
"filename": "ts_cashew_benin_source.tar.gz",
"md5": "957272c86e518a925a4e0d90dab4f92d",
Expand Down Expand Up @@ -416,7 +417,8 @@ def _download(self, api_key: Optional[str] = None) -> None:
print("Files already downloaded and verified")
return

download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)

image_archive_path = os.path.join(self.root, self.image_meta["filename"])
target_archive_path = os.path.join(self.root, self.target_meta["filename"])
Expand Down
12 changes: 9 additions & 3 deletions torchgeo/datasets/cloud_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import Tensor

from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive


# TODO: read geospatial information from stac.json files
Expand Down Expand Up @@ -54,7 +54,12 @@ class CloudCoverDetection(NonGeoDataset):
.. versionadded:: 0.4
"""

dataset_id = "ref_cloud_cover_detection_challenge_v1"
collection_ids = [
"ref_cloud_cover_detection_challenge_v1_train_source",
"ref_cloud_cover_detection_challenge_v1_train_labels",
"ref_cloud_cover_detection_challenge_v1_test_source",
"ref_cloud_cover_detection_challenge_v1_test_labels",
]

image_meta = {
"train": {
Expand Down Expand Up @@ -332,7 +337,8 @@ def _download(self, api_key: Optional[str] = None) -> None:
print("Files already downloaded and verified")
return

download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)

image_archive_path = os.path.join(
self.root, self.image_meta[self.split]["filename"]
Expand Down
10 changes: 7 additions & 3 deletions torchgeo/datasets/cv4a_kenya_crop_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import Tensor

from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive


# TODO: read geospatial information from stac.json files
Expand Down Expand Up @@ -56,7 +56,10 @@ class CV4AKenyaCropType(NonGeoDataset):
imagery and labels from the Radiant Earth MLHub
"""

dataset_id = "ref_african_crops_kenya_02"
collection_ids = [
"ref_african_crops_kenya_02_labels",
"ref_african_crops_kenya_02_source",
]
image_meta = {
"filename": "ref_african_crops_kenya_02_source.tar.gz",
"md5": "9c2004782f6dc83abb1bf45ba4d0da46",
Expand Down Expand Up @@ -394,7 +397,8 @@ def _download(self, api_key: Optional[str] = None) -> None:
print("Files already downloaded and verified")
return

download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)

image_archive_path = os.path.join(self.root, self.image_meta["filename"])
target_archive_path = os.path.join(self.root, self.target_meta["filename"])
Expand Down
11 changes: 9 additions & 2 deletions torchgeo/datasets/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import Tensor

from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive


class TropicalCyclone(NonGeoDataset):
Expand Down Expand Up @@ -45,6 +45,12 @@ class TropicalCyclone(NonGeoDataset):
"""

collection_id = "nasa_tropical_storm_competition"
collection_ids = [
"nasa_tropical_storm_competition_train_source",
"nasa_tropical_storm_competition_test_source",
"nasa_tropical_storm_competition_train_labels",
"nasa_tropical_storm_competition_test_labels",
]
md5s = {
"train": {
"source": "97e913667a398704ea8d28196d91dad6",
Expand Down Expand Up @@ -207,7 +213,8 @@ def _download(self, api_key: Optional[str] = None) -> None:
print("Files already downloaded and verified")
return

download_radiant_mlhub_dataset(self.collection_id, self.root, api_key)
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)

for split, resources in self.md5s.items():
for resource_type in resources:
Expand Down
16 changes: 10 additions & 6 deletions torchgeo/datasets/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchvision.utils import draw_bounding_boxes

from .geo import NonGeoDataset
from .utils import download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive


class NASAMarineDebris(NonGeoDataset):
Expand Down Expand Up @@ -51,7 +51,7 @@ class NASAMarineDebris(NonGeoDataset):
.. versionadded:: 0.2
"""

dataset_id = "nasa_marine_debris"
collection_ids = ["nasa_marine_debris_source", "nasa_marine_debris_labels"]
directories = ["nasa_marine_debris_source", "nasa_marine_debris_labels"]
filenames = ["nasa_marine_debris_source.tar.gz", "nasa_marine_debris_labels.tar.gz"]
md5s = ["fe8698d1e68b3f24f0b86b04419a797d", "d8084f5a72778349e07ac90ec1e1d990"]
Expand Down Expand Up @@ -189,9 +189,11 @@ def _verify(self) -> None:

# Check if zip file already exists (if so then extract)
exists = []
for filename in self.filenames:
for filename, md5 in zip(self.filenames, self.md5s):
filepath = os.path.join(self.root, filename)
if os.path.exists(filepath):
if self.checksum and not check_integrity(filepath, md5):
raise RuntimeError("Dataset checksum mismatch.")
exists.append(True)
extract_archive(filepath)
else:
Expand All @@ -208,11 +210,13 @@ def _verify(self) -> None:
"to automatically download the dataset."
)

# TODO: need a checksum check in here post downloading
# Download and extract the dataset
download_radiant_mlhub_dataset(self.dataset_id, self.root, self.api_key)
for filename in self.filenames:
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, self.api_key)
for filename, md5 in zip(self.filenames, self.md5s):
filepath = os.path.join(self.root, filename)
if self.checksum and not check_integrity(filepath, md5):
raise RuntimeError("Dataset checksum mismatch.")
extract_archive(filepath)

def plot(
Expand Down