diff --git a/.gitattributes b/.gitattributes index 6f6b1096141..ed7c1989f2d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,2 @@ # Do not change line endings on test data, it will change the MD5 -/tests/data/** binary +/tests/data/*/** binary diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index c529da285e5..ad63dd2f034 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -62,7 +62,7 @@ jobs: if: ${{ runner.os == 'Windows' }} - name: Install conda dependencies (Windows) run: | - conda install h5py 'rasterio>=1.0' + conda install fiona h5py 'rasterio>=1.0' conda list conda info if: ${{ runner.os == 'Windows' }} diff --git a/docs/datasets.rst b/docs/datasets.rst index 3ec879edb3a..0d138ca5dd9 100644 --- a/docs/datasets.rst +++ b/docs/datasets.rst @@ -10,6 +10,11 @@ Geospatial Datasets :class:`GeoDataset` is designed for datasets that contain geospatial information, like latitude, longitude, coordinate system, and projection. Datasets containing this kind of information can be combined using :class:`ZipDataset`. +Canadian Building Footprints +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: CanadianBuildingFootprints + Chesapeake Bay High-Resolution Land Cover Project ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/notebooks/Canadian Building Footprints Dataset.ipynb b/docs/notebooks/Canadian Building Footprints Dataset.ipynb new file mode 100644 index 00000000000..1208f4d2972 --- /dev/null +++ b/docs/notebooks/Canadian Building Footprints Dataset.ipynb @@ -0,0 +1,276 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "pediatric-slovakia", + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "import sys\n", + "sys.path.append(\"../..\")\n", + "import os\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from torch.utils.data import DataLoader\n", + "from torchgeo.datasets import CanadianBuildingFootprints\n", + "from torchgeo.samplers import RandomGeoSampler\n", + "from torchgeo.datasets.utils import BoundingBox" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "threatened-coverage", + "metadata": {}, + "outputs": [], + "source": [ + "ROOT_DIR = os.path.expanduser(\"~/ssdprivate/cbf/\")" + ] + }, + { + "cell_type": "markdown", + "id": "minimal-spread", + "metadata": {}, + "source": [ + "## Visualization example" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "framed-voice", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3min 5s, sys: 1.68 s, total: 3min 7s\n", + "Wall time: 3min 6s\n" + ] + } + ], + "source": [ + "%%time\n", + "ds = CanadianBuildingFootprints(\n", + " ROOT_DIR,\n", + " download=False,\n", + " checksum=False\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "governmental-lesson", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BoundingBox(minx=-141.008104, maxx=-52.621493, miny=41.73535, maxy=74.773117, mint=0.0, maxt=9.223372036854776e+18)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.bounds" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "sustainable-visitor", + "metadata": {}, + "outputs": [], + "source": [ + "bounds = BoundingBox(-79.69096183776855,-79.68220710754395,43.78839898848133,43.79482711775757,0,1)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "eight-teach", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.bounds.intersects(bounds)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "massive-pipeline", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1min, sys: 597 ms, total: 1min\n", + "Wall time: 1min\n" + ] + } + ], + "source": [ + "%%time\n", + "sample = ds[bounds]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "assumed-engine", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([642, 875])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample[\"masks\"].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "special-catch", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "ds.plot(sample[\"masks\"])" + ] + }, + { + "cell_type": "markdown", + "id": "special-officer", + "metadata": {}, + "source": [ + "## DataLoader example" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "outstanding-avenue", + "metadata": {}, + "outputs": [], + "source": [ + "sampler = RandomGeoSampler(\n", + " roi=ds.bounds,\n", + " size=256,\n", + " length=48\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "unexpected-europe", + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = DataLoader(ds, sampler=sampler, batch_size=32)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "brutal-shade", + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "empty range for randrange() (-141, -307, -166)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_867320/547832430.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"masks\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/envs/torchgeo1/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 519\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 520\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 522\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/envs/torchgeo1/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 560\u001b[0;31m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 561\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 562\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/envs/torchgeo1/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_index\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 510\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 511\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 512\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 513\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 514\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/envs/torchgeo1/lib/python3.8/site-packages/torch/utils/data/sampler.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mIterator\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 226\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 227\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/azurefiles/torchgeo/docs/notebooks/../../torchgeo/samplers/samplers.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 66\u001b[0m \"\"\"\n\u001b[1;32m 67\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 68\u001b[0;31m \u001b[0mminx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mminx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmaxx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 69\u001b[0m \u001b[0mmaxx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mminx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/envs/torchgeo1/lib/python3.8/random.py\u001b[0m in \u001b[0;36mrandint\u001b[0;34m(self, a, b)\u001b[0m\n\u001b[1;32m 246\u001b[0m \"\"\"\n\u001b[1;32m 247\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 248\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 249\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 250\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_randbelow_with_getrandbits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/opt/conda/envs/torchgeo1/lib/python3.8/random.py\u001b[0m in \u001b[0;36mrandrange\u001b[0;34m(self, start, stop, step, _int)\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mistart\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_randbelow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwidth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 226\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"empty range for randrange() (%d, %d, %d)\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mistart\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mistop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwidth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 227\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0;31m# Non-unit step argument supplied.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: empty range for randrange() (-141, -307, -166)" + ] + } + ], + "source": [ + "for batch_idx, batch in enumerate(dataloader):\n", + " print(batch[\"masks\"].shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "peripheral-princeton", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:torchgeo1]", + "language": "python", + "name": "conda-env-torchgeo1-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/environment.yml b/environment.yml index bbe99bdfe9e..b0e28bd4e34 100644 --- a/environment.yml +++ b/environment.yml @@ -3,6 +3,7 @@ channels: - conda-forge dependencies: - cudatoolkit + - fiona - h5py - numpy - pip @@ -17,6 +18,7 @@ dependencies: - black[colorama]>=21b - flake8 - isort[colors]>=4.3.5 + - jupyterlab - mypy>=0.900 - omegaconf - opencv-python diff --git a/requirements.txt b/requirements.txt index aa646de2275..1d68ade3087 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,10 @@ affine black[colorama]>=21b +fiona flake8 h5py isort[colors]>=4.3.5 +jupyterlab matplotlib mypy>=0.900 numpy diff --git a/setup.cfg b/setup.cfg index 8684a4811cb..ef8b9734dab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,7 @@ setup_requires = setuptools>=42 install_requires = affine + fiona matplotlib numpy pillow @@ -50,6 +51,7 @@ datasets = # Optional developer requirements docs = + jupyterlab sphinx pydocstyle[toml]>=6.1 pytorch-sphinx-theme diff --git a/spack.yaml b/spack.yaml index 27b63ce59ab..bb758ba033b 100644 --- a/spack.yaml +++ b/spack.yaml @@ -5,9 +5,11 @@ spack: - "python@3.7:+bz2" - py-affine - "py-black@21:+colorama" + - py-fiona - py-flake8 - py-h5py - "py-isort@4.3.5:+colors" + - py-jupyterlab - py-matplotlib - "py-mypy@0.900:" - py-numpy diff --git a/tests/data/README.md b/tests/data/README.md index e1a44a55e02..21713df0d94 100644 --- a/tests/data/README.md +++ b/tests/data/README.md @@ -1,9 +1,11 @@ -This directory contains fake data used to test torchgeo. Depending on the type of dataset, fake data can be created in one of two ways: +This directory contains fake data used to test torchgeo. Depending on the type of dataset, fake data can be created in multiple ways: ## GeoDataset GeoDataset data can be created like so. We first open an existing data example and use it to copy the driver/CRS/transform to the fake data. +### Raster data + ```python import os @@ -25,7 +27,21 @@ cmap = src.colormap(1) dst.write_colormap(1, cmap) ``` -If the dataset expects multiple files, you can simply copy and rename the file you created. +### Vector data + +```python +import os + +import fiona + +ROOT = "/mnt/blobfuse/adam-scratch/cbf" +FILENAME = "Ontario.geojson" + +src = fiona.open(os.path.join(ROOT, FILENAME)) +dst = fiona.open(FILENAME, "w", **src.meta) +rec = next(iter(src)) +dst.write(rec) +``` ## VisionDataset diff --git a/tests/data/cbf/Alberta.zip b/tests/data/cbf/Alberta.zip new file mode 100644 index 00000000000..88f6f05a905 Binary files /dev/null and b/tests/data/cbf/Alberta.zip differ diff --git a/tests/data/cbf/BritishColumbia.zip b/tests/data/cbf/BritishColumbia.zip new file mode 100644 index 00000000000..d7697543adf Binary files /dev/null and b/tests/data/cbf/BritishColumbia.zip differ diff --git a/tests/data/cbf/Manitoba.zip b/tests/data/cbf/Manitoba.zip new file mode 100644 index 00000000000..cd811f0ad8f Binary files /dev/null and b/tests/data/cbf/Manitoba.zip differ diff --git a/tests/data/cbf/NewBrunswick.zip b/tests/data/cbf/NewBrunswick.zip new file mode 100644 index 00000000000..c74ba8d2f5e Binary files /dev/null and b/tests/data/cbf/NewBrunswick.zip differ diff --git a/tests/data/cbf/NewfoundlandAndLabrador.zip b/tests/data/cbf/NewfoundlandAndLabrador.zip new file mode 100644 index 00000000000..775afc70d69 Binary files /dev/null and b/tests/data/cbf/NewfoundlandAndLabrador.zip differ diff --git a/tests/data/cbf/NorthwestTerritories.zip b/tests/data/cbf/NorthwestTerritories.zip new file mode 100644 index 00000000000..2f717fa9f7c Binary files /dev/null and b/tests/data/cbf/NorthwestTerritories.zip differ diff --git a/tests/data/cbf/NovaScotia.zip b/tests/data/cbf/NovaScotia.zip new file mode 100644 index 00000000000..33d5b191f5f Binary files /dev/null and b/tests/data/cbf/NovaScotia.zip differ diff --git a/tests/data/cbf/Nunavut.zip b/tests/data/cbf/Nunavut.zip new file mode 100644 index 00000000000..2e43bec937c Binary files /dev/null and b/tests/data/cbf/Nunavut.zip differ diff --git a/tests/data/cbf/Ontario.zip b/tests/data/cbf/Ontario.zip new file mode 100644 index 00000000000..b669cba3636 Binary files /dev/null and b/tests/data/cbf/Ontario.zip differ diff --git a/tests/data/cbf/PrinceEdwardIsland.zip b/tests/data/cbf/PrinceEdwardIsland.zip new file mode 100644 index 00000000000..55d24681478 Binary files /dev/null and b/tests/data/cbf/PrinceEdwardIsland.zip differ diff --git a/tests/data/cbf/Quebec.zip b/tests/data/cbf/Quebec.zip new file mode 100644 index 00000000000..05a7094f9db Binary files /dev/null and b/tests/data/cbf/Quebec.zip differ diff --git a/tests/data/cbf/Saskatchewan.zip b/tests/data/cbf/Saskatchewan.zip new file mode 100644 index 00000000000..6ea3decf76f Binary files /dev/null and b/tests/data/cbf/Saskatchewan.zip differ diff --git a/tests/data/cbf/YukonTerritory.zip b/tests/data/cbf/YukonTerritory.zip new file mode 100644 index 00000000000..d32c2fadbe9 Binary files /dev/null and b/tests/data/cbf/YukonTerritory.zip differ diff --git a/tests/datasets/test_cbf.py b/tests/datasets/test_cbf.py new file mode 100644 index 00000000000..15323edb1b6 --- /dev/null +++ b/tests/datasets/test_cbf.py @@ -0,0 +1,92 @@ +import os +import shutil +from pathlib import Path +from typing import Generator + +import matplotlib.pyplot as plt +import pytest +import torch +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch +from rasterio.crs import CRS + +import torchgeo.datasets.utils +from torchgeo.datasets import BoundingBox, CanadianBuildingFootprints, ZipDataset +from torchgeo.transforms import Identity + + +def download_url(url: str, root: str, *args: str) -> None: + shutil.copy(url, root) + + +class TestCanadianBuildingFootprints: + @pytest.fixture + def dataset( + self, + monkeypatch: Generator[MonkeyPatch, None, None], + tmp_path: Path, + request: SubRequest, + ) -> CanadianBuildingFootprints: + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.utils, "download_url", download_url + ) + md5s = [ + "aef9a3deb3297f225d6cdb221cb48527", + "2b7872c4121116fda8f96490daf89d29", + "c71ded923e22569b62b00da2d2a61076", + "75a8f652531790c3c3aefc0655400d6d", + "89ff9c6257efa99365a8b709dde9579b", + "d4d6a36ed834df5cbf5254effca78a4d", + "cce85f6183427e3034704cf35919c985", + "0149c7ec5101c0309c79b7e695dcb394", + "b05216155725f48937804371b945f8ae", + "72d0e6d7196345ca520c825697cc4947", + "77e1c6c71ff0efbdd221b7e7d4a5f2df", + "86e32374f068c7bbb76aa81af0736733", + "5e453a3426b0bb986b2837b85e8b8850", + ] + monkeypatch.setattr( # type: ignore[attr-defined] + CanadianBuildingFootprints, "md5s", md5s + ) + url = os.path.join("tests", "data", "cbf") + os.sep + monkeypatch.setattr( # type: ignore[attr-defined] + CanadianBuildingFootprints, "url", url + ) + monkeypatch.setattr( # type: ignore[attr-defined] + plt, "show", lambda *args: None + ) + (tmp_path / "cbf").mkdir() + root = str(tmp_path) + transforms = Identity() + return CanadianBuildingFootprints( + root, transforms=transforms, download=True, checksum=True + ) + + def test_getitem(self, dataset: CanadianBuildingFootprints) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x["crs"], CRS) + assert isinstance(x["masks"], torch.Tensor) + + def test_add(self, dataset: CanadianBuildingFootprints) -> None: + ds = dataset + dataset + assert isinstance(ds, ZipDataset) + + def test_already_downloaded(self, dataset: CanadianBuildingFootprints) -> None: + CanadianBuildingFootprints(root=dataset.root, download=True) + + def test_plot(self, dataset: CanadianBuildingFootprints) -> None: + query = dataset.bounds + x = dataset[query] + dataset.plot(x["masks"]) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + CanadianBuildingFootprints(str(tmp_path)) + + def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match="query: .* is not within bounds of the index:" + ): + dataset[query] diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 63fb85e0cf2..6a7b8e9afd2 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -1,6 +1,7 @@ """TorchGeo datasets.""" from .benin_cashews import BeninSmallHolderCashews +from .cbf import CanadianBuildingFootprints from .cdl import CDL from .chesapeake import ( Chesapeake, @@ -42,6 +43,7 @@ __all__ = ( "BeninSmallHolderCashews", "BoundingBox", + "CanadianBuildingFootprints", "CDL", "collate_dict", "Chesapeake", diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py new file mode 100644 index 00000000000..64c45ba05b1 --- /dev/null +++ b/torchgeo/datasets/cbf.py @@ -0,0 +1,216 @@ +"""Canadian Building Footprints dataset.""" + +import glob +import os +import sys +from typing import Any, Callable, Dict, Optional + +import fiona +import fiona.transform +import matplotlib.pyplot as plt +import rasterio +import torch +from rasterio.crs import CRS +from rtree.index import Index, Property +from torch import Tensor + +from .geo import GeoDataset +from .utils import BoundingBox, check_integrity, download_and_extract_archive + +_crs = CRS.from_epsg(4326) + + +class CanadianBuildingFootprints(GeoDataset): + """Canadian Building Footprints dataset. + + The `Canadian Building Footprints + `_ dataset contains + 11,842,186 computer generated building footprints in all Canadian provinces and + territories in GeoJSON format. This data is freely available for download and use. + """ + + # TODO: how does one cite this dataset? + # https://github.com/microsoft/CanadianBuildingFootprints/issues/11 + + url = "https://usbuildingdata.blob.core.windows.net/canadian-buildings-v2/" + provinces_territories = [ + "Alberta", + "BritishColumbia", + "Manitoba", + "NewBrunswick", + "NewfoundlandAndLabrador", + "NorthwestTerritories", + "NovaScotia", + "Nunavut", + "Ontario", + "PrinceEdwardIsland", + "Quebec", + "Saskatchewan", + "YukonTerritory", + ] + md5s = [ + "8b4190424e57bb0902bd8ecb95a9235b", + "fea05d6eb0006710729c675de63db839", + "adf11187362624d68f9c69aaa693c46f", + "44269d4ec89521735389ef9752ee8642", + "65dd92b1f3f5f7222ae5edfad616d266", + "346d70a682b95b451b81b47f660fd0e2", + "bd57cb1a7822d72610215fca20a12602", + "c1f29b73cdff9a6a9dd7d086b31ef2cf", + "76ba4b7059c5717989ce34977cad42b2", + "2e4a3fa47b3558503e61572c59ac5963", + "9ff4417ae00354d39a0cf193c8df592c", + "a51078d8e60082c7d3a3818240da6dd5", + "c11f3bd914ecabd7cac2cb2871ec0261", + ] + + def __init__( + self, + root: str = "data", + crs: CRS = _crs, + res: float = 0.00001, + transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Canadian Building Footprints dataset. + + Args: + root: root directory where dataset can be found + crs: :term:`coordinate reference system (CRS)` to project to + res: resolution to use when rasterizing features + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` and data is not found, or + ``checksum=True`` and checksums don't match + """ + self.root = root + self.crs = crs + self.res = res + self.transforms = transforms + self.checksum = checksum + + if download: + self._download() + + if not self._check_integrity(): + raise RuntimeError( + "Dataset not found or corrupted. " + + "You can use download=True to download it" + ) + + # Create an R-tree to index the dataset + self.index = Index(interleaved=False, properties=Property(dimension=3)) + fileglob = os.path.join(root, "**.geojson") + for i, filename in enumerate(glob.iglob(fileglob, recursive=True)): + with fiona.open(filename) as src: + minx, miny, maxx, maxy = src.bounds + (minx, maxx), (miny, maxy) = fiona.transform.transform( + src.crs, crs.to_dict(), [minx, maxx], [miny, maxy] + ) + mint = 0 + maxt = sys.maxsize + coords = (minx, maxx, miny, maxy, mint, maxt) + self.index.insert(i, coords, filename) + + def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + """Retrieve image and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of labels and metadata at that index + + Raises: + IndexError: if query is not within bounds of the index + """ + if not query.intersects(self.bounds): + raise IndexError( + f"query: {query} is not within bounds of the index: {self.bounds}" + ) + + hits = self.index.intersection(query, objects=True) + filename = next(hits).object # TODO: this assumes there is only a single hit + shapes = [] + with fiona.open(filename) as src: + # We need to know the bounding box of the query in the source CRS + (minx, maxx), (miny, maxy) = fiona.transform.transform( + self.crs.to_dict(), + src.crs, + [query.minx, query.maxx], + [query.miny, query.maxy], + ) + + # Filter geometries to those that intersect with the bounding box + for feature in src.filter(bbox=(minx, miny, maxx, maxy)): + # Warp geometries to requested CRS + shape = fiona.transform.transform_geom( + src.crs, self.crs.to_dict(), feature["geometry"] + ) + shapes.append(shape) + + # Rasterize geometries + width = (query.maxx - query.minx) / self.res + height = (query.maxy - query.miny) / self.res + transform = rasterio.transform.from_bounds( + query.minx, query.miny, query.maxx, query.maxy, width, height + ) + masks = rasterio.features.rasterize( + shapes, out_shape=(int(height), int(width)), transform=transform + ) + + sample = { + "masks": torch.tensor(masks), # type: ignore[attr-defined] + "crs": self.crs, + "bbox": query, + } + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _check_integrity(self) -> bool: + """Check integrity of dataset. + + Returns: + True if dataset files are found and/or MD5s match, else False + """ + for prov_terr, md5 in zip(self.provinces_territories, self.md5s): + filepath = os.path.join(self.root, prov_terr + ".zip") + if not check_integrity(filepath, md5 if self.checksum else None): + return False + return True + + def _download(self) -> None: + """Download the dataset and extract it.""" + if self._check_integrity(): + print("Files already downloaded and verified") + return + + for prov_terr, md5 in zip(self.provinces_territories, self.md5s): + download_and_extract_archive( + self.url + prov_terr + ".zip", + self.root, + md5=md5 if self.checksum else None, + ) + + def plot(self, image: Tensor) -> None: + """Plot an image on a map. + + Args: + image: the image to plot + """ + array = image.squeeze().numpy() + + # Plot the image + ax = plt.axes() + ax.imshow(array) + ax.axis("off") + plt.show() + plt.close()