Skip to content

Commit

Permalink
fix tests finally
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley committed Apr 15, 2024
1 parent ccda859 commit c8dae1c
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions tests/datasets/test_quakeset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
Expand All @@ -25,7 +27,7 @@ class TestQuakeSet:
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> QuakeSet:
monkeypatch.setattr(torchgeo.datasets.fire_risk, "download_url", download_url)
monkeypatch.setattr(torchgeo.datasets.quakeset, "download_url", download_url)
url = os.path.join("tests", "data", "quakeset", "earthquakes.h5")
md5 = "127d0d6a1f82d517129535f50053a4c9"
monkeypatch.setattr(QuakeSet, "md5", md5)
Expand All @@ -37,6 +39,17 @@ def dataset(
root, split, transforms=transforms, download=True, checksum=True
)

@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "h5py":
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, "__import__", mocked_import)

def test_getitem(self, dataset: QuakeSet) -> None:
x = dataset[0]
assert isinstance(x, dict)
Expand All @@ -50,13 +63,6 @@ def test_len(self, dataset: QuakeSet) -> None:
def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None:
QuakeSet(root=str(tmp_path), download=True)

def test_already_downloaded_not_extracted(
self, dataset: QuakeSet, tmp_path: Path
) -> None:
shutil.rmtree(os.path.dirname(dataset.root))
download_url(dataset.url, root=str(tmp_path))
QuakeSet(root=str(tmp_path), download=False)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
QuakeSet(str(tmp_path))
Expand All @@ -68,5 +74,6 @@ def test_plot(self, dataset: QuakeSet) -> None:
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["label"].clone()
x["magnitude"] = torch.tensor(0.0)
dataset.plot(x)
plt.close()

0 comments on commit c8dae1c

Please sign in to comment.