Skip to content

Commit

Permalink
Merge branch 'main' into geosampler_prechipping
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena authored Sep 30, 2024
2 parents e7ecb86 + b2f9936 commit 5347560
Show file tree
Hide file tree
Showing 24 changed files with 613 additions and 19 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ jobs:
latest:
name: latest
runs-on: ${{ matrix.os }}
env:
MPLBACKEND: Agg
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
Expand Down Expand Up @@ -55,8 +53,6 @@ jobs:
minimum:
name: minimum
runs-on: ubuntu-latest
env:
MPLBACKEND: Agg
steps:
- name: Clone repo
uses: actions/checkout@v4.1.7
Expand Down Expand Up @@ -90,8 +86,6 @@ jobs:
datasets:
name: datasets
runs-on: ubuntu-latest
env:
MPLBACKEND: Agg
steps:
- name: Clone repo
uses: actions/checkout@v4.1.7
Expand Down
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,11 @@ FAIR1M

.. autoclass:: FAIR1M

Fields Of The World
^^^^^^^^^^^^^^^^^^^

.. autoclass:: FieldsOfTheWorld

FireRisk
^^^^^^^^

Expand Down
2 changes: 1 addition & 1 deletion docs/api/datasets/geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Dataset,Type,Source,License,Size (px),Resolution (m)
`L8 Biome`_,"Imagery, Masks",Landsat,"CC0-1.0","8,900x8,900","15, 30"
`LandCover.ai Geo`_,"Imagery, Masks",Aerial,"CC-BY-NC-SA-4.0","4,200--9,500",0.25--0.5
`Landsat`_,Imagery,Landsat,"public domain","8,900x8,900",30
`NAIP`_,Imagery,Aerial,"public domain","6,100x7,600",1
`NAIP`_,Imagery,Aerial,"public domain","6,100x7,600",0.3--2
`NCCM`_,Masks,Sentinel-2,"CC-BY-4.0",-,10
`NLCD`_,Masks,Landsat,"public domain",-,30
`Open Buildings`_,Geometries,"Maxar, CNES/Airbus","CC-BY-4.0 OR ODbL-1.0",-,-
Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`ETCI2021 Flood Detection`_,S,Sentinel-1,-,"66,810",2,256x256,5--20,SAR
`EuroSAT`_,C,Sentinel-2,"MIT","27,000",10,64x64,10,MSI
`FAIR1M`_,OD,Gaofen/Google Earth,"CC-BY-NC-SA-3.0","15,000",37,"1,024x1,024",0.3--0.8,RGB
`Fields Of The World`_,"S,I",Sentinel-2,"Various","70795","2,3",256x256,10,MSI
`FireRisk`_,C,NAIP Aerial,"CC-BY-NC-4.0","91,872",7,"320x320",1,RGB
`Forest Damage`_,OD,Drone imagery,"CDLA-Permissive-1.0","1,543",4,"1,500x1,500",,RGB
`GeoNRW`_,S,Aerial,"CC-BY-4.0","7,783",11,"1,000x1,000",1,"RGB, DEM"
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ datasets = [
"laspy>=2",
# opencv-python 4.5.4+ required for Python 3.10 wheels
"opencv-python>=4.5.4",
# pandas 2+ required for parquet extra
"pandas[parquet]>=2",
# pycocotools 2.0.7+ required for wheels
"pycocotools>=2.0.7",
# pyvista 0.34.2+ required to avoid ImportError in CI
Expand Down
3 changes: 2 additions & 1 deletion requirements/datasets.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# datasets
h5py==3.11.0
h5py==3.12.1
laspy==2.5.4
opencv-python==4.10.0.84
pandas[parquet]==2.2.3
pycocotools==2.0.8
pyvista==0.44.1
scikit-image==0.24.0
Expand Down
1 change: 1 addition & 0 deletions requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ h5py==3.6.0
laspy==2.0.0
opencv-python==4.5.4.58
pycocotools==2.0.7
pyarrow==15.0.0 # Remove when we upgrade min verison of pandas to `pandas[parquet]>=2`
pyvista==0.34.2
scikit-image==0.19.0
scipy==1.7.2
Expand Down
2 changes: 1 addition & 1 deletion requirements/style.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# style
mypy==1.11.2
ruff==0.6.7
ruff==0.6.8
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Any

import matplotlib
import pytest
import torch
import torchvision
Expand All @@ -19,6 +20,11 @@ def load_state_dict_from_url(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load)


@pytest.fixture(autouse=True, scope='session')
def matplotlib_backend() -> None:
matplotlib.use('agg')


@pytest.fixture(autouse=True)
def torch_hub(tmp_path: Path) -> None:
torch.hub.set_dir(tmp_path) # type: ignore[no-untyped-call]
Binary file added tests/data/ftw/austria.zip
Binary file not shown.
110 changes: 110 additions & 0 deletions tests/data/ftw/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil
import zipfile

import numpy as np
import pandas as pd
import rasterio
from affine import Affine

np.random.seed(0)

country = 'austria'
SIZE = 32
num_samples = {'train': 2, 'val': 2, 'test': 2}
BASE_PROFILE = {
'driver': 'GTiff',
'dtype': 'uint16',
'nodata': None,
'width': SIZE,
'height': SIZE,
'count': 4,
'crs': 'EPSG:4326',
'transform': Affine(5.4e-05, 0.0, 0, 0.0, 5.4e-05, 0),
'blockxsize': SIZE,
'blockysize': SIZE,
'tiled': True,
'interleave': 'pixel',
}


def create_image(fn: str) -> None:
os.makedirs(os.path.dirname(fn), exist_ok=True)

profile = BASE_PROFILE.copy()

data = np.random.randint(0, 20000, size=(4, SIZE, SIZE), dtype=np.uint16)
with rasterio.open(fn, 'w', **profile) as dst:
dst.write(data)


def create_mask(fn: str, min_val: int, max_val: int) -> None:
os.makedirs(os.path.dirname(fn), exist_ok=True)

profile = BASE_PROFILE.copy()
profile['dtype'] = 'uint8'
profile['nodata'] = 0
profile['count'] = 1

data = np.random.randint(min_val, max_val, size=(1, SIZE, SIZE), dtype=np.uint8)
with rasterio.open(fn, 'w', **profile) as dst:
dst.write(data)


if __name__ == '__main__':
i = 0
cols = {'aoi_id': [], 'split': []}
for split, n in num_samples.items():
for j in range(n):
aoi = f'g_{i}'
cols['aoi_id'].append(aoi)
cols['split'].append(split)

create_image(os.path.join(country, 's2_images', 'window_a', f'{aoi}.tif'))
create_image(os.path.join(country, 's2_images', 'window_b', f'{aoi}.tif'))

create_mask(
os.path.join(country, 'label_masks', 'semantic_2class', f'{aoi}.tif'),
0,
1,
)
create_mask(
os.path.join(country, 'label_masks', 'semantic_3class', f'{aoi}.tif'),
0,
2,
)
create_mask(
os.path.join(country, 'label_masks', 'instance', f'{aoi}.tif'), 0, 100
)

i += 1

# Create an extra train file to test for missing other files
aoi = f'g_{i}'
cols['aoi_id'].append(aoi)
cols['split'].append(split)
create_image(os.path.join(country, 's2_images', 'window_a', f'{aoi}.tif'))

# Write parquet index
df = pd.DataFrame(cols)
df.to_parquet(os.path.join(country, f'chips_{country}.parquet'))

# archive to zip
with zipfile.ZipFile(f'{country}.zip', 'w') as zipf:
for root, _, files in os.walk(country):
for file in files:
output_fn = os.path.join(root, file)
zipf.write(output_fn, os.path.relpath(output_fn, country))

shutil.rmtree(country)

# Compute checksums
with open(f'{country}.zip', 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'{md5}')
90 changes: 90 additions & 0 deletions tests/datasets/test_ftw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset

from torchgeo.datasets import DatasetNotFoundError, FieldsOfTheWorld

pytest.importorskip('pyarrow')


class TestFieldsOfTheWorld:
@pytest.fixture(
params=product(['train', 'val', 'test'], ['2-class', '3-class', 'instance'])
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> FieldsOfTheWorld:
split, task = request.param

monkeypatch.setattr(FieldsOfTheWorld, 'valid_countries', ['austria'])
monkeypatch.setattr(
FieldsOfTheWorld,
'country_to_md5',
{'austria': '1cf9593c9bdceeaba21bbcb24d35816c'},
)
base_url = os.path.join('tests', 'data', 'ftw') + '/'
monkeypatch.setattr(FieldsOfTheWorld, 'base_url', base_url)
root = tmp_path
transforms = nn.Identity()
return FieldsOfTheWorld(
root,
split,
task,
countries='austria',
transforms=transforms,
download=True,
checksum=True,
)

def test_getitem(self, dataset: FieldsOfTheWorld) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)

def test_len(self, dataset: FieldsOfTheWorld) -> None:
assert len(dataset) == 2

def test_add(self, dataset: FieldsOfTheWorld) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 4

def test_already_extracted(self, dataset: FieldsOfTheWorld) -> None:
FieldsOfTheWorld(root=dataset.root, download=True)

def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
url = os.path.join('tests', 'data', 'ftw', 'austria.zip')
root = tmp_path
shutil.copy(url, root)
FieldsOfTheWorld(root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
FieldsOfTheWorld(tmp_path)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
FieldsOfTheWorld(split='foo')

def test_plot(self, dataset: FieldsOfTheWorld) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x['prediction'] = x['mask'].clone()
dataset.plot(x)
plt.close()
5 changes: 5 additions & 0 deletions tests/models/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,8 @@ def test_get_weight(enum: WeightsEnum) -> None:
def test_list_models() -> None:
models = [builder.__name__ for builder in builders]
assert set(models) == set(list_models())


def test_invalid_model() -> None:
with pytest.raises(ValueError, match='bad_model is not a valid WeightsEnum'):
get_weight('bad_model')
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .fair1m import FAIR1M
from .fire_risk import FireRisk
from .forestdamage import ForestDamage
from .ftw import FieldsOfTheWorld
from .gbif import GBIF
from .geo import (
GeoDataset,
Expand Down Expand Up @@ -217,6 +218,7 @@
'EuroSATSpatial',
'EuroSAT100',
'FAIR1M',
'FieldsOfTheWorld',
'FireRisk',
'ForestDamage',
'GeoNRW',
Expand Down
Loading

0 comments on commit 5347560

Please sign in to comment.