Skip to content

Commit

Permalink
Refactor SegmentationTask testing
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Dec 27, 2021
1 parent a6cdc38 commit a102194
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 121 deletions.
14 changes: 6 additions & 8 deletions conf/defaults.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
config_file: null # This lets the user pass a config filename to load other arguments from

program: # These are the arguments that define how the train.py script works
seed: 1337
seed: 0
output_dir: output
data_dir: data
log_dir: logs
Expand All @@ -16,16 +16,17 @@ experiment: # These are arugments specific to the experiment we are running
root_dir: ${program.data_dir}
seed: ${program.seed}
batch_size: 32
num_workers: 4
num_workers: 0


# The values here are taken from the defaults here https://pytorch-lightning.readthedocs.io/en/1.3.8/common/trainer.html#init
# this probably should be made into a schema, e.g. as shown https://omegaconf.readthedocs.io/en/2.0_branch/structured_config.html#merging-with-other-configs
trainer: # These are the parameters passed to the pytorch lightning Trainer object
logger: True
checkpoint_callback: True
callbacks: null
default_root_dir: null
detect_anomaly: False
enable_checkpointing: True
gradient_clip_val: 0.0
gradient_clip_algorithm: 'norm'
process_position: 0
Expand All @@ -43,16 +44,15 @@ trainer: # These are the parameters passed to the pytorch lightning Trainer obje
accumulate_grad_batches: 1
max_epochs: null
min_epochs: null
max_steps: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: 1.0
limit_val_batches: 1.0
limit_test_batches: 1.0
limit_predict_batches: 1.0
val_check_interval: 1.0
flush_logs_every_n_steps: 100
log_every_n_steps: 50
log_every_n_steps: 1
accelerator: null
sync_batchnorm: False
precision: 32
Expand All @@ -66,9 +66,7 @@ trainer: # These are the parameters passed to the pytorch lightning Trainer obje
reload_dataloaders_every_epoch: False
auto_lr_find: False
replace_sampler_ddp: True
terminate_on_nan: False
auto_scale_batch_size: False
prepare_data_per_node: True
plugins: null
amp_backend: 'native'
move_metrics_to_cpu: False
Expand Down
2 changes: 1 addition & 1 deletion conf/task_defaults/byol.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ experiment:
- "de-test"
test_splits:
- "de-test"
batch_size: 64
batch_size: 1
num_workers: 0
8 changes: 4 additions & 4 deletions conf/task_defaults/chesapeake_cvpr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ experiment:
learning_rate_schedule_patience: 6
in_channels: 4
num_classes: 7
num_filters: 256
num_filters: 1
ignore_zeros: False
datamodule:
root_dir: "tests/data/chesapeake/cvpr"
Expand All @@ -20,8 +20,8 @@ experiment:
- "de-test"
test_splits:
- "de-test"
patches_per_tile: 200
patch_size: 256
batch_size: 64
patches_per_tile: 1
patch_size: 64
batch_size: 1
num_workers: 0
class_set: ${experiment.module.num_classes}
2 changes: 1 addition & 1 deletion conf/task_defaults/etci2021.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ experiment:
ignore_zeros: True
datamodule:
root_dir: "tests/data/etci2021"
batch_size: 32
batch_size: 1
num_workers: 0
4 changes: 2 additions & 2 deletions conf/task_defaults/landcoverai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ experiment:
verbose: false
in_channels: 3
num_classes: 6
num_filters: 256
num_filters: 1
ignore_zeros: False
datamodule:
root_dir: "tests/data/landcoverai"
batch_size: 32
batch_size: 1
num_workers: 0
4 changes: 2 additions & 2 deletions conf/task_defaults/naipchesapeake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ experiment:
learning_rate_schedule_patience: 2
in_channels: 4
num_classes: 13
num_filters: 64
num_filters: 1
ignore_zeros: False
datamodule:
naip_root_dir: "tests/data/naip"
chesapeake_root_dir: "tests/data/chesapeake/BAYWIDE"
batch_size: 32
batch_size: 1
num_workers: 0
patch_size: 32
8 changes: 4 additions & 4 deletions conf/task_defaults/oscd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ experiment:
loss: "jaccard"
segmentation_model: "unet"
encoder_name: "resnet18"
encoder_weights: null
encoder_weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
verbose: false
in_channels: 26
num_classes: 2
num_filters: 256
num_filters: 1
ignore_zeros: True
datamodule:
root_dir: "tests/data/oscd"
batch_size: 32
batch_size: 1
num_workers: 0
val_split_pct: 0.1
bands: "all"
num_patches_per_tile: 128
num_patches_per_tile: 1
3 changes: 2 additions & 1 deletion conf/task_defaults/sen12ms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ experiment:
ignore_zeros: False
datamodule:
root_dir: "tests/data/sen12ms"
batch_size: 32
batch_size: 1
num_workers: 0
seed: 0
9 changes: 7 additions & 2 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,24 @@ def test_config_file(tmp_path: Path) -> None:
subprocess.run(args, check=True)


