Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change "classification_model" to "model" #916

Merged
merged 8 commits into from
Dec 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conf/bigearthnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
task: "bigearthnet"
module:
loss: "bce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion conf/eurosat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "eurosat"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion conf/resisc45.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
task: "resisc45"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion conf/so2sat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
task: "so2sat"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion conf/ucmerced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "ucmerced"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
Expand Down
4 changes: 2 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,15 @@ def main(args: argparse.Namespace) -> None:
if issubclass(TASK, ClassificationTask):
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
val_row = {
"split": "val",
"classification_model": hparams["classification_model"],
"model": hparams["model"],
"learning_rate": hparams["learning_rate"],
"weights": hparams["weights"],
"loss": hparams["loss"],
}

test_row = {
"split": "test",
"classification_model": hparams["classification_model"],
"model": hparams["model"],
"learning_rate": hparams["learning_rate"],
"weights": hparams["weights"],
"loss": hparams["loss"],
Expand Down
2 changes: 1 addition & 1 deletion experiments/run_resisc45_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"python train.py"
+ f" config_file={config_file}"
+ f" experiment.name={experiment_name}"
+ f" experiment.module.classification_model={model}"
+ f" experiment.module.model={model}"
+ f" experiment.module.learning_rate={lr}"
+ f" experiment.module.loss={loss}"
+ f" experiment.module.weights={weights}"
Expand Down
2 changes: 1 addition & 1 deletion experiments/run_so2sat_byol_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"python train.py"
+ f" config_file={config_file}"
+ f" experiment.name={experiment_name}"
+ f" experiment.module.classification_model={model}"
+ f" experiment.module.model={model}"
+ f" experiment.module.learning_rate={lr}"
+ f" experiment.module.loss={loss}"
+ f" experiment.module.weights={weights}"
Expand Down
2 changes: 1 addition & 1 deletion experiments/run_so2sat_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"python train.py"
+ f" config_file={config_file}"
+ f" experiment.name={experiment_name}"
+ f" experiment.module.classification_model={model}"
+ f" experiment.module.model={model}"
+ f" experiment.module.learning_rate={lr}"
+ f" experiment.module.loss={loss}"
+ f" experiment.module.weights={weights}"
Expand Down
2 changes: 1 addition & 1 deletion experiments/run_so2sat_seed_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"python train.py"
+ f" config_file={config_file}"
+ f" experiment.name={experiment_name}"
+ f" experiment.module.classification_model={model}"
+ f" experiment.module.model={model}"
+ f" experiment.module.learning_rate={lr}"
+ f" experiment.module.loss={loss}"
+ f" experiment.module.weights={weights}"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/bigearthnet_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "bigearthnet"
module:
loss: "bce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/bigearthnet_s1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "bigearthnet"
module:
loss: "bce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/bigearthnet_s2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "bigearthnet"
module:
loss: "bce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/eurosat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "eurosat"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/resisc45.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "resisc45"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/so2sat_supervised.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "so2sat"
module:
loss: "focal"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/so2sat_unsupervised.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "so2sat"
module:
loss: "jaccard"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/ucmerced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "ucmerced"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
weights: "random"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
Expand Down
4 changes: 2 additions & 2 deletions tests/trainers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def state_dict(model: Module) -> Dict[str, Tensor]:
return model.state_dict()


@pytest.fixture(params=["classification_model", "encoder_name"])
@pytest.fixture(params=["model", "encoder_name"])
def checkpoint(
state_dict: Dict[str, Tensor], request: SubRequest, tmp_path: Path
) -> str:
if request.param == "classification_model":
if request.param == "model":
state_dict = OrderedDict({"model." + k: v for k, v in state_dict.items()})
else:
state_dict = OrderedDict(
Expand Down
8 changes: 4 additions & 4 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_no_logger(self) -> None:
@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
return {
"classification_model": "resnet18",
"model": "resnet18",
"in_channels": 13,
"loss": "ce",
"num_classes": 10,
Expand All @@ -101,7 +101,7 @@ def test_invalid_pretrained(
self, model_kwargs: Dict[Any, Any], checkpoint: str
) -> None:
model_kwargs["weights"] = checkpoint
model_kwargs["classification_model"] = "resnet50"
model_kwargs["model"] = "resnet50"
match = "Trying to load resnet18 weights into a resnet50"
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)
Expand All @@ -113,7 +113,7 @@ def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None:
ClassificationTask(**model_kwargs)

def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["classification_model"] = "invalid_model"
model_kwargs["model"] = "invalid_model"
match = "Model type 'invalid_model' is not a valid timm model."
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_no_logger(self) -> None:
@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
return {
"classification_model": "resnet18",
"model": "resnet18",
"in_channels": 14,
"loss": "bce",
"num_classes": 19,
Expand Down
5 changes: 1 addition & 4 deletions tests/trainers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,7 @@ def test_extract_encoder_unsupported_model(tmp_path: Path) -> None:
checkpoint = {"hyper_parameters": {"some_unsupported_model": "resnet18"}}
path = os.path.join(str(tmp_path), "dummy.ckpt")
torch.save(checkpoint, path)
err = (
"Unknown checkpoint task. Only encoder or classification_model"
" extraction is supported"
)
err = "Unknown checkpoint task. Only encoder or model extraction is supported"
with pytest.raises(ValueError, match=err):
extract_encoder(path)

Expand Down
32 changes: 18 additions & 14 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@ class ClassificationTask(pl.LightningModule):

Supports any available `Timm model
<https://rwightman.github.io/pytorch-image-models/>`_
as an architecture choice. To see a list of available pretrained
as an architecture choice. To see a list of available
models, you can do:

.. code-block:: python

import timm
print(timm.list_models(pretrained=True))
print(timm.list_models())
"""

def config_model(self) -> None:
"""Configures the model based on kwargs parameters passed to the constructor."""
in_channels = self.hyperparams["in_channels"]
classification_model = self.hyperparams["classification_model"]
model = self.hyperparams["model"]

imagenet_pretrained = False
custom_pretrained = False
Expand All @@ -68,26 +68,24 @@ def config_model(self) -> None:
custom_pretrained = True

# Create the model
valid_models = timm.list_models(pretrained=True)
if classification_model in valid_models:
valid_models = timm.list_models(pretrained=imagenet_pretrained)
if model in valid_models:
self.model = timm.create_model(
classification_model,
model,
num_classes=self.hyperparams["num_classes"],
in_chans=in_channels,
pretrained=imagenet_pretrained,
)
else:
raise ValueError(
f"Model type '{classification_model}' is not a valid timm model."
)
raise ValueError(f"Model type '{model}' is not a valid timm model.")

if custom_pretrained:
name, state_dict = utils.extract_encoder(self.hyperparams["weights"])

if self.hyperparams["classification_model"] != name:
if self.hyperparams["model"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hyperparams['classification_model']}"
f"{self.hyperparams['model']}"
)
self.model = utils.load_state_dict(self.model, state_dict)

Expand All @@ -108,13 +106,16 @@ def __init__(self, **kwargs: Any) -> None:
"""Initialize the LightningModule with a model and loss function.

Keyword Args:
classification_model: Name of the classification model use
loss: Name of the loss function
model: Name of the classification model use
loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal'
weights: Either "random" or "imagenet"
num_classes: Number of prediction classes
in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

.. versionchanged:: 0.4
The *classification_model* parameter was renamed to *model*.
"""
super().__init__()

Expand Down Expand Up @@ -313,13 +314,16 @@ def __init__(self, **kwargs: Any) -> None:
"""Initialize the LightningModule with a model and loss function.

Keyword Args:
classification_model: Name of the classification model use
model: Name of the classification model use
loss: Name of the loss function, currently only supports 'bce'
weights: Either "random" or 'imagenet'
num_classes: Number of prediction classes
in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler

.. versionchanged:: 0.4
The *classification_model* parameter was renamed to *model*.
"""
super().__init__(**kwargs)

Expand Down
10 changes: 4 additions & 6 deletions torchgeo/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]:
tuple containing model name and state dict

Raises:
ValueError: if 'classification_model' or 'encoder' not in
ValueError: if 'model' or 'encoder' not in
checkpoint['hyper_parameters']
"""
checkpoint = torch.load(path, map_location=torch.device("cpu"))

if "classification_model" in checkpoint["hyper_parameters"]:
name = checkpoint["hyper_parameters"]["classification_model"]
if "model" in checkpoint["hyper_parameters"]:
name = checkpoint["hyper_parameters"]["model"]
state_dict = checkpoint["state_dict"]
state_dict = OrderedDict({k: v for k, v in state_dict.items() if "model." in k})
state_dict = OrderedDict(
Expand All @@ -51,8 +50,7 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]:
)
else:
raise ValueError(
"Unknown checkpoint task. Only encoder or classification_model"
" extraction is supported"
"Unknown checkpoint task. Only encoder or model extraction is supported"
)

return name, state_dict
Expand Down