Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding dataset from MOSAIKS paper #363

Merged
merged 67 commits into from
Feb 27, 2022
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
0620f6b
Adding dataset from MOSAIKS paper
iejMac Jan 20, 2022
09b4faf
Name change
iejMac Jan 21, 2022
fa869cc
implementing NAIPTileIndex in USAVars
iejMac Jan 24, 2022
501216b
lookup_point works
iejMac Jan 24, 2022
66e2c1c
Merge branch 'main' of https://github.com/iejMac/torchgeo into mosaik…
iejMac Feb 22, 2022
48f6c9b
usavars: adding extract + verify
iejMac Feb 22, 2022
be354d7
USAVars: add md5
iejMac Feb 22, 2022
8b8ca79
initial _load_files function
iejMac Feb 22, 2022
b57b7ee
adding plotting
iejMac Feb 22, 2022
ec84d89
formatting
iejMac Feb 22, 2022
b78ee80
add description
iejMac Feb 23, 2022
0311f4e
black fix
iejMac Feb 23, 2022
ebbf9d0
flake8 fix
iejMac Feb 23, 2022
b0d34c9
pydocstyle fix
iejMac Feb 23, 2022
062c9f8
mypy fix
iejMac Feb 23, 2022
7bffb9c
add DS to docs
iejMac Feb 23, 2022
168f989
black fix
iejMac Feb 23, 2022
7e940e1
fake dataset
iejMac Feb 23, 2022
8837898
add transforms arg
iejMac Feb 23, 2022
0a880b5
initial tests
iejMac Feb 23, 2022
61739b7
fix black flake8 isort
iejMac Feb 23, 2022
b5e2512
fix black flake8
iejMac Feb 23, 2022
74aaab4
fix black
iejMac Feb 23, 2022
a47e2fb
fix mypy
iejMac Feb 23, 2022
e081de8
test fixes
iejMac Feb 23, 2022
73a87f8
testing something
iejMac Feb 23, 2022
efb8360
it finds zip but not csv
iejMac Feb 23, 2022
dea3703
fake csv files didn't get added
iejMac Feb 23, 2022
81afd9a
pandas docs fix
iejMac Feb 23, 2022
bd75a00
forgot to take out here
iejMac Feb 23, 2022
bd90b5e
need to add in functions
iejMac Feb 23, 2022
d1deebf
round plot labels
iejMac Feb 23, 2022
0e6202a
Small edits
calebrob6 Feb 23, 2022
d6bc60f
remove Unnamed column
iejMac Feb 23, 2022
bf77ea3
zipfile change
iejMac Feb 23, 2022
9e63d67
i think this solves codecov?
iejMac Feb 23, 2022
18a1904
there needs to be a test
iejMac Feb 23, 2022
2820635
codecov
iejMac Feb 23, 2022
dd9b3b7
bring back UAR!
iejMac Feb 24, 2022
b45a4c3
remove intermediate directory
iejMac Feb 24, 2022
c1f9345
fix flake8
iejMac Feb 24, 2022
705ffca
No more iteration in load_files
iejMac Feb 24, 2022
8a72999
dont' use Any
iejMac Feb 24, 2022
d28a6e4
check if all csv files exist
iejMac Feb 25, 2022
faf6f21
Add docstring to init
iejMac Feb 25, 2022
5e75097
use index col = ID
iejMac Feb 25, 2022
5626062
labels in as list
iejMac Feb 25, 2022
f57f999
adjust to only 3 labels + adjust tests
iejMac Feb 25, 2022
9f24afb
citation
iejMac Feb 25, 2022
e97939a
remove testing file
iejMac Feb 25, 2022
2ba7e6b
no need to rename zipfile
iejMac Feb 25, 2022
fca13f6
adding data.py to test data + adjusting tests
iejMac Feb 25, 2022
07f942f
formatting fixes
iejMac Feb 25, 2022
681321e
style fix
iejMac Feb 25, 2022
da8040b
docstring
iejMac Feb 25, 2022
f30b50e
Docstring
calebrob6 Feb 25, 2022
eb81d7c
Fixing docstring
calebrob6 Feb 25, 2022
fb49f5d
docstring for labels
iejMac Feb 26, 2022
5e4f2b9
kjMerge branch 'mosaiks_dataset' of https://github.com/iejMac/torchge…
iejMac Feb 26, 2022
8243b30
Adding all csv files to data + checking for all 7 labels instead of j…
iejMac Feb 26, 2022
d518683
docstrings
iejMac Feb 27, 2022
72459ca
docstring
iejMac Feb 27, 2022
fdf8ee7
ensure labels are valid
iejMac Feb 27, 2022
1d0d0ae
pydocstyle fix
iejMac Feb 27, 2022
9199844
cast to list
iejMac Feb 27, 2022
c64876c
remove typos
iejMac Feb 27, 2022
0925b64
Requested changes
calebrob6 Feb 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ UC Merced

