Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL4EO-L: add download support #1424

Merged
merged 7 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/conf/ssl4eo_l_byol_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module:

datamodule:
_target_: torchgeo.datamodules.SSL4EOLDataModule
root: "tests/data/ssl4eo/l/tm_toa"
root: "tests/data/ssl4eo/l"
split: "tm_toa"
seasons: 1
batch_size: 2
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/ssl4eo_l_byol_2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module:

datamodule:
_target_: torchgeo.datamodules.SSL4EOLDataModule
root: "tests/data/ssl4eo/l/etm_sr"
root: "tests/data/ssl4eo/l"
split: "etm_sr"
seasons: 2
batch_size: 2
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/ssl4eo_l_moco_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module:

datamodule:
_target_: torchgeo.datamodules.SSL4EOLDataModule
root: "tests/data/ssl4eo/l/etm_toa"
root: "tests/data/ssl4eo/l"
split: "etm_toa"
seasons: 1
batch_size: 2
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/ssl4eo_l_moco_2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ module:

datamodule:
_target_: torchgeo.datamodules.SSL4EOLDataModule
root: "tests/data/ssl4eo/l/oli_tirs_toa"
root: "tests/data/ssl4eo/l"
split: "oli_tirs_toa"
seasons: 2
batch_size: 2
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/ssl4eo_l_simclr_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module:

datamodule:
_target_: torchgeo.datamodules.SSL4EOLDataModule
root: "tests/data/ssl4eo/l/oli_sr"
root: "tests/data/ssl4eo/l"
split: "oli_sr"
seasons: 1
batch_size: 2
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/ssl4eo_l_simclr_2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module:

datamodule:
_target_: torchgeo.datamodules.SSL4EOLDataModule
root: "tests/data/ssl4eo/l/tm_toa"
root: "tests/data/ssl4eo/l"
split: "tm_toa"
seasons: 2
batch_size: 2
Expand Down
42 changes: 32 additions & 10 deletions tests/data/ssl4eo/l/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from rasterio.crs import CRS

SIZE = 36
CHUNK_SIZE = 2**12

np.random.seed(0)

FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]]

