Skip to content

Commit

Permalink
Remove dataset-specific trainers (microsoft#286)
Browse files Browse the repository at this point in the history
* Remove dataset-specific trainers

* Collation functions will be new in 0.2.0

* Clarify arg docstring

* Style fixes

* Remove files forgotten in rebase

* Fix bug in unbind_samples, add tests

* Fix bugs in datamodule augmentations

* Increase coverage for datamodules

* Fix bugs in logger plotting, properly test

* Fix tests

* Increase coverage of trainers

* Use datamodule plot instead of dataset plot

* Skip datamodules without tests

* Plot predictions

* Fix ClassificationTask tests

* Fix SemanticSegmentationTask tests

* EAFP -> LBYL

* Ensure that tensors are on the CPU before plotting
  • Loading branch information
adamjstewart authored Jan 1, 2022
1 parent ac9bcf4 commit 0689a50
Show file tree
Hide file tree
Showing 31 changed files with 327 additions and 650 deletions.
2 changes: 1 addition & 1 deletion conf/task_defaults/eurosat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
learning_rate_schedule_patience: 6
weights: "random"
in_channels: 13
num_classes: 10
num_classes: 2
datamodule:
root_dir: "tests/data/eurosat"
batch_size: 1
Expand Down
2 changes: 1 addition & 1 deletion conf/task_defaults/resisc45.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
learning_rate_schedule_patience: 6
weights: "random"
in_channels: 3
num_classes: 45
num_classes: 3
datamodule:
root_dir: "tests/data/resisc45"
batch_size: 1
Expand Down
2 changes: 1 addition & 1 deletion conf/task_defaults/ucmerced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 3
num_classes: 21
num_classes: 2
datamodule:
root_dir: "tests/data/ucmerced"
batch_size: 1
Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,4 @@ Collation Functions
.. autofunction:: stack_samples
.. autofunction:: concat_samples
.. autofunction:: merge_samples
.. autofunction:: unbind_samples
22 changes: 20 additions & 2 deletions tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
merge_samples,
percentile_normalization,
stack_samples,
unbind_samples,
working_dir,
)

Expand Down Expand Up @@ -457,7 +458,7 @@ def samples(self) -> List[Dict[str, Any]]:
},
]

def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None:
def test_stack_unbind_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = stack_samples(samples)
assert sample["image"].size() == torch.Size( # type: ignore[attr-defined]
[2, 3]
Expand All @@ -468,6 +469,13 @@ def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None:
)
assert sample["crs"] == [CRS.from_epsg(2000), CRS.from_epsg(2001)]

new_samples = unbind_samples(sample)
for i in range(2):
assert torch.allclose( # type: ignore[attr-defined]
samples[i]["image"], new_samples[i]["image"]
)
assert samples[i]["crs"] == new_samples[i]["crs"]

def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = concat_samples(samples)
assert sample["image"].size() == torch.Size([6]) # type: ignore[attr-defined]
Expand Down Expand Up @@ -500,7 +508,7 @@ def samples(self) -> List[Dict[str, Any]]:
},
]

def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None:
def test_stack_unbind_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = stack_samples(samples)
assert sample["image"].size() == torch.Size( # type: ignore[attr-defined]
[1, 3]
Expand All @@ -515,6 +523,16 @@ def test_stack_samples(self, samples: List[Dict[str, Any]]) -> None:
assert sample["crs1"] == [CRS.from_epsg(2000)]
assert sample["crs2"] == [CRS.from_epsg(2001)]

new_samples = unbind_samples(sample)
assert torch.allclose( # type: ignore[attr-defined]
samples[0]["image"], new_samples[0]["image"]
)
assert samples[0]["crs1"] == new_samples[0]["crs1"]
assert torch.allclose( # type: ignore[attr-defined]
samples[1]["mask"], new_samples[0]["mask"]
)
assert samples[1]["crs2"] == new_samples[0]["crs2"]

def test_concat_samples(self, samples: List[Dict[str, Any]]) -> None:
sample = concat_samples(samples)
assert sample["image"].size() == torch.Size([3]) # type: ignore[attr-defined]
Expand Down
64 changes: 0 additions & 64 deletions tests/trainers/test_chesapeake.py

This file was deleted.

36 changes: 36 additions & 0 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,23 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
trainer.fit(model=model, datamodule=datamodule)
trainer.test(model=model, datamodule=datamodule)

def test_no_logger(self) -> None:
conf = OmegaConf.load(os.path.join("conf", "task_defaults", "ucmerced.yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)

# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
datamodule = UCMercedDataModule(**datamodule_kwargs)

# Instantiate model
model_kwargs = conf_dict["module"]
model = ClassificationTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(logger=None, fast_dev_run=True, log_every_n_steps=1)
trainer.fit(model=model, datamodule=datamodule)

@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
return {
Expand Down Expand Up @@ -120,6 +137,25 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
trainer.fit(model=model, datamodule=datamodule)
trainer.test(model=model, datamodule=datamodule)

def test_no_logger(self) -> None:
conf = OmegaConf.load(
os.path.join("conf", "task_defaults", "bigearthnet_s1.yaml")
)
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)

# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
datamodule = BigEarthNetDataModule(**datamodule_kwargs)

# Instantiate model
model_kwargs = conf_dict["module"]
model = MultiLabelClassificationTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(logger=None, fast_dev_run=True, log_every_n_steps=1)
trainer.fit(model=model, datamodule=datamodule)

@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
return {
Expand Down
67 changes: 0 additions & 67 deletions tests/trainers/test_landcoverai.py

This file was deleted.

55 changes: 0 additions & 55 deletions tests/trainers/test_naipchesapeake.py

This file was deleted.

17 changes: 17 additions & 0 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None:
trainer.fit(model=model, datamodule=datamodule)
trainer.test(model=model, datamodule=datamodule)

def test_no_logger(self) -> None:
conf = OmegaConf.load(os.path.join("conf", "task_defaults", "cyclone.yaml"))
conf_dict = OmegaConf.to_object(conf.experiment)
conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict)

# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
datamodule = CycloneDataModule(**datamodule_kwargs)

# Instantiate model
model_kwargs = conf_dict["module"]
model = RegressionTask(**model_kwargs)

# Instantiate trainer
trainer = Trainer(logger=None, fast_dev_run=True, log_every_n_steps=1)
trainer.fit(model=model, datamodule=datamodule)

def test_invalid_model(self) -> None:
match = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=match):
Expand Down
Loading

0 comments on commit 0689a50

Please sign in to comment.