.. autoclass:: UCMerced

USAVars
^^^^^^^

.. autoclass:: USAVars

Vaihingen
^^^^^^^^^

Expand Down
77 changes: 77 additions & 0 deletions tests/data/usavars/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import glob
import hashlib
import os
import shutil

import numpy as np
import pandas as pd
import rasterio

data_dir = "uar"
labels = [
"elevation",
"population",
"treecover",
"income",
"nightlights",
"housing",
"roads",
]
SIZE = 3


def create_file(path: str, dtype: str, num_channels: int) -> None:
profile = {}
profile["driver"] = "GTiff"
profile["dtype"] = dtype
profile["count"] = num_channels
profile["crs"] = "epsg:4326"
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
profile["height"] = SIZE
profile["width"] = SIZE
profile["compress"] = "lzw"
profile["predictor"] = 2

Z = np.random.randint(
np.iinfo(profile["dtype"]).max, size=(4, SIZE, SIZE), dtype=profile["dtype"]
)
src = rasterio.open(path, "w", **profile)
src.write(Z)


# Remove old data
filename = f"{data_dir}.zip"
csvs = glob.glob("*.csv")

for csv in csvs:
os.remove(csv)
if os.path.exists(filename):
os.remove(filename)
if os.path.exists(data_dir):
shutil.rmtree(data_dir)

# Create tifs:
os.makedirs(data_dir)
create_file(os.path.join(data_dir, "tile_0,0.tif"), np.uint8, 4)
create_file(os.path.join(data_dir, "tile_0,1.tif"), np.uint8, 4)

# Create labels:
columns = [["ID", "lon", "lat", lab] for lab in labels]
fake_vals = [["0,0", 0.0, 0.0, 0.0], ["0,1", 0.1, 0.1, 1.0]]
for lab, cols in zip(labels, columns):
df = pd.DataFrame(fake_vals, columns=cols)
df.to_csv(lab + ".csv")

# Compress data
shutil.make_archive(data_dir, "zip", ".", data_dir)

# Compute checksums
filename = f"{data_dir}.zip"
with open(filename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(repr(filename) + ": " + repr(md5) + ",")
3 changes: 3 additions & 0 deletions tests/data/usavars/elevation.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
,ID,lon,lat,elevation
0,"0,0",0.0,0.0,0.0
1,"0,1",0.1,0.1,1.0
3 changes: 3 additions & 0 deletions tests/data/usavars/housing.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
,ID,lon,lat,housing
0,"0,0",0.0,0.0,0.0
1,"0,1",0.1,0.1,1.0
3 changes: 3 additions & 0 deletions tests/data/usavars/income.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
,ID,lon,lat,income
0,"0,0",0.0,0.0,0.0
1,"0,1",0.1,0.1,1.0
3 changes: 3 additions & 0 deletions tests/data/usavars/nightlights.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
,ID,lon,lat,nightlights
0,"0,0",0.0,0.0,0.0
1,"0,1",0.1,0.1,1.0
3 changes: 3 additions & 0 deletions tests/data/usavars/population.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
,ID,lon,lat,population
0,"0,0",0.0,0.0,0.0
1,"0,1",0.1,0.1,1.0
3 changes: 3 additions & 0 deletions tests/data/usavars/roads.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
,ID,lon,lat,roads
0,"0,0",0.0,0.0,0.0
1,"0,1",0.1,0.1,1.0
3 changes: 3 additions & 0 deletions tests/data/usavars/treecover.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
,ID,lon,lat,treecover
0,"0,0",0.0,0.0,0.0
1,"0,1",0.1,0.1,1.0
Binary file added tests/data/usavars/uar.zip
Binary file not shown.
Binary file added tests/data/usavars/uar/tile_0,0.tif
Binary file not shown.
Binary file added tests/data/usavars/uar/tile_0,1.tif
Binary file not shown.
135 changes: 135 additions & 0 deletions tests/datasets/test_usavars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any, Generator

import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from matplotlib import pyplot as plt
from torch.utils.data import ConcatDataset

import torchgeo.datasets.utils
from torchgeo.datasets import USAVars

pytest.importorskip("pandas", minversion="0.19.1")


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


class TestUSAVars:
@pytest.fixture()
def dataset(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
request: SubRequest,
) -> USAVars:

monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.usavars, "download_url", download_url
)

