Skip to content

Commit

Permalink
move fixtures into class methods
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley committed Nov 1, 2021
1 parent a46d3e7 commit d6e548b
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions tests/trainers/test_bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,33 @@
from torchgeo.trainers import BigEarthNetDataModule


@pytest.fixture(scope="module", params=[("s1", 2), ("s2", 12), ("all", 14)])
def bands(request: SubRequest) -> Tuple[str, int]:
return cast(Tuple[str, int], request.param)


@pytest.fixture(scope="module", params=[True, False])
def datamodule(bands: Tuple[str, int], request: SubRequest) -> BigEarthNetDataModule:
band_set = bands[0]
unsupervised_mode = request.param
root = os.path.join("tests", "data", "bigearthnet")
batch_size = 1
num_workers = 0
dm = BigEarthNetDataModule(
root,
band_set,
batch_size,
num_workers,
unsupervised_mode,
val_split_pct=0.3,
test_split_pct=0.3,
)
dm.prepare_data()
dm.setup()
return dm


class TestBigEarthNetDataModule:
@pytest.fixture(params=[("s1", 2), ("s2", 12), ("all", 14)])
def bands(self, request: SubRequest) -> Tuple[str, int]:
return cast(Tuple[str, int], request.param)

@pytest.fixture(params=[True, False])
def datamodule(
self, bands: Tuple[str, int], request: SubRequest
) -> BigEarthNetDataModule:
band_set = bands[0]
unsupervised_mode = request.param
root = os.path.join("tests", "data", "bigearthnet")
batch_size = 1
num_workers = 0
dm = BigEarthNetDataModule(
root,
band_set,
batch_size,
num_workers,
unsupervised_mode,
val_split_pct=0.3,
test_split_pct=0.3,
)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.train_dataloader()))

Expand Down

0 comments on commit d6e548b

Please sign in to comment.