# TODO: these tests can probably be removed now
@pytest.mark.parametrize(
"task",
[
"bigearthnet",
"bigearthnet_all",
"bigearthnet_s1",
"bigearthnet_s2",
"byol",
"chesapeake_cvpr",
"cowc_counting",
"cyclone",
"landcoverai",
"naipchesapeake",
"oscd",
"resisc45",
"sen12ms",
"so2sat",
"so2sat_supervised",
"so2sat_unsupervised",
"ucmerced",
],
)
Expand Down
24 changes: 11 additions & 13 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.

import os
from typing import Any, Dict, Union, cast
from typing import Any, Dict, Type, cast

import pytest
from omegaconf import OmegaConf
Expand All @@ -29,7 +29,7 @@ class TestClassificationTask:
("ucmerced", UCMercedDataModule),
],
)
def test_trainer(self, name: str, classname: LightningDataModule) -> None:
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
if name == "so2sat":
pytest.importorskip("h5py")

Expand All @@ -51,7 +51,7 @@ def test_trainer(self, name: str, classname: LightningDataModule) -> None:
trainer.test(model=model, datamodule=datamodule)

@pytest.fixture
def model_kwargs(self) -> Dict[str, Union[str, int]]:
def model_kwargs(self) -> Dict[Any, Any]:
return {
"classification_model": "resnet18",
"in_channels": 1,
Expand All @@ -60,35 +60,33 @@ def model_kwargs(self) -> Dict[str, Union[str, int]]:
"weights": "random",
}

def test_pretrained(
self, model_kwargs: Dict[str, Union[str, int]], checkpoint: str
) -> None:
def test_pretrained(self, model_kwargs: Dict[Any, Any], checkpoint: str) -> None:
model_kwargs["weights"] = checkpoint
with pytest.warns(UserWarning):
ClassificationTask(**model_kwargs)

def test_invalid_pretrained(
self, model_kwargs: Dict[str, Union[str, int]], checkpoint: str
self, model_kwargs: Dict[Any, Any], checkpoint: str
) -> None:
model_kwargs["weights"] = checkpoint
model_kwargs["classification_model"] = "resnet50"
match = "Trying to load resnet18 weights into a resnet50"
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)

def test_invalid_loss(self, model_kwargs: Dict[str, Union[str, int]]) -> None:
def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["loss"] = "invalid_loss"
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)

def test_invalid_model(self, model_kwargs: Dict[str, Union[str, int]]) -> None:
def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["classification_model"] = "invalid_model"
match = "Model type 'invalid_model' is not a valid timm model."
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)

def test_invalid_weights(self, model_kwargs: Dict[str, Union[str, int]]) -> None:
def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["weights"] = "invalid_weights"
match = "Weight type 'invalid_weights' is not valid."
with pytest.raises(ValueError, match=match):
Expand All @@ -104,7 +102,7 @@ class TestMultiLabelClassificationTask:
("bigearthnet_s2", BigEarthNetDataModule),
],
)
def test_trainer(self, name: str, classname: LightningDataModule) -> None:
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
Expand All @@ -123,7 +121,7 @@ def test_trainer(self, name: str, classname: LightningDataModule) -> None:
trainer.test(model=model, datamodule=datamodule)

@pytest.fixture
def model_kwargs(self) -> Dict[str, Union[str, int]]:
def model_kwargs(self) -> Dict[Any, Any]:
return {
"classification_model": "resnet18",
"in_channels": 1,
Expand All @@ -132,7 +130,7 @@ def model_kwargs(self) -> Dict[str, Union[str, int]]:
"weights": "random",
}

def test_invalid_loss(self, model_kwargs: Dict[str, Union[str, int]]) -> None:
def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["loss"] = "invalid_loss"
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
Expand Down
4 changes: 2 additions & 2 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.

import os
from typing import Any, Dict, cast
from typing import Any, Dict, Type, cast

import pytest
from omegaconf import OmegaConf
Expand All @@ -17,7 +17,7 @@ class TestRegressionTask:
"name,classname",
[("cowc_counting", COWCCountingDataModule), ("cyclone", CycloneDataModule)],
)
def test_trainer(self, name: str, classname: LightningDataModule) -> None:
def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
conf = OmegaConf.load(os.path.join("conf", "task_defaults", name + ".yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)
Expand Down
Loading

0 comments on commit a102194

Please sign in to comment.