Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RESISC45 Trainer #179

Merged
merged 9 commits into from
Oct 11, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions conf/resisc45.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
trainer:
gpus: 1 # single GPU training
min_epochs: 10
max_epochs: 40
benchmark: True

experiment:
task: "resisc45"
module:
loss: "ce"
classification_model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
datamodule:
batch_size: 128
num_workers: 6
val_split_pct: 0.2
test_split_pct: 0.2
14 changes: 14 additions & 0 deletions conf/task_defaults/resisc45.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
experiment:
task: "resisc45"
module:
loss: "ce"
classification_model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
datamodule:
batch_size: 128
num_workers: 6
weights: ${experiment.module.weights}
val_split_pct: 0.2
test_split_pct: 0.2
23 changes: 8 additions & 15 deletions experiments/run_landcoverai_seed_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@

# Hyperparameter options
model_options = ["unet"]
encoder_options = ["resnet50"]
lr_options = [1e-4]
loss_options = ["ce"]
weight_init_options = ["imagenet"]
seeds = list(range(15))
encoder_options = ["resnet18", "resnet50"]
lr_options = [1e-2, 1e-3, 1e-4]
loss_options = ["ce", "jaccard"]
weight_init_options = ["null", "imagenet"]


def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
Expand All @@ -36,18 +35,13 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
if __name__ == "__main__":
work: "Queue[str]" = Queue()

