Skip to content

Commit

Permalink
Rename VisionDataset to NonGeoDataset (#627)
Browse files Browse the repository at this point in the history
* Rename VisionDataset to NonGeoDataset

* Keep VisionDataset but add DeprecationWarning

* mypy fix

* More fixes

* More fixes

* cast types

* Undo cast

* Fix usage in test

* No idea why...

* Update more datasets
  • Loading branch information
adamjstewart authored Jul 10, 2022
1 parent 2d14883 commit ee657ba
Show file tree
Hide file tree
Showing 48 changed files with 198 additions and 138 deletions.
10 changes: 5 additions & 5 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Sentinel
Non-geospatial Datasets
-----------------------

:class:`VisionDataset` is designed for datasets that lack geospatial information. These datasets can still be combined using :class:`ConcatDataset <torch.utils.data.ConcatDataset>`.
:class:`NonGeoDataset` is designed for datasets that lack geospatial information. These datasets can still be combined using :class:`ConcatDataset <torch.utils.data.ConcatDataset>`.

.. csv-table:: C = classification, R = regression, S = semantic segmentation, I = instance segmentation, T = time series, CD = change detection, OD = object detection
:widths: 15 7 15 12 11 12 15 13
Expand Down Expand Up @@ -342,15 +342,15 @@ VectorDataset

.. autoclass:: VectorDataset

VisionDataset
NonGeoDataset
^^^^^^^^^^^^^

.. autoclass:: VisionDataset
.. autoclass:: NonGeoDataset

VisionClassificationDataset
NonGeoClassificationDataset
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: VisionClassificationDataset
.. autoclass:: NonGeoClassificationDataset

IntersectionDataset
^^^^^^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion docs/api/samplers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ torchgeo.samplers
Samplers
--------

Samplers are used to index a dataset, retrieving a single query at a time. For :class:`~torchgeo.datasets.VisionDataset`, dataset objects can be indexed with integers, and PyTorch's builtin samplers are sufficient. For :class:`~torchgeo.datasets.GeoDataset`, dataset objects require a bounding box for indexing. For this reason, we define our own :class:`GeoSampler` implementations below. These can be used like so:
Samplers are used to index a dataset, retrieving a single query at a time. For :class:`~torchgeo.datasets.NonGeoDataset`, dataset objects can be indexed with integers, and PyTorch's builtin samplers are sufficient. For :class:`~torchgeo.datasets.GeoDataset`, dataset objects require a bounding box for indexing. For this reason, we define our own :class:`GeoSampler` implementations below. These can be used like so:

.. code-block:: python
Expand Down
2 changes: 1 addition & 1 deletion docs/user/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ Datasets

A major component of TorchGeo is the large collection of :mod:`torchgeo.datasets` that have been implemented. Adding new datasets to this list is a great way to contribute to the library. A brief checklist to follow when implementing a new dataset:

* Implement the dataset extending either :class:`~torchgeo.datasets.GeoDataset` or :class:`~torchgeo.datasets.VisionDataset`
* Implement the dataset extending either :class:`~torchgeo.datasets.GeoDataset` or :class:`~torchgeo.datasets.NonGeoDataset`
* Add the dataset definition to ``torchgeo/datasets/__init__.py``
* Add a ``data.py`` script to ``tests/data/<new dataset>/`` that generates test data with the same directory structure/file naming conventions as the new dataset
* Add appropriate tests with 100% test coverage to ``tests/datasets/``
Expand Down
4 changes: 2 additions & 2 deletions tests/data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ rec = {"type": "Feature", "id": "0", "properties": OrderedDict(), "geometry": {"
dst.write(rec)
```

## VisionDataset
## NonGeoDataset

VisionDataset data can be created like so.
NonGeoDataset data can be created like so.

### RGB images

Expand Down
110 changes: 68 additions & 42 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
BoundingBox,
GeoDataset,
IntersectionDataset,
NonGeoClassificationDataset,
NonGeoDataset,
RasterDataset,
Sentinel2,
UnionDataset,
Expand Down Expand Up @@ -47,6 +49,14 @@ class CustomVectorDataset(VectorDataset):
filename_glob = "*.geojson"


class CustomNonGeoDataset(NonGeoDataset):
def __getitem__(self, index: int) -> Dict[str, int]:
return {"index": index}

def __len__(self) -> int:
return 2


class CustomVisionDataset(VisionDataset):
def __getitem__(self, index: int) -> Dict[str, int]:
return {"index": index}
Expand Down Expand Up @@ -137,8 +147,8 @@ def test_abstract(self) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
GeoDataset() # type: ignore[abstract]

def test_and_vision(self, dataset: GeoDataset) -> None:
ds2 = CustomVisionDataset()
def test_and_nongeo(self, dataset: GeoDataset) -> None:
ds2 = CustomNonGeoDataset()
with pytest.raises(
ValueError, match="IntersectionDataset only supports GeoDatasets"
):
Expand Down Expand Up @@ -227,100 +237,116 @@ def test_no_data(self, tmp_path: Path) -> None:
VectorDataset(str(tmp_path))


class TestVisionDataset:
class TestNonGeoDataset:
@pytest.fixture(scope="class")
def dataset(self) -> VisionDataset:
return CustomVisionDataset()
def dataset(self) -> NonGeoDataset:
return CustomNonGeoDataset()

def test_getitem(self, dataset: VisionDataset) -> None:
def test_getitem(self, dataset: NonGeoDataset) -> None:
assert dataset[0] == {"index": 0}

def test_len(self, dataset: VisionDataset) -> None:
def test_len(self, dataset: NonGeoDataset) -> None:
assert len(dataset) == 2

def test_add_two(self) -> None:
ds1 = CustomVisionDataset()
ds2 = CustomVisionDataset()
ds1 = CustomNonGeoDataset()
ds2 = CustomNonGeoDataset()
dataset = ds1 + ds2
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 4

def test_add_three(self) -> None:
ds1 = CustomVisionDataset()
ds2 = CustomVisionDataset()
ds3 = CustomVisionDataset()
ds1 = CustomNonGeoDataset()
ds2 = CustomNonGeoDataset()
ds3 = CustomNonGeoDataset()
dataset = ds1 + ds2 + ds3
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 6

def test_add_four(self) -> None:
ds1 = CustomVisionDataset()
ds2 = CustomVisionDataset()
ds3 = CustomVisionDataset()
ds4 = CustomVisionDataset()
ds1 = CustomNonGeoDataset()
ds2 = CustomNonGeoDataset()
ds3 = CustomNonGeoDataset()
ds4 = CustomNonGeoDataset()
dataset = (ds1 + ds2) + (ds3 + ds4)
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 8

def test_str(self, dataset: VisionDataset) -> None:
assert "type: VisionDataset" in str(dataset)
def test_str(self, dataset: NonGeoDataset) -> None:
assert "type: NonGeoDataset" in str(dataset)
assert "size: 2" in str(dataset)

def test_abstract(self) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
VisionDataset() # type: ignore[abstract]
NonGeoDataset() # type: ignore[abstract]


class TestVisionClassificationDataset:
class TestVisionDataset:
def test_deprecation(self) -> None:
match = "VisionDataset is deprecated, use NonGeoDataset instead."
with pytest.warns(DeprecationWarning, match=match):
CustomVisionDataset()


class TestNonGeoClassificationDataset:
@pytest.fixture(scope="class")
def dataset(self, root: str) -> VisionClassificationDataset:
def dataset(self, root: str) -> NonGeoClassificationDataset:
transforms = nn.Identity()
return VisionClassificationDataset(root, transforms=transforms)
return NonGeoClassificationDataset(root, transforms=transforms)

@pytest.fixture(scope="class")
def root(self) -> str:
root = os.path.join("tests", "data", "visionclassificationdataset")
root = os.path.join("tests", "data", "nongeoclassification")
return root

def test_getitem(self, dataset: VisionClassificationDataset) -> None:
def test_getitem(self, dataset: NonGeoClassificationDataset) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert x["image"].shape[0] == 3

def test_len(self, dataset: VisionClassificationDataset) -> None:
def test_len(self, dataset: NonGeoClassificationDataset) -> None:
assert len(dataset) == 2

def test_add_two(self, root: str) -> None:
ds1 = VisionClassificationDataset(root)
ds2 = VisionClassificationDataset(root)
ds1 = NonGeoClassificationDataset(root)
ds2 = NonGeoClassificationDataset(root)
dataset = ds1 + ds2
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 4

def test_add_three(self, root: str) -> None:
ds1 = VisionClassificationDataset(root)
ds2 = VisionClassificationDataset(root)
ds3 = VisionClassificationDataset(root)
ds1 = NonGeoClassificationDataset(root)
ds2 = NonGeoClassificationDataset(root)
ds3 = NonGeoClassificationDataset(root)
dataset = ds1 + ds2 + ds3
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 6

def test_add_four(self, root: str) -> None:
ds1 = VisionClassificationDataset(root)
ds2 = VisionClassificationDataset(root)
ds3 = VisionClassificationDataset(root)
ds4 = VisionClassificationDataset(root)
ds1 = NonGeoClassificationDataset(root)
ds2 = NonGeoClassificationDataset(root)
ds3 = NonGeoClassificationDataset(root)
ds4 = NonGeoClassificationDataset(root)
dataset = (ds1 + ds2) + (ds3 + ds4)
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 8

def test_str(self, dataset: VisionClassificationDataset) -> None:
assert "type: VisionDataset" in str(dataset)
def test_str(self, dataset: NonGeoClassificationDataset) -> None:
assert "type: NonGeoDataset" in str(dataset)
assert "size: 2" in str(dataset)


class TestVisionClassificationDataset:
def test_deprecation(self) -> None:
root = os.path.join("tests", "data", "nongeoclassification")
match = "VisionClassificationDataset is deprecated, "
match += "use NonGeoClassificationDataset instead."
with pytest.warns(DeprecationWarning, match=match):
VisionClassificationDataset(root)


class TestIntersectionDataset:
@pytest.fixture(scope="class")
def dataset(self) -> IntersectionDataset:
Expand All @@ -341,9 +367,9 @@ def test_str(self, dataset: IntersectionDataset) -> None:
assert "bbox: BoundingBox" in out
assert "size: 1" in out

def test_vision_dataset(self) -> None:
ds1 = CustomVisionDataset()
ds2 = CustomVisionDataset()
def test_nongeo_dataset(self) -> None:
ds1 = CustomNonGeoDataset()
ds2 = CustomNonGeoDataset()
with pytest.raises(
ValueError, match="IntersectionDataset only supports GeoDatasets"
):
Expand Down Expand Up @@ -395,9 +421,9 @@ def test_str(self, dataset: UnionDataset) -> None:
assert "bbox: BoundingBox" in out
assert "size: 2" in out

def test_vision_dataset(self) -> None:
ds1 = CustomVisionDataset()
ds2 = CustomVisionDataset()
def test_nongeo_dataset(self) -> None:
ds1 = CustomNonGeoDataset()
ds2 = CustomNonGeoDataset()
with pytest.raises(ValueError, match="UnionDataset only supports GeoDatasets"):
UnionDataset(ds1, ds2) # type: ignore[arg-type]

Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# GeoDataset
"ChesapeakeCVPRDataModule",
"NAIPChesapeakeDataModule",
# VisionDataset
# NonGeoDataset
"BigEarthNetDataModule",
"COWCCountingDataModule",
"DeepGlobeLandCoverDataModule",
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datamodules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from torch.utils.data import Subset, TensorDataset, random_split

from ..datasets import VisionDataset
from ..datasets import NonGeoDataset


def dataset_split(
dataset: Union[TensorDataset, VisionDataset],
dataset: Union[TensorDataset, NonGeoDataset],
val_pct: float,
test_pct: Optional[float] = None,
) -> List[Subset[Any]]:
Expand Down
8 changes: 6 additions & 2 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from .geo import (
GeoDataset,
IntersectionDataset,
NonGeoClassificationDataset,
NonGeoDataset,
RasterDataset,
UnionDataset,
VectorDataset,
Expand Down Expand Up @@ -143,7 +145,7 @@
"OpenBuildings",
"Sentinel",
"Sentinel2",
# VisionDataset
# NonGeoDataset
"ADVANCE",
"BeninSmallHolderCashews",
"BigEarthNet",
Expand Down Expand Up @@ -191,11 +193,13 @@
# Base classes
"GeoDataset",
"IntersectionDataset",
"NonGeoClassificationDataset",
"NonGeoDataset",
"RasterDataset",
"UnionDataset",
"VectorDataset",
"VisionDataset",
"VisionClassificationDataset",
"VisionDataset",
# Utilities
"BoundingBox",
"concat_samples",
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/advance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from PIL import Image
from torch import Tensor

from .geo import VisionDataset
from .geo import NonGeoDataset
from .utils import download_and_extract_archive


class ADVANCE(VisionDataset):
class ADVANCE(NonGeoDataset):
"""ADVANCE dataset.
The `ADVANCE <https://akchen.github.io/ADVANCE-DATASET/>`__
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/benin_cashews.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from rasterio.crs import CRS
from torch import Tensor

from .geo import VisionDataset
from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive


# TODO: read geospatial information from stac.json files
class BeninSmallHolderCashews(VisionDataset):
class BeninSmallHolderCashews(NonGeoDataset):
r"""Smallholder Cashew Plantations in Benin dataset.
This dataset contains labels for cashew plantations in a 120 km\ :sup:`2`\ area
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from rasterio.enums import Resampling
from torch import Tensor

from .geo import VisionDataset
from .geo import NonGeoDataset
from .utils import download_url, extract_archive, sort_sentinel2_bands


class BigEarthNet(VisionDataset):
class BigEarthNet(NonGeoDataset):
"""BigEarthNet dataset.
The `BigEarthNet <https://bigearth.net/>`__
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/cowc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from PIL import Image
from torch import Tensor

from .geo import VisionDataset
from .geo import NonGeoDataset
from .utils import check_integrity, download_and_extract_archive


class COWC(VisionDataset, abc.ABC):
class COWC(NonGeoDataset, abc.ABC):
"""Abstract base class for the COWC dataset.
The `Cars Overhead With Context (COWC) <https://gdo152.llnl.gov/cowc/>`_ data set
Expand Down
Loading

0 comments on commit ee657ba

Please sign in to comment.