md5 = "b504580a00bdc27097d5421dec50481b"
monkeypatch.setattr(USAVars, "md5", md5) # type: ignore[attr-defined]

data_url = os.path.join("tests", "data", "usavars", "uar.zip")
monkeypatch.setattr(USAVars, "data_url", data_url) # type: ignore[attr-defined]

label_urls = {
"elevation": os.path.join("tests", "data", "usavars", "elevation.csv"),
"population": os.path.join("tests", "data", "usavars", "population.csv"),
"treecover": os.path.join("tests", "data", "usavars", "treecover.csv"),
"income": os.path.join("tests", "data", "usavars", "income.csv"),
"nightlights": os.path.join("tests", "data", "usavars", "nightlights.csv"),
"roads": os.path.join("tests", "data", "usavars", "roads.csv"),
"housing": os.path.join("tests", "data", "usavars", "housing.csv"),
}
monkeypatch.setattr( # type: ignore[attr-defined]
USAVars, "label_urls", label_urls
)

root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]

return USAVars(root, transforms=transforms, download=True, checksum=True)

def test_getitem(self, dataset: USAVars) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert x["image"].ndim == 3
assert len(x.keys()) == 2 # image, elevation, population, treecover
assert x["image"].shape[0] == 4 # R, G, B, Inf

def test_len(self, dataset: USAVars) -> None:
assert len(dataset) == 2

def test_add(self, dataset: USAVars) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)

def test_already_extracted(self, dataset: USAVars) -> None:
USAVars(root=dataset.root, download=True)

def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "usavars", "uar.zip")
root = str(tmp_path)
shutil.copy(pathname, root)
csvs = [
"elevation.csv",
"population.csv",
"treecover.csv",
"income.csv",
"nightlights.csv",
"roads.csv",
"housing.csv",
]
for csv in csvs:
shutil.copy(os.path.join("tests", "data", "usavars", csv), root)

USAVars(root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
USAVars(str(tmp_path))

@pytest.fixture(params=["pandas"])
def mock_missing_module(
self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest
) -> str:
import_orig = builtins.__import__
package = str(request.param)

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == package:
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr( # type: ignore[attr-defined]
builtins, "__import__", mocked_import
)
return package

def test_mock_missing_module(
self, dataset: USAVars, mock_missing_module: str
) -> None:
package = mock_missing_module
if package == "pandas":
with pytest.raises(
ImportError,
match=f"{package} is not installed and is required to use this dataset",
):
USAVars(dataset.root)

def test_plot(self, dataset: USAVars) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from .so2sat import So2Sat
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
from .ucmerced import UCMerced
from .usavars import USAVars
from .utils import (
BoundingBox,
concat_samples,
Expand Down Expand Up @@ -149,6 +150,7 @@
"SpaceNet7",
"TropicalCycloneWindEstimation",
"UCMerced",
"USAVars",
"Vaihingen2D",
"VHR10",
"XView2",
Expand Down
Loading