Skip to content

Commit

Permalink
Merge branch 'geosampler_prechipping' into vers_working_branch
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena committed Sep 24, 2024
2 parents 514745d + 25ce0e1 commit 7046e89
Show file tree
Hide file tree
Showing 15 changed files with 499 additions and 56 deletions.
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
('py:class', 'torchvision.models._api.WeightsEnum'),
('py:class', 'torchvision.models.resnet.ResNet'),
('py:class', 'torchvision.models.swin_transformer.SwinTransformer'),
('py:class', 'geopandas.GeoDataFrame'),
]


Expand Down Expand Up @@ -122,6 +123,7 @@
'torch': ('https://pytorch.org/docs/stable', None),
'torchmetrics': ('https://lightning.ai/docs/torchmetrics/stable/', None),
'torchvision': ('https://pytorch.org/vision/stable', None),
'geopandas': ('https://geopandas.org/en/stable/', None),
}

# nbsphinx
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ torchgeo
:caption: Tutorials

tutorials/getting_started
tutorials/visualizing_samples
tutorials/custom_raster_dataset
tutorials/transforms
tutorials/indices
Expand Down
159 changes: 159 additions & 0 deletions docs/tutorials/visualizing_samples.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualizing Samples\n",
"\n",
"This tutorial shows how to visualize and save the extent of your samples before and during training. In this particular example, we compare a vanilla RandomGeoSampler with one bounded by multiple ROI's and show how easy it is to gain insight on the distribution of your samples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import tempfile\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from torchgeo.datasets import NAIP, stack_samples\n",
"from torchgeo.datasets.utils import download_url\n",
"from torchgeo.samplers import RandomGeoSampler\n",
"\n",
"\n",
"def run_epochs(dataset, sampler):\n",
" dataloader = DataLoader(\n",
" dataset, sampler=sampler, batch_size=1, collate_fn=stack_samples, num_workers=0\n",
" )\n",
" fig, ax = plt.subplots()\n",
" num_epochs = 5\n",
" for epoch in range(num_epochs):\n",
" color = plt.cm.viridis(epoch / num_epochs)\n",
" # sampler.chips.to_file(f'naip_chips_epoch_{epoch}') # Optional: save chips to file for display in GIS software\n",
" ax = sampler.chips.plot(ax=ax, color=color)\n",
" for sample in dataloader:\n",
" pass\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generate dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"naip_root = os.path.join(tempfile.gettempdir(), 'naip')\n",
"naip_url = (\n",
" 'https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/'\n",
")\n",
"tiles = ['m_3807511_ne_18_060_20181104.tif', 'm_3807512_sw_18_060_20180815.tif']\n",
"for tile in tiles:\n",
" download_url(naip_url + tile, naip_root)\n",
"\n",
"naip = NAIP(naip_root)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we create the default sampler for our dataset (3 samples) and run it for 5 epochs and plot its results. Each color displays a different epoch, so we can see how the RandomGeoSampler has distributed it's samples for every epoch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sampler = RandomGeoSampler(naip, size=1000, length=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run_epochs(naip, sampler)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we split our dataset by two bounding boxes and re-inspect the samples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from torchgeo.datasets import roi_split\n",
"from torchgeo.datasets.utils import BoundingBox\n",
"\n",
"rois = [\n",
" BoundingBox(440854, 442938, 4299766, 4301731, 0, np.inf),\n",
" BoundingBox(449070, 451194, 4289463, 4291746, 0, np.inf),\n",
"]\n",
"datasets = roi_split(naip, rois)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"combined = datasets[0] | datasets[1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sampler = RandomGeoSampler(combined, size=1000, length=3)\n",
"run_epochs(combined, sampler)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cca",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ dependencies = [
"einops>=0.3",
# fiona 1.8.21+ required for Python 3.10 wheels
"fiona>=1.8.21",
# geopandas 0.13.2 is the last version to support pandas 1.3, but has feather support
"geopandas>=0.13.2",
# kornia 0.7.3+ required for instance segmentation support in AugmentationSequential
"kornia>=0.7.3",
# lightly 1.4.5+ required for LARS optimizer
Expand All @@ -58,6 +60,8 @@ dependencies = [
"pandas>=1.3.3",
# pillow 8.4+ required for Python 3.10 wheels
"pillow>=8.4",
# pyarrow 12.0+ required for feather support
"pyarrow>=17.0.0",
# pyproj 3.3+ required for Python 3.10 wheels
"pyproj>=3.3",
# rasterio 1.3+ required for Python 3.10 wheels
Expand Down
2 changes: 2 additions & 0 deletions requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ setuptools==61.0.0
# install
einops==0.3.0
fiona==1.8.21
geopandas==0.13.2
kornia==0.7.3
lightly==1.4.5
lightning[pytorch-extra]==2.0.0
matplotlib==3.5.0
numpy==1.21.2
pandas==1.3.3
pillow==8.4.0
pyarrow==17.0.0
pyproj==3.3.0
rasterio==1.3.0.post1
rtree==1.0.0
Expand Down
2 changes: 2 additions & 0 deletions requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ setuptools==75.1.0
# install
einops==0.8.0
fiona==1.10.1
geopandas==0.14.4
kornia==0.7.3
lightly==1.5.12
lightning[pytorch-extra]==2.4.0
matplotlib==3.9.2
numpy==2.1.1
pandas==2.2.3
pillow==10.4.0
pyarrow==17.0.0
pyproj==3.6.1
rasterio==1.3.11
rtree==1.3.0
Expand Down
Binary file added tests/data/samplers/filtering_4x4.feather
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/samplers/filtering_4x4/filtering_4x4.cpg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ISO-8859-1
Binary file added tests/data/samplers/filtering_4x4/filtering_4x4.dbf
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/samplers/filtering_4x4/filtering_4x4.prj
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PROJCS["NAD_1983_BC_Environment_Albers",GEOGCS["GCS_North_American_1983",DATUM["D_North_American_1983",SPHEROID["GRS_1980",6378137.0,298.257222101]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Albers"],PARAMETER["False_Easting",1000000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",-126.0],PARAMETER["Standard_Parallel_1",50.0],PARAMETER["Standard_Parallel_2",58.5],PARAMETER["Latitude_Of_Origin",45.0],UNIT["Meter",1.0]]
Binary file added tests/data/samplers/filtering_4x4/filtering_4x4.shp
Binary file not shown.
Binary file added tests/data/samplers/filtering_4x4/filtering_4x4.shx
Binary file not shown.
3 changes: 2 additions & 1 deletion tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import torch
from _pytest.fixtures import SubRequest
from geopandas import GeoDataFrame
from lightning.pytorch import Trainer
from matplotlib.figure import Figure
from rasterio.crs import CRS
Expand Down Expand Up @@ -182,7 +183,7 @@ def test_zero_length_sampler(self) -> None:
dm = CustomGeoDataModule()
dm.dataset = CustomGeoDataset()
dm.sampler = RandomGeoSampler(dm.dataset, 1, 1)
dm.sampler.length = 0
dm.sampler.chips = GeoDataFrame()
msg = r'CustomGeoDataModule\.sampler has length 0.'
with pytest.raises(MisconfigurationException, match=msg):
dm.train_dataloader()
Expand Down
Loading

0 comments on commit 7046e89

Please sign in to comment.