Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AI4ArcticSeaIce dataset. #2528

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ ADVANCE

.. autoclass:: ADVANCE

AI4ArcticSeaIce

.. autoclass:: AI4ArcticSeaIce

Benin Cashew Plantations
^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`ADVANCE`_,C,"Google Earth, Freesound","CC-BY-4.0","5,075",13,512x512,0.5,RGB
`AI4ArcticSeaIce`_,S,"Sentinel-1","CC-BY-4.0","520",2,"~5000x5000",80,"HH,HV"
`Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI
`BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI"
`BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI"
Expand Down
247 changes: 247 additions & 0 deletions tests/data/ai4arctic_sea_ice/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import numpy as np
import xarray as xr
import pandas as pd
import tarfile
import hashlib
import shutil
from datetime import datetime, timedelta


def create_dummy_nc_file(filepath: str, is_reference: bool = False):

Check failure on line 16 in tests/data/ai4arctic_sea_ice/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

tests/data/ai4arctic_sea_ice/data.py:6:1: I001 Import block is un-sorted or un-formatted

Check failure on line 16 in tests/data/ai4arctic_sea_ice/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN201)

tests/data/ai4arctic_sea_ice/data.py:16:5: ANN201 Missing return type annotation for public function `create_dummy_nc_file`
"""Create dummy netCDF file matching original dataset structure."""

# Define dimensions
dims = {
'sar_lines': 12,
'sar_samples': 9,
'sar_sample_2dgrid_points': 3,
'sar_line_2dgrid_points': 4,
'2km_grid_lines': 5,
'2km_grid_samples': 6,
}

# Create variables with realistic dummy data
data_vars = {
# SAR variables (full resolution)
'nersc_sar_primary': (
('sar_lines', 'sar_samples'),
np.random.normal(-20, 5, (dims['sar_lines'], dims['sar_samples'])).astype(
np.float32
),
),
'nersc_sar_secondary': (
('sar_lines', 'sar_samples'),
np.random.normal(-25, 5, (dims['sar_lines'], dims['sar_samples'])).astype(
np.float32
),
),
# Grid coordinates
'sar_grid2d_latitude': (
('sar_sample_2dgrid_points', 'sar_line_2dgrid_points'),
np.random.uniform(
60,
80,
(dims['sar_sample_2dgrid_points'], dims['sar_line_2dgrid_points']),
).astype(np.float64),
),
'sar_grid2d_longitude': (
('sar_sample_2dgrid_points', 'sar_line_2dgrid_points'),
np.random.uniform(
-60,
0,
(dims['sar_sample_2dgrid_points'], dims['sar_line_2dgrid_points']),
).astype(np.float64),
),
# Weather variables (2km grid)
'u10m_rotated': (
('2km_grid_lines', '2km_grid_samples'),
np.random.normal(
0, 5, (dims['2km_grid_lines'], dims['2km_grid_samples'])
).astype(np.float32),
),
'v10m_rotated': (
('2km_grid_lines', '2km_grid_samples'),
np.random.normal(
0, 5, (dims['2km_grid_lines'], dims['2km_grid_samples'])
).astype(np.float32),
),
# AMSR2 variables (6.9, 7.3, 10.7, 23.8, 36.5, 89.0 GHz, h, v)
**{
f'btemp_{freq}{pol}': (
('2km_grid_lines', '2km_grid_samples'),
np.random.normal(
250, 20, (dims['2km_grid_lines'], dims['2km_grid_samples'])
).astype(np.float32),
)
for freq in ['6_9', '7_3']
for pol in ['h', 'v']
},
# Add distance map
'distance_map': (
('sar_lines', 'sar_samples'),
np.random.uniform(0, 10, (dims['sar_lines'], dims['sar_samples'])).astype(
np.float32
),
{
'long_name': 'Distance to land zones numbered with ids ranging from 0 to N',
'zonal_range_description': '\ndist_id; dist_range_km\n0; land\n1; 0 -> 0.5\n2; 0.5 -> 1\n3; 1 -> 2\n4; 2 -> 4\n5; 4 -> 8\n6; 8 -> 16\n7; 16 -> 32\n8; 32 -> 64\n9; 64 -> 128\n10; >128',
},
),
}

# Add target variables if reference file
if is_reference:
data_vars.update(
{
'SOD': (
('sar_lines', 'sar_samples'),
np.random.randint(
0, 6, (dims['sar_lines'], dims['sar_samples'])
).astype(np.uint8),
),
'SIC': (
('sar_lines', 'sar_samples'),
np.random.randint(
0, 11, (dims['sar_lines'], dims['sar_samples'])
).astype(np.uint8),
),
'FLOE': (
('sar_lines', 'sar_samples'),
np.random.randint(
0, 7, (dims['sar_lines'], dims['sar_samples'])
).astype(np.uint8),
),
}
)

# Create dataset with correct attributes
ds = xr.Dataset(
data_vars=data_vars,
attrs={
'scene_id': os.path.basename(filepath),
'original_id': f'S1A_EW_GRDM_1SDH_{os.path.basename(filepath)}',
'ice_service': 'dmi' if 'dmi' in filepath else 'cis',
'flip': 0,
'pixel_spacing': 80,
},
)

# Save to netCDF file
os.makedirs(os.path.dirname(filepath), exist_ok=True)
ds.to_netcdf(filepath)


def create_metadata_csv(root_dir: str, n_train: int = 3, n_test: int = 2):

Check failure on line 140 in tests/data/ai4arctic_sea_ice/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN201)

tests/data/ai4arctic_sea_ice/data.py:140:5: ANN201 Missing return type annotation for public function `create_metadata_csv`
"""Create metadata CSV file."""
records = []

# Generate dates
base_date = datetime(2021, 1, 1)
dates = [base_date + timedelta(days=i) for i in range(n_train + n_test)]

# Create train records
for i in range(n_train):
date_str = dates[i].strftime('%Y%m%dT%H%M%S')
service = 'dmi' if i % 2 == 0 else 'cis'
path = f'train/{date_str}_{service}_prep.nc'
records.append(
{
'input_path': path,
'reference_path': None,
'date': dates[i],
'ice_service': service,
'split': 'train',
'region_id': 'SGRDIFOXE' if service == 'cis' else 'North_RIC',
}
)

# Create test records
for i in range(n_test):
date_str = dates[n_train + i].strftime('%Y%m%dT%H%M%S')
service = 'dmi' if i % 2 == 0 else 'cis'
input_path = f'test/{date_str}_{service}_prep.nc'
ref_path = f'test/{date_str}_{service}_prep_reference.nc'
records.append(
{
'input_path': input_path,
'reference_path': ref_path,
'date': dates[n_train + i],
'ice_service': service,
'split': 'test',
'region_id': 'SGRDIFOXE' if service == 'cis' else 'North_RIC',
}
)

# Create DataFrame and save
df = pd.DataFrame(records)
df.to_csv(os.path.join(root_dir, 'metadata.csv'), index=False)
return df


def main():

Check failure on line 187 in tests/data/ai4arctic_sea_ice/data.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN201)

tests/data/ai4arctic_sea_ice/data.py:187:5: ANN201 Missing return type annotation for public function `main`
"""Create complete dummy dataset."""
root_dir = '.'
n_train = 3
n_test = 2

# Create metadata
df = create_metadata_csv(root_dir, n_train, n_test)

# Create train files
train_files = df[df['split'] == 'train']['input_path']
for f in train_files:
create_dummy_nc_file(os.path.join(root_dir, f), is_reference=True)

# Create test files
test_files = df[df['split'] == 'test']
for _, row in test_files.iterrows():
create_dummy_nc_file(
os.path.join(root_dir, row['input_path']), is_reference=False
)
create_dummy_nc_file(
os.path.join(root_dir, row['reference_path']), is_reference=True
)

# Create and split train tarball
shutil.make_archive('train', 'gztar', '.', 'train')

with open('train.tar.gz', 'rb') as f:
content = f.read()

# Split into two chunks
chunk1 = content[: len(content) // 2]
chunk2 = content[len(content) // 2 :]

with open('train.tar.gzaa', 'wb') as g:
g.write(chunk1)
with open('train.tar.gzab', 'wb') as g:
g.write(chunk2)

# Remove original tarball
os.remove('train.tar.gz')

with tarfile.open('test.tar.gz', 'w:gz') as tar:
tar.add('test')

# compute md5sum
def md5(fname: str) -> str:
hash_md5 = hashlib.md5()
with open(fname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
hash_md5.update(chunk)
return hash_md5.hexdigest()

print(f'MD5 checksum train.gzaa: {md5("train.tar.gzaa")}')
print(f'MD5 checksum train.gzab: {md5("train.tar.gzab")}')
print(f'MD5 checksum test.gz: {md5("test.tar.gz")}')
print(f'MD5 checksum metadata: {md5("metadata.csv")}')


if __name__ == '__main__':
main()
6 changes: 6 additions & 0 deletions tests/data/ai4arctic_sea_ice/metadata.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
input_path,reference_path,date,ice_service,split,region_id
train/20210101T000000_dmi_prep.nc,,2021-01-01,dmi,train,North_RIC
train/20210102T000000_cis_prep.nc,,2021-01-02,cis,train,SGRDIFOXE
train/20210103T000000_dmi_prep.nc,,2021-01-03,dmi,train,North_RIC
test/20210104T000000_dmi_prep.nc,test/20210104T000000_dmi_prep_reference.nc,2021-01-04,dmi,test,North_RIC
test/20210105T000000_cis_prep.nc,test/20210105T000000_cis_prep_reference.nc,2021-01-05,cis,test,SGRDIFOXE
Binary file added tests/data/ai4arctic_sea_ice/test.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/ai4arctic_sea_ice/train.tar.gzaa
Binary file not shown.
Binary file added tests/data/ai4arctic_sea_ice/train.tar.gzab
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
93 changes: 93 additions & 0 deletions tests/datasets/test_ai4arctic_sea_ice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import shutil

Check failure on line 4 in tests/datasets/test_ai4arctic_sea_ice.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

tests/datasets/test_ai4arctic_sea_ice.py:4:8: F401 `shutil` imported but unused
import os
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError, AI4ArcticSeaIce

pytest.importorskip('xarray', minversion='2023.9')

Check failure on line 17 in tests/datasets/test_ai4arctic_sea_ice.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

tests/datasets/test_ai4arctic_sea_ice.py:4:1: I001 Import block is un-sorted or un-formatted
pytest.importorskip('netCDF4', minversion='1.5.4')

valid_amsr2_vars = ('btemp_6_9h', 'btemp_6_9v', 'btemp_7_3h', 'btemp_7_3v')
valid_weather_vars = ('u10m_rotated', 'v10m_rotated')


class TestAI4ArcticSeaIce:
@pytest.fixture(
params=zip(
['train', 'train', 'test', 'test'],
['SOD', 'SIC', 'FLOE', 'SIC'],
[None, 'distance_map', None, 'distance_map'],
[valid_amsr2_vars, None, valid_amsr2_vars, None],
[valid_weather_vars, None, valid_weather_vars, None],
)
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> AI4ArcticSeaIce:
url = os.path.join('tests', 'data', 'ai4arctic_sea_ice', '{}')
monkeypatch.setattr(AI4ArcticSeaIce, 'url', url)
files = [
{'name': 'train.tar.gzaa', 'md5': '399952b2603d0d508a30909357e6956a'},
{'name': 'train.tar.gzab', 'md5': 'a998c852a2f418394f97cb1f99716489'},
{'name': 'test.tar.gz', 'md5': 'b81e53b4c402a64d53854f02f66ce938'},
{'name': 'metadata.csv', 'md5': 'd1222877af76d3fe9620678c930d70f0'},
]
monkeypatch.setattr(AI4ArcticSeaIce, 'files', files)

monkeypatch.setattr(AI4ArcticSeaIce, 'valid_amsr2_vars', valid_amsr2_vars)

monkeypatch.setattr(AI4ArcticSeaIce, 'valid_weather_vars', valid_weather_vars)
root = tmp_path
split, target_var, geo_var, amsr2_var, weather_var = request.param
transforms = nn.Identity()
return AI4ArcticSeaIce(
root,
split=split,
target_var=target_var,
geo_var=geo_var,
amsr2_vars=amsr2_var,
weather_vars=weather_var,
transforms=transforms,
download=True,
checksum=False,
)

def test_getitem(self, dataset: AI4ArcticSeaIce) -> None:
x = dataset[0]
assert isinstance(x, dict)

def test_len(self, dataset: AI4ArcticSeaIce) -> None:
if dataset.split == 'train':
assert len(dataset) == 3
else:
assert len(dataset) == 2

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
AI4ArcticSeaIce(tmp_path)

def test_already_downloaded_and_extracted(self, dataset: AI4ArcticSeaIce) -> None:
AI4ArcticSeaIce(root=dataset.root, download=False)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
AI4ArcticSeaIce(split='foo')

def test_plot(self, dataset: AI4ArcticSeaIce) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()

sample = dataset[0]
sample['prediction'] = torch.clone(sample['mask'])
dataset.plot(sample, suptitle='Test with prediction')
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .advance import ADVANCE
from .agb_live_woody_density import AbovegroundLiveWoodyBiomassDensity
from .agrifieldnet import AgriFieldNet
from .ai4arctic_sea_ice import AI4ArcticSeaIce
from .airphen import Airphen
from .astergdem import AsterGDEM
from .benin_cashews import BeninSmallHolderCashews
Expand Down Expand Up @@ -149,153 +150,154 @@
from .xview import XView2
from .zuericrop import ZueriCrop

__all__ = (
'ADVANCE',
'AI4ArcticSeaIce',
'CDL',
'COWC',
'DFC2022',
'ETCI2021',
'EUDEM',
'FAIR1M',
'GBIF',
'GID15',
'LEVIRCD',
'MDAS',
'NAIP',
'NCCM',
'NLCD',
'OSCD',
'PASTIS',
'PRISMA',
'RESISC45',
'SEN12MS',
'SKIPPD',
'SSL4EO',
'SSL4EOL',
'SSL4EOS12',
'VHR10',
'AbovegroundLiveWoodyBiomassDensity',
'AgriFieldNet',
'Airphen',
'AsterGDEM',
'BeninSmallHolderCashews',
'BigEarthNet',
'BioMassters',
'BoundingBox',
'CMSGlobalMangroveCanopy',
'COWCCounting',
'COWCDetection',
'CV4AKenyaCropType',
'CaBuAr',
'CaFFe',
'CanadianBuildingFootprints',
'ChaBuD',
'Chesapeake',
'ChesapeakeCVPR',
'ChesapeakeDC',
'ChesapeakeDE',
'ChesapeakeMD',
'ChesapeakeNY',
'ChesapeakePA',
'ChesapeakeVA',
'ChesapeakeWV',
'CloudCoverDetection',
'CropHarvest',
'DatasetNotFoundError',
'DeepGlobeLandCover',
'DependencyNotFoundError',
'DigitalTyphoon',
'EDDMapS',
'EnviroAtlas',
'Esri2020',
'EuroCrops',
'EuroSAT',
'EuroSAT100',
'EuroSATSpatial',
'FieldsOfTheWorld',
'FireRisk',
'ForestDamage',
'GeoDataset',
'GeoNRW',
'GlobBiomass',
'HySpecNet11k',
'IDTReeS',
'INaturalist',
'IOBench',
'InriaAerialImageLabeling',
'IntersectionDataset',
'L7Irish',
'L8Biome',
'LEVIRCDBase',
'LEVIRCDPlus',
'LandCoverAI',
'LandCoverAI100',
'LandCoverAIBase',
'LandCoverAIGeo',
'Landsat',
'Landsat1',
'Landsat2',
'Landsat3',
'Landsat4MSS',
'Landsat4TM',
'Landsat5MSS',
'Landsat5TM',
'Landsat7',
'Landsat8',
'Landsat9',
'LoveDA',
'MMEarth',
'MapInWild',
'MillionAID',
'NASAMarineDebris',
'NonGeoClassificationDataset',
'NonGeoDataset',
'OpenBuildings',
'PatternNet',
'Potsdam2D',
'QuakeSet',
'RGBBandsMissingError',
'RasterDataset',
'ReforesTree',
'RwandaFieldBoundary',
'SSL4EOLBenchmark',
'SatlasPretrain',
'SeasoNet',
'SeasonalContrastS2',
'Sentinel',
'Sentinel1',
'Sentinel2',
'SkyScript',
'So2Sat',
'SouthAfricaCropType',
'SouthAmericaSoybean',
'SpaceNet',
'SpaceNet1',
'SpaceNet2',
'SpaceNet3',
'SpaceNet4',
'SpaceNet5',
'SpaceNet6',
'SpaceNet7',
'SpaceNet8',
'SustainBenchCropYield',
'TreeSatAI',
'TropicalCyclone',
'UCMerced',
'USAVars',
'UnionDataset',
'Vaihingen2D',
'VectorDataset',
'WesternUSALiveFuelMoisture',
'XView2',
'ZueriCrop',
'concat_samples',
'merge_samples',
'random_bbox_assignment',
'random_bbox_splitting',
'random_grid_cell_assignment',
'roi_split',
'stack_samples',
'time_series_split',
'unbind_samples',
)

Check failure on line 303 in torchgeo/datasets/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF022)

torchgeo/datasets/__init__.py:153:11: RUF022 `__all__` is not sorted
Loading
Loading