Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley committed Apr 23, 2023
1 parent 20831d9 commit b69e484
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 27 deletions.
50 changes: 29 additions & 21 deletions tests/datasets/test_fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from torchgeo.datasets import FAIR1M


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
os.makedirs(root, exist_ok=True)
shutil.copy(url, os.path.join(root, filename))


class TestFAIR1M:
Expand Down Expand Up @@ -73,11 +74,11 @@ def test_getitem(self, dataset: FAIR1M) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["boxes"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert x["image"].shape[0] == 3

if dataset.split != "test":
assert isinstance(x["boxes"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert x["boxes"].shape[-2:] == (5, 2)
assert x["label"].ndim == 1

Expand All @@ -88,27 +89,32 @@ def test_len(self, dataset: FAIR1M) -> None:
assert len(dataset) == 4

def test_already_downloaded(self, dataset: FAIR1M, tmp_path: Path) -> None:
shutil.rmtree(str(tmp_path))
shutil.copytree(self.test_root, str(tmp_path))
FAIR1M(root=str(tmp_path), split=dataset.split)
FAIR1M(root=str(tmp_path), split=dataset.split, download=True)

def test_already_downloaded_not_extracted(
self, dataset: FAIR1M, tmp_path: Path
) -> None:
for path in dataset.paths[dataset.split]:
filepath = os.path.join(self.test_root, path)
output = os.path.join(str(tmp_path), os.path.dirname(filepath))
os.makedirs(os.path.dirname(output))
shutil.copy(filepath, output)
shutil.rmtree(dataset.root)
for filepath, url in zip(
dataset.paths[dataset.split], dataset.urls[dataset.split]
):
output = os.path.join(str(tmp_path), filepath)
os.makedirs(os.path.dirname(output), exist_ok=True)
download_url(url, root=os.path.dirname(output), filename=output)

FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True)

def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None:
paths = dataset.paths[dataset.split]
for path in paths:
filepath = os.path.join(tmp_path, path)
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, "w") as f:
f.write("bad")
md5s = tuple(["randomhash"] * len(FAIR1M.md5s[dataset.split]))
FAIR1M.md5s[dataset.split] = md5s
shutil.rmtree(dataset.root)
for filepath, url in zip(
dataset.paths[dataset.split], dataset.urls[dataset.split]
):
output = os.path.join(str(tmp_path), filepath)
os.makedirs(os.path.dirname(output), exist_ok=True)
download_url(url, root=os.path.dirname(output), filename=output)

with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True)

Expand All @@ -123,6 +129,8 @@ def test_plot(self, dataset: FAIR1M) -> None:
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction_boxes"] = x["boxes"].clone()
dataset.plot(x)
plt.close()

if dataset.split != "test":
x["prediction_boxes"] = x["boxes"].clone()
dataset.plot(x)
plt.close()
14 changes: 8 additions & 6 deletions torchgeo/datasets/fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,12 +415,14 @@ def plot(

axs[0].imshow(image)
axs[0].axis("off")
polygons = [
patches.Polygon(points, color="r", fill=False)
for points in sample["boxes"].numpy()
]
for polygon in polygons:
axs[0].add_patch(polygon)

if "boxes" in sample:
polygons = [
patches.Polygon(points, color="r", fill=False)
for points in sample["boxes"].numpy()
]
for polygon in polygons:
axs[0].add_patch(polygon)

if show_titles:
axs[0].set_title("Ground Truth")
Expand Down

0 comments on commit b69e484

Please sign in to comment.