for (model, encoder, lr, loss, weight_init, seed) in itertools.product(
model_options,
encoder_options,
lr_options,
loss_options,
weight_init_options,
seeds,
for (model, encoder, lr, loss, weight_init) in itertools.product(
model_options, encoder_options, lr_options, loss_options, weight_init_options
):

experiment_name = f"{model}_{encoder}_{lr}_{loss}_{weight_init}_{seed}"
experiment_name = f"{model}_{encoder}_{lr}_{loss}_{weight_init}"

output_dir = os.path.join("output", "landcoverai_seed_experiments")
output_dir = os.path.join("output", "landcoverai_experiments")
log_dir = os.path.join(output_dir, "logs")
config_file = os.path.join("conf", "landcoverai.yaml")

Expand All @@ -63,7 +57,6 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
+ f" experiment.module.encoder_name={encoder}"
+ f" experiment.module.encoder_weights={weight_init}"
+ f" program.output_dir={output_dir}"
+ f" program.seed={seed}"
+ f" program.log_dir={log_dir}"
+ f" program.data_dir={DATA_DIR}"
+ " trainer.gpus=[GPU]"
Expand Down
82 changes: 82 additions & 0 deletions experiments/run_resisc45_experiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Runs the train script with a grid of hyperparameters."""
import itertools
import os
import subprocess
from multiprocessing import Process, Queue

# list of GPU IDs that we want to use, one job will be started for every ID in the list
GPUS = [0]
DRY_RUN = False # if False then print out the commands to be run, if True then run
DATA_DIR = "" # path to the RESISC45 data directory

# Hyperparameter options
model_options = ["resnet18", "resnet50"]
lr_options = [1e-2, 1e-3, 1e-4]
loss_options = ["ce"]
weight_options = ["imagenet_only", "random"]


def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"""Process for each ID in GPUS."""
while not work.empty():
experiment = work.get()
experiment = experiment.replace("GPU", str(gpu_idx))
print(experiment)
if not DRY_RUN:
subprocess.call(experiment.split(" "))
return True


if __name__ == "__main__":
work: "Queue[str]" = Queue()

for (model, lr, loss, weights) in itertools.product(
model_options,
lr_options,
loss_options,
weight_options,
):

experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_','-')}"

output_dir = os.path.join("output", "resisc45_experiments")
log_dir = os.path.join(output_dir, "logs")
config_file = os.path.join("conf", "resisc45.yaml")

if not os.path.exists(os.path.join(output_dir, experiment_name)):

command = (
"python train.py"
+ f" config_file={config_file}"
+ f" experiment.name={experiment_name}"
+ f" experiment.module.classification_model={model}"
+ f" experiment.module.learning_rate={lr}"
+ f" experiment.module.loss={loss}"
+ f" experiment.module.weights={weights}"
+ f" experiment.datamodule.weights={weights}"
+ f" program.output_dir={output_dir}"
+ f" program.log_dir={log_dir}"
+ f" program.data_dir={DATA_DIR}"
+ " trainer.gpus=[GPU]"
)
command = command.strip()

work.put(command)

processes = []
for gpu_idx in GPUS:
p = Process(
target=do_work,
args=(
work,
gpu_idx,
),
)
processes.append(p)
p.start()
for p in processes:
p.join()
20 changes: 20 additions & 0 deletions tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import torch
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
from torch.utils.data import TensorDataset

import torchgeo.datasets.utils
from torchgeo.datasets.utils import (
BoundingBox,
collate_dict,
dataset_split,
disambiguate_timestamp,
download_and_extract_archive,
download_radiant_mlhub_collection,
Expand Down Expand Up @@ -335,3 +337,21 @@ def test_nonexisting_directory(tmp_path: Path) -> None:

with working_dir(str(subdir), create=True):
assert subdir.cwd() == subdir


def test_dataset_split() -> None:
num_samples = 24
x = torch.ones(num_samples, 5) # type: ignore[attr-defined]
y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined]
ds = TensorDataset(x, y)

# Test only train/val set split
train_ds, val_ds = dataset_split(ds, val_pct=1 / 2)
assert len(train_ds) == num_samples // 2
assert len(val_ds) == num_samples // 2

# Test train/val/test set split
train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3)
assert len(train_ds) == num_samples // 3
assert len(val_ds) == num_samples // 3
assert len(test_ds) == num_samples // 3
39 changes: 39 additions & 0 deletions tests/trainers/test_resisc45.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from typing import Any, Dict, cast

import pytest
import torch.nn as nn
import torchvision
from omegaconf import OmegaConf

from torchgeo.trainers import RESISC45ClassificationTask


class TestRESISC45Trainer:
@pytest.fixture
def default_config(self) -> Dict[str, Any]:
task_conf = OmegaConf.load("conf/task_defaults/resisc45.yaml")
task_args = OmegaConf.to_object(task_conf.experiment.module)
task_args = cast(Dict[str, Any], task_args)
return task_args

def test_resnet_ce(self, default_config: Dict[str, Any]) -> None:
default_config["classification_model"] = "resnet18"
default_config["loss"] = "ce"
task = RESISC45ClassificationTask(**default_config)
assert isinstance(task.model, torchvision.models.ResNet)
assert isinstance(task.loss, nn.CrossEntropyLoss) # type: ignore[attr-defined]

def test_invalid_model(self, default_config: Dict[str, Any]) -> None:
default_config["classification_model"] = "invalid_model"
error_message = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=error_message):
RESISC45ClassificationTask(**default_config)

def test_invalid_loss(self, default_config: Dict[str, Any]) -> None:
default_config["loss"] = "invalid_loss"
error_message = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=error_message):
RESISC45ClassificationTask(**default_config)
25 changes: 25 additions & 0 deletions torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import rasterio
import torch
from torch import Tensor
from torch.utils.data import Dataset, Subset, random_split
from torchvision.datasets.utils import check_integrity, download_url

__all__ = (
Expand All @@ -30,6 +31,7 @@
"working_dir",
"collate_dict",
"rasterio_loader",
"dataset_split",
)


Expand Down Expand Up @@ -394,3 +396,26 @@ def rasterio_loader(path: str) -> np.ndarray: # type: ignore[type-arg]
# VisionClassificationDataset expects images returned with channels last (HWC)
array = array.transpose(1, 2, 0)
return array


def dataset_split(
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
dataset: Dataset[Any], val_pct: float, test_pct: Optional[float] = None
) -> List[Subset[Any]]:
"""Split a torch Dataset into train/val/test sets.

Args:
dataset: dataset to be split into train/val or train/val/test subsets
val_pct: percentage of samples to be in validation set
test_pct: (Optional) percentage of samples to be in test set
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved
Returns:
a list of the subset datasets. Either [train, val] or [train, val, test]
"""
if test_pct is None:
val_length = int(len(dataset) * val_pct) # type: ignore[arg-type]
train_length = len(dataset) - val_length # type: ignore[arg-type]
return random_split(dataset, [train_length, val_length])
else:
val_length = int(len(dataset) * val_pct) # type: ignore[arg-type]
test_length = int(len(dataset) * test_pct) # type: ignore[arg-type]
train_length = len(dataset) - (val_length + test_length) # type: ignore[arg-type] # noqa: E501
return random_split(dataset, [train_length, val_length, test_length])
3 changes: 3 additions & 0 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .cyclone import CycloneDataModule, CycloneSimpleRegressionTask
from .landcoverai import LandcoverAIDataModule, LandcoverAISegmentationTask
from .naipchesapeake import NAIPChesapeakeDataModule, NAIPChesapeakeSegmentationTask
from .resisc45 import RESISC45ClassificationTask, RESISC45DataModule
from .sen12ms import SEN12MSDataModule, SEN12MSSegmentationTask
from .so2sat import So2SatClassificationTask, So2SatDataModule

Expand All @@ -21,6 +22,8 @@
"LandcoverAISegmentationTask",
"NAIPChesapeakeDataModule",
"NAIPChesapeakeSegmentationTask",
"RESISC45ClassificationTask",
"RESISC45DataModule",
"SEN12MSDataModule",
"SEN12MSSegmentationTask",
"So2SatDataModule",
Expand Down
Loading