Skip to content

Commit

Permalink
Fix resisc45 testing
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Dec 19, 2021
1 parent 3d40ecb commit 13804e0
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions tests/trainers/test_resisc45.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
# Licensed under the MIT License.

import os
from typing import Any, Dict
from typing import Any, Dict, Generator

import pytest
from _pytest.monkeypatch import MonkeyPatch

from torchgeo.datasets import RESISC45DataModule
from torchgeo.trainers.resisc45 import RESISC45ClassificationTask

from .test_utils import FakeTrainer, mocked_log


class TestRESISC45ClassificationTask:
@pytest.fixture(scope="class")
Expand All @@ -34,8 +37,13 @@ def config(self) -> Dict[str, Any]:
return task_args

@pytest.fixture
def task(self, config: Dict[str, Any]) -> RESISC45ClassificationTask:
def task(
self, config: Dict[str, Any], monkeypatch: Generator[MonkeyPatch, None, None]
) -> RESISC45ClassificationTask:
task = RESISC45ClassificationTask(**config)
trainer = FakeTrainer()
monkeypatch.setattr(task, "trainer", trainer) # type: ignore[attr-defined]
monkeypatch.setattr(task, "log", mocked_log) # type: ignore[attr-defined]
return task

def test_training(
Expand Down

0 comments on commit 13804e0

Please sign in to comment.