Skip to content

Commit

Permalink
Change "classification_model" to "model" (microsoft#916)
Browse files Browse the repository at this point in the history
* name change

* fix failing test

* expose all available timm models

* chmod

* imagenet pretrained flag

* remove extra

* docstring list_models
  • Loading branch information
nilsleh authored Dec 4, 2022
1 parent bc8326a commit afd1cca
Show file tree
Hide file tree
Showing 23 changed files with 48 additions and 49 deletions.
2 changes: 1 addition & 1 deletion conf/bigearthnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
task: "bigearthnet"
module:
loss: "bce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion conf/eurosat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "eurosat"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion conf/resisc45.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
task: "resisc45"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion conf/so2sat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
task: "so2sat"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion conf/ucmerced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "ucmerced"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
Expand Down
4 changes: 2 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ def main(args: argparse.Namespace) -> None:
if issubclass(TASK, ClassificationTask):
val_row = {
"split": "val",
"classification_model": hparams["classification_model"],
"model": hparams["model"],
"learning_rate": hparams["learning_rate"],
"weights": hparams["weights"],
"loss": hparams["loss"],
}

test_row = {
"split": "test",
"classification_model": hparams["classification_model"],
"model": hparams["model"],
"learning_rate": hparams["learning_rate"],
"weights": hparams["weights"],
"loss": hparams["loss"],
Expand Down
2 changes: 1 addition & 1 deletion experiments/run_resisc45_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"python train.py"
+ f" config_file={config_file}"
+ f" experiment.name={experiment_name}"
+ f" experiment.module.classification_model={model}"
+ f" experiment.module.model={model}"
+ f" experiment.module.learning_rate={lr}"
+ f" experiment.module.loss={loss}"
+ f" experiment.module.weights={weights}"
Expand Down
2 changes: 1 addition & 1 deletion experiments/run_so2sat_byol_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"python train.py"
+ f" config_file={config_file}"
+ f" experiment.name={experiment_name}"
+ f" experiment.module.classification_model={model}"
+ f" experiment.module.model={model}"
+ f" experiment.module.learning_rate={lr}"
+ f" experiment.module.loss={loss}"
+ f" experiment.module.weights={weights}"
Expand Down
2 changes: 1 addition & 1 deletion experiments/run_so2sat_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"python train.py"
+ f" config_file={config_file}"
+ f" experiment.name={experiment_name}"
+ f" experiment.module.classification_model={model}"
+ f" experiment.module.model={model}"
+ f" experiment.module.learning_rate={lr}"
+ f" experiment.module.loss={loss}"
+ f" experiment.module.weights={weights}"
Expand Down
2 changes: 1 addition & 1 deletion experiments/run_so2sat_seed_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"python train.py"
+ f" config_file={config_file}"
+ f" experiment.name={experiment_name}"
+ f" experiment.module.classification_model={model}"
+ f" experiment.module.model={model}"
+ f" experiment.module.learning_rate={lr}"
+ f" experiment.module.loss={loss}"
+ f" experiment.module.weights={weights}"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/bigearthnet_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "bigearthnet"
module:
loss: "bce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/bigearthnet_s1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "bigearthnet"
module:
loss: "bce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/bigearthnet_s2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "bigearthnet"
module:
loss: "bce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/eurosat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "eurosat"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/resisc45.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "resisc45"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/so2sat_supervised.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "so2sat"
module:
loss: "focal"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/so2sat_unsupervised.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "so2sat"
module:
loss: "jaccard"
classification_model: "resnet18"
model: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: "random"
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/ucmerced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ experiment:
task: "ucmerced"
module:
loss: "ce"
classification_model: "resnet18"
model: "resnet18"
weights: "random"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
Expand Down
4 changes: 2 additions & 2 deletions tests/trainers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def state_dict(model: Module) -> Dict[str, Tensor]:
return model.state_dict()


@pytest.fixture(params=["classification_model", "encoder_name"])
@pytest.fixture(params=["model", "encoder_name"])
def checkpoint(
state_dict: Dict[str, Tensor], request: SubRequest, tmp_path: Path
) -> str:
if request.param == "classification_model":
if request.param == "model":
state_dict = OrderedDict({"model." + k: v for k, v in state_dict.items()})
else:
state_dict = OrderedDict(
Expand Down
8 changes: 4 additions & 4 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_no_logger(self) -> None:
@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
return {
"classification_model": "resnet18",
"model": "resnet18",
"in_channels": 13,
"loss": "ce",
"num_classes": 10,
Expand All @@ -101,7 +101,7 @@ def test_invalid_pretrained(
self, model_kwargs: Dict[Any, Any], checkpoint: str
) -> None:
model_kwargs["weights"] = checkpoint
model_kwargs["classification_model"] = "resnet50"
model_kwargs["model"] = "resnet50"
match = "Trying to load resnet18 weights into a resnet50"
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)
Expand All @@ -113,7 +113,7 @@ def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None:
ClassificationTask(**model_kwargs)

def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None:
model_kwargs["classification_model"] = "invalid_model"
model_kwargs["model"] = "invalid_model"
match = "Model type 'invalid_model' is not a valid timm model."
with pytest.raises(ValueError, match=match):
ClassificationTask(**model_kwargs)
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_no_logger(self) -> None:
@pytest.fixture
def model_kwargs(self) -> Dict[Any, Any]:
return {
"classification_model": "resnet18",
"model": "resnet18",
"in_channels": 14,
"loss": "bce",
"num_classes": 19,
Expand Down
5 changes: 1 addition & 4 deletions tests/trainers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,7 @@ def test_extract_encoder_unsupported_model(tmp_path: Path) -> None:
checkpoint = {"hyper_parameters": {"some_unsupported_model": "resnet18"}}
path = os.path.join(str(tmp_path), "dummy.ckpt")
torch.save(checkpoint, path)
err = (
"Unknown checkpoint task. Only encoder or classification_model"
" extraction is supported"
)
err = "Unknown checkpoint task. Only encoder or model extraction is supported"
with pytest.raises(ValueError, match=err):
extract_encoder(path)

Expand Down
32 changes: 18 additions & 14 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@ class ClassificationTask(pl.LightningModule):
Supports any available `Timm model
<https://rwightman.github.io/pytorch-image-models/>`_
as an architecture choice. To see a list of available pretrained
as an architecture choice. To see a list of available
models, you can do:
.. code-block:: python
import timm
print(timm.list_models(pretrained=True))
print(timm.list_models())
"""

def config_model(self) -> None:
"""Configures the model based on kwargs parameters passed to the constructor."""
in_channels = self.hyperparams["in_channels"]
classification_model = self.hyperparams["classification_model"]
model = self.hyperparams["model"]

imagenet_pretrained = False
custom_pretrained = False
Expand All @@ -68,26 +68,24 @@ def config_model(self) -> None:
custom_pretrained = True

# Create the model
valid_models = timm.list_models(pretrained=True)
if classification_model in valid_models:
valid_models = timm.list_models(pretrained=imagenet_pretrained)
if model in valid_models:
self.model = timm.create_model(
classification_model,
model,
num_classes=self.hyperparams["num_classes"],
in_chans=in_channels,
pretrained=imagenet_pretrained,
)
else:
raise ValueError(
f"Model type '{classification_model}' is not a valid timm model."
)
raise ValueError(f"Model type '{model}' is not a valid timm model.")

if custom_pretrained:
name, state_dict = utils.extract_encoder(self.hyperparams["weights"])

if self.hyperparams["classification_model"] != name:
if self.hyperparams["model"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hyperparams['classification_model']}"
f"{self.hyperparams['model']}"
)
self.model = utils.load_state_dict(self.model, state_dict)

Expand All @@ -108,13 +106,16 @@ def __init__(self, **kwargs: Any) -> None:
"""Initialize the LightningModule with a model and loss function.
Keyword Args:
classification_model: Name of the classification model use
loss: Name of the loss function
model: Name of the classification model use
loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal'
weights: Either "random" or "imagenet"
num_classes: Number of prediction classes
in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler
.. versionchanged:: 0.4
The *classification_model* parameter was renamed to *model*.
"""
super().__init__()

Expand Down Expand Up @@ -313,13 +314,16 @@ def __init__(self, **kwargs: Any) -> None:
"""Initialize the LightningModule with a model and loss function.
Keyword Args:
classification_model: Name of the classification model use
model: Name of the classification model use
loss: Name of the loss function, currently only supports 'bce'
weights: Either "random" or 'imagenet'
num_classes: Number of prediction classes
in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler
.. versionchanged:: 0.4
The *classification_model* parameter was renamed to *model*.
"""
super().__init__(**kwargs)

Expand Down
10 changes: 4 additions & 6 deletions torchgeo/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]:
tuple containing model name and state dict
Raises:
ValueError: if 'classification_model' or 'encoder' not in
ValueError: if 'model' or 'encoder' not in
checkpoint['hyper_parameters']
"""
checkpoint = torch.load(path, map_location=torch.device("cpu"))

if "classification_model" in checkpoint["hyper_parameters"]:
name = checkpoint["hyper_parameters"]["classification_model"]
if "model" in checkpoint["hyper_parameters"]:
name = checkpoint["hyper_parameters"]["model"]
state_dict = checkpoint["state_dict"]
state_dict = OrderedDict({k: v for k, v in state_dict.items() if "model." in k})
state_dict = OrderedDict(
Expand All @@ -51,8 +50,7 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]:
)
else:
raise ValueError(
"Unknown checkpoint task. Only encoder or classification_model"
" extraction is supported"
"Unknown checkpoint task. Only encoder or model extraction is supported"
)

return name, state_dict
Expand Down

0 comments on commit afd1cca

Please sign in to comment.