-
Notifications
You must be signed in to change notification settings - Fork 379
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add IDTReeS dataset * dataset loads data now * add optional laspy and pandas dependencies * fixed docs failing * format * refactor verify and resample chm/hsi to 200x200 * add open3d optional dep * overhaul * temporarily remove open3d install bc their pypi is broken * mypy fixes * fixes per suggestions * general cleanup * test passing * add min version for laspy and pandas * add open3d dependency * add open3d to mypy tests * add hard install for python 3.9 open3d to actions * attempt #2 * I think I got it now * updated tests.yaml * make open3d dep require python<3.9 * open3d has issues with macos python 3.6 * same for 3.7 * skip open3d plot test for macos * formatting * skip open3d plot test for windows * update per suggestions * update test data readme for las files * updated per suggestions * more changes per suggestions * last change per suggestion * Grammar fix in pandas dep requirement comment Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
- Loading branch information
1 parent
fcbd1ab
commit 0434f3c
Showing
60 changed files
with
758 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import builtins | ||
import glob | ||
import os | ||
import shutil | ||
import sys | ||
from pathlib import Path | ||
from typing import Any, Generator | ||
|
||
import matplotlib.pyplot as plt | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from _pytest.fixtures import SubRequest | ||
from _pytest.monkeypatch import MonkeyPatch | ||
|
||
import torchgeo.datasets.utils | ||
from torchgeo.datasets import IDTReeS | ||
|
||
|
||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: | ||
shutil.copy(url, root) | ||
|
||
|
||
class TestIDTReeS: | ||
@pytest.fixture(params=zip(["train", "test", "test"], ["task1", "task1", "task2"])) | ||
def dataset( | ||
self, | ||
monkeypatch: Generator[MonkeyPatch, None, None], | ||
tmp_path: Path, | ||
request: SubRequest, | ||
) -> IDTReeS: | ||
pytest.importorskip("pandas") | ||
pytest.importorskip("laspy") | ||
monkeypatch.setattr( # type: ignore[attr-defined] | ||
torchgeo.datasets.idtrees, "download_url", download_url | ||
) | ||
data_dir = os.path.join("tests", "data", "idtrees") | ||
metadata = { | ||
"train": { | ||
"url": os.path.join(data_dir, "IDTREES_competition_train_v2.zip"), | ||
"md5": "5ddfa76240b4bb6b4a7861d1d31c299c", | ||
"filename": "IDTREES_competition_train_v2.zip", | ||
}, | ||
"test": { | ||
"url": os.path.join(data_dir, "IDTREES_competition_test_v2.zip"), | ||
"md5": "b108931c84a70f2a38a8234290131c9b", | ||
"filename": "IDTREES_competition_test_v2.zip", | ||
}, | ||
} | ||
split, task = request.param | ||
monkeypatch.setattr(IDTReeS, "metadata", metadata) # type: ignore[attr-defined] | ||
root = str(tmp_path) | ||
transforms = nn.Identity() # type: ignore[attr-defined] | ||
return IDTReeS(root, split, task, transforms, download=True, checksum=True) | ||
|
||
@pytest.fixture(params=["pandas", "laspy", "open3d"]) | ||
def mock_missing_module( | ||
self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest | ||
) -> str: | ||
import_orig = builtins.__import__ | ||
package = str(request.param) | ||
|
||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: | ||
if name == package: | ||
raise ImportError() | ||
return import_orig(name, *args, **kwargs) | ||
|
||
monkeypatch.setattr( # type: ignore[attr-defined] | ||
builtins, "__import__", mocked_import | ||
) | ||
return package | ||
|
||
def test_getitem(self, dataset: IDTReeS) -> None: | ||
x = dataset[0] | ||
assert isinstance(x, dict) | ||
assert isinstance(x["image"], torch.Tensor) | ||
assert isinstance(x["chm"], torch.Tensor) | ||
assert isinstance(x["hsi"], torch.Tensor) | ||
assert isinstance(x["las"], torch.Tensor) | ||
assert x["image"].shape == (3, 200, 200) | ||
assert x["chm"].shape == (1, 200, 200) | ||
assert x["hsi"].shape == (369, 200, 200) | ||
assert x["las"].ndim == 2 | ||
assert x["las"].shape[0] == 3 | ||
|
||
if "label" in x: | ||
assert isinstance(x["label"], torch.Tensor) | ||
if "boxes" in x: | ||
assert isinstance(x["boxes"], torch.Tensor) | ||
if x["boxes"].ndim != 1: | ||
assert x["boxes"].ndim == 2 | ||
assert x["boxes"].shape[-1] == 4 | ||
|
||
def test_len(self, dataset: IDTReeS) -> None: | ||
assert len(dataset) == 3 | ||
|
||
def test_already_downloaded(self, dataset: IDTReeS) -> None: | ||
IDTReeS(root=dataset.root, download=True) | ||
|
||
def test_not_downloaded(self, tmp_path: Path) -> None: | ||
err = "Dataset not found in `root` directory and `download=False`, " | ||
"either specify a different `root` directory or use `download=True` " | ||
"to automaticaly download the dataset." | ||
with pytest.raises(RuntimeError, match=err): | ||
IDTReeS(str(tmp_path)) | ||
|
||
def test_not_extracted(self, tmp_path: Path) -> None: | ||
pathname = os.path.join("tests", "data", "idtrees", "*.zip") | ||
root = str(tmp_path) | ||
for zipfile in glob.iglob(pathname): | ||
shutil.copy(zipfile, root) | ||
IDTReeS(root) | ||
|
||
def test_mock_missing_module( | ||
self, dataset: IDTReeS, mock_missing_module: str | ||
) -> None: | ||
package = mock_missing_module | ||
|
||
if package in ["pandas", "laspy"]: | ||
with pytest.raises( | ||
ImportError, | ||
match=f"{package} is not installed and is required to use this dataset", | ||
): | ||
IDTReeS(dataset.root, download=True, checksum=True) | ||
else: | ||
with pytest.raises( | ||
ImportError, | ||
match=f"{package} is not installed and is required to use this dataset", | ||
): | ||
dataset.plot_las(0) | ||
|
||
def test_plot(self, dataset: IDTReeS) -> None: | ||
x = dataset[0].copy() | ||
dataset.plot(x, suptitle="Test") | ||
plt.close() | ||
dataset.plot(x, show_titles=False) | ||
plt.close() | ||
|
||
if "boxes" in x: | ||
x["prediction_boxes"] = x["boxes"] | ||
dataset.plot(x, show_titles=True) | ||
plt.close() | ||
if "label" in x: | ||
x["prediction_label"] = x["label"] | ||
dataset.plot(x, show_titles=False) | ||
plt.close() | ||
|
||
@pytest.mark.skipif( | ||
sys.platform in ["darwin", "win32"], | ||
reason="segmentation fault on macOS and windows", | ||
) | ||
def test_plot_las(self, dataset: IDTReeS) -> None: | ||
pytest.importorskip("open3d") | ||
vis = dataset.plot_las(index=0, colormap="BrBG") | ||
vis.close() | ||
vis = dataset.plot_las(index=0, colormap=None) | ||
vis.close() | ||
vis = dataset.plot_las(index=1, colormap=None) | ||
vis.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.