diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py index f625e3a5cc9..b41634084cc 100644 --- a/tests/datasets/test_quakeset.py +++ b/tests/datasets/test_quakeset.py @@ -1,9 +1,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import builtins import os import shutil from pathlib import Path +from typing import Any import matplotlib.pyplot as plt import pytest @@ -25,7 +27,7 @@ class TestQuakeSet: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> QuakeSet: - monkeypatch.setattr(torchgeo.datasets.fire_risk, "download_url", download_url) + monkeypatch.setattr(torchgeo.datasets.quakeset, "download_url", download_url) url = os.path.join("tests", "data", "quakeset", "earthquakes.h5") md5 = "127d0d6a1f82d517129535f50053a4c9" monkeypatch.setattr(QuakeSet, "md5", md5) @@ -37,6 +39,17 @@ def dataset( root, split, transforms=transforms, download=True, checksum=True ) + @pytest.fixture + def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: + import_orig = builtins.__import__ + + def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == "h5py": + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mocked_import) + def test_getitem(self, dataset: QuakeSet) -> None: x = dataset[0] assert isinstance(x, dict) @@ -50,13 +63,6 @@ def test_len(self, dataset: QuakeSet) -> None: def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None: QuakeSet(root=str(tmp_path), download=True) - def test_already_downloaded_not_extracted( - self, dataset: QuakeSet, tmp_path: Path - ) -> None: - shutil.rmtree(os.path.dirname(dataset.root)) - download_url(dataset.url, root=str(tmp_path)) - QuakeSet(root=str(tmp_path), download=False) - def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match="Dataset not found"): QuakeSet(str(tmp_path)) @@ -68,5 +74,6 @@ def test_plot(self, dataset: QuakeSet) -> None: dataset.plot(x, show_titles=False) plt.close() x["prediction"] = x["label"].clone() + x["magnitude"] = torch.tensor(0.0) dataset.plot(x) plt.close()