filenames: FILENAME_HIERARCHY = {
"tm_toa": {
"ssl4eo_l_tm_toa": {
"0000002": {
"LT05_172034_20010526": ["all_bands.tif"],
"LT05_172034_20020310": ["all_bands.tif"],
Expand All @@ -34,7 +35,7 @@
"LT5_223084_20020923": ["all_bands.tif"],
},
},
"etm_sr": {
"ssl4eo_l_etm_toa": {
"0000002": {
"LE07_172034_20010526": ["all_bands.tif"],
"LE07_172034_20020310": ["all_bands.tif"],
Expand All @@ -48,7 +49,7 @@
"LE07_223084_20020923": ["all_bands.tif"],
},
},
"etm_toa": {
"ssl4eo_l_etm_sr": {
"0000002": {
"LE07_172034_20010526": ["all_bands.tif"],
"LE07_172034_20020310": ["all_bands.tif"],
Expand All @@ -62,7 +63,7 @@
"LE07_223084_20020923": ["all_bands.tif"],
},
},
"oli_tirs_toa": {
"ssl4eo_l_oli_tirs_toa": {
"0000002": {
"LC08_172034_20210306": ["all_bands.tif"],
"LC08_172034_20210829": ["all_bands.tif"],
Expand All @@ -76,7 +77,7 @@
"LC08_223084_20221211": ["all_bands.tif"],
},
},
"oli_sr": {
"ssl4eo_l_oli_sr": {
"0000002": {
"LC08_172034_20210306": ["all_bands.tif"],
"LC08_172034_20210829": ["all_bands.tif"],
Expand All @@ -92,7 +93,13 @@
},
}

num_bands = {"tm_toa": 7, "etm_sr": 6, "etm_toa": 9, "oli_tirs_toa": 11, "oli_sr": 7}
num_bands = {
"ssl4eo_l_tm_toa": 7,
"ssl4eo_l_etm_toa": 9,
"ssl4eo_l_etm_sr": 6,
"ssl4eo_l_oli_tirs_toa": 11,
"ssl4eo_l_oli_sr": 7,
}


def create_file(path: str) -> None:
Expand Down Expand Up @@ -141,10 +148,25 @@ def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:

directories = filenames.keys()
for directory in directories:
# Create tarballs
# Create tarball
shutil.make_archive(directory, "gztar", ".", directory)

# Split tarball
path = f"{directory}.tar.gz"
paths = []
with open(path, "rb") as f:
suffix = "a"
while chunk := f.read(CHUNK_SIZE):
split = f"{path}a{suffix}"
with open(split, "wb") as g:
g.write(chunk)
suffix = chr(ord(suffix) + 1)
paths.append(split)

os.remove(path)

# Compute checksums
with open(f"{directory}.tar.gz", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(directory, md5)
for path in paths:
with open(path, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(path, md5)
Binary file removed tests/data/ssl4eo/l/etm_sr.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed tests/data/ssl4eo/l/etm_toa.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed tests/data/ssl4eo/l/oli_sr.tar.gz
Binary file not shown.
Binary file removed tests/data/ssl4eo/l/oli_tirs_toa.tar.gz
Binary file not shown.
Binary file added tests/data/ssl4eo/l/ssl4eo_l_etm_sr.tar.gzaa
Binary file not shown.
Binary file added tests/data/ssl4eo/l/ssl4eo_l_etm_sr.tar.gzab
Binary file not shown.
Binary file added tests/data/ssl4eo/l/ssl4eo_l_etm_sr.tar.gzac
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/ssl4eo/l/ssl4eo_l_etm_toa.tar.gzaa
Binary file not shown.
Binary file added tests/data/ssl4eo/l/ssl4eo_l_etm_toa.tar.gzab
Binary file not shown.
Binary file added tests/data/ssl4eo/l/ssl4eo_l_etm_toa.tar.gzac
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/ssl4eo/l/ssl4eo_l_oli_sr.tar.gzaa
Binary file not shown.
Binary file added tests/data/ssl4eo/l/ssl4eo_l_oli_sr.tar.gzab
Binary file not shown.
Binary file added tests/data/ssl4eo/l/ssl4eo_l_oli_sr.tar.gzac
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/ssl4eo/l/ssl4eo_l_tm_toa.tar.gzaa
Binary file not shown.
Binary file added tests/data/ssl4eo/l/ssl4eo_l_tm_toa.tar.gzab
Binary file not shown.
Binary file added tests/data/ssl4eo/l/ssl4eo_l_tm_toa.tar.gzac
Binary file not shown.
Binary file removed tests/data/ssl4eo/l/tm_toa.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
62 changes: 59 additions & 3 deletions tests/datasets/test_ssl4eo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import glob
import os
import shutil
from pathlib import Path
Expand All @@ -13,16 +14,57 @@
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset

import torchgeo
from torchgeo.datasets import SSL4EOL, SSL4EOS12


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


class TestSSL4EOL:
@pytest.fixture(params=zip(SSL4EOL.metadata.keys(), [1, 1, 2, 2, 4]))
def dataset(self, request: SubRequest) -> SSL4EOL:
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> SSL4EOL:
monkeypatch.setattr(torchgeo.datasets.ssl4eo, "download_url", download_url)

url = os.path.join("tests", "data", "ssl4eo", "l", "ssl4eo_l_{0}.tar.gz{1}")
monkeypatch.setattr(SSL4EOL, "url", url)

checksums = {
"tm_toa": {
"aa": "010b9d72b476e0e30741c17725f84e5c",
"ab": "39171bd7bca8a56a8cb339a0f88da9d3",
"ac": "3cfc407ce3f4f4d6e3c5fdb457bb87da",
},
"etm_toa": {
"aa": "87e47278f5a30acd3b696b6daaa4713b",
"ab": "59295e1816e08a5acd3a18ae56b6f32e",
"ac": "f3ff76eb6987501000228ce15684e09f",
},
"etm_sr": {
"aa": "fd61a4154eafaeb350dbb01a2551a818",
"ab": "0c3117bc7682ba9ffdc6871e6c364b36",
"ac": "93d3385e47de4578878ca5c4fa6a628d",
},
"oli_tirs_toa": {
"aa": "defb9e91a73b145b2dbe347649bded06",
"ab": "97f7edaa4e288fc14ec7581dccea766f",
"ac": "7472fad9929a0dc96ccf4dc6c804b92f",
},
"oli_sr": {
"aa": "8fd3aa6b581d024299f44457956faa05",
"ab": "7eb4d761ce1afd89cae9c6142ca17882",
"ac": "a3210da9fcc71e3a4efde71c30d78c59",
},
}
monkeypatch.setattr(SSL4EOL, "checksums", checksums)

root = str(tmp_path)
split, seasons = request.param
root = os.path.join("tests", "data", "ssl4eo", "l", split)
transforms = nn.Identity()
return SSL4EOL(root, split, seasons, transforms)
return SSL4EOL(root, split, seasons, transforms, download=True, checksum=True)

def test_getitem(self, dataset: SSL4EOL) -> None:
x = dataset[0]
Expand All @@ -41,6 +83,20 @@ def test_add(self, dataset: SSL4EOL) -> None:
assert isinstance(ds, ConcatDataset)
assert len(ds) == 2 * 2

def test_already_extracted(self, dataset: SSL4EOL) -> None:
SSL4EOL(dataset.root, dataset.split, dataset.seasons)

def test_already_downloaded(self, dataset: SSL4EOL, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "ssl4eo", "l", "*.tar.gz*")
root = str(tmp_path)
for tarfile in glob.iglob(pathname):
shutil.copy(tarfile, root)
SSL4EOL(root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
SSL4EOL(str(tmp_path))

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
SSL4EOL(split="foo")
Expand Down
Loading