Skip to content

Commit

Permalink
Use hydra for instantiating models and optimizers (#295)
Browse files Browse the repository at this point in the history
* fixes #293 #246
add tests for optimizer instantiation in test_optimizers.py
adapt our unet models (models/unet.py) to expect same parameter names as smp models

* minor typo fixes

* name model yamls as close as possible to upcoming naming convention

* fix model name
  • Loading branch information
remtav authored Mar 24, 2022
1 parent ef96eb6 commit a1e322d
Show file tree
Hide file tree
Showing 37 changed files with 187 additions and 516 deletions.
2 changes: 1 addition & 1 deletion config/gdl_config_template.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- model: unet
- model: gdl_unet
- training: default_training
- loss: binary/softbce
- optimizer: adamw
Expand Down
5 changes: 0 additions & 5 deletions config/model/checkpoint_unet.yaml

This file was deleted.

5 changes: 0 additions & 5 deletions config/model/deeplabv3+_pretrained.yaml

This file was deleted.

5 changes: 0 additions & 5 deletions config/model/deeplabv3_pretrained.yaml

This file was deleted.

6 changes: 0 additions & 6 deletions config/model/deeplabv3_resnet101_dualhead.yaml

This file was deleted.

5 changes: 5 additions & 0 deletions config/model/gdl_deeplabv3-dualhead.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package _global_
model:
_target_: models.deeplabv3_dualhead.DeepLabV3_dualhead
conc_point: conv1 # Choose concatenation point in resnet encoder: 'conv1', 'maxpool', 'layer2', 'layer3', 'layer4'
encoder_weights: imagenet
5 changes: 5 additions & 0 deletions config/model/gdl_unet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package _global_
model:
_target_: models.unet.UNet
dropout: False # (bool) Use dropout or not
prob: False # (float) Set dropout probability, e.g. 0.5
6 changes: 6 additions & 0 deletions config/model/gdl_unetsmall.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# @package _global_
model:
_target_: models.unet.UNetSmall
dropout: False # (bool) Use dropout or not
prob: False # (float) Set dropout probability, e.g. 0.5

5 changes: 5 additions & 0 deletions config/model/smp_deeplabv3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package _global_
model:
_target_: segmentation_models_pytorch.DeepLabV3
encoder_name: resnet101
encoder_weights: imagenet
5 changes: 5 additions & 0 deletions config/model/smp_deeplabv3plus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package _global_
model:
_target_: segmentation_models_pytorch.DeepLabV3Plus
encoder_name: resnext50_32x4d
encoder_weights: imagenet
5 changes: 5 additions & 0 deletions config/model/smp_manet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package _global_
model:
_target_: segmentation_models_pytorch.MAnet
encoder_name: resnext101_32x8d
encoder_weights: imagenet
5 changes: 5 additions & 0 deletions config/model/smp_pan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package _global_
model:
_target_: segmentation_models_pytorch.PAN
encoder_name: se_resnext101_32x4d
encoder_weights: imagenet
7 changes: 7 additions & 0 deletions config/model/smp_unet-resnext101.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# @package _global_
model:
_target_: segmentation_models_pytorch.Unet
encoder_name: resnext101_32x8d
encoder_depth: 4
decoder_channels: [ 256, 128, 64, 32 ]
encoder_weights: imagenet
5 changes: 5 additions & 0 deletions config/model/smp_unet-spacenet-baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package _global_
model:
_target_: segmentation_models_pytorch.Unet
encoder_name: vgg11
encoder_weights: imagenet
6 changes: 6 additions & 0 deletions config/model/smp_unet-spacenet-efficientnetb5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# @package _global_
model:
_target_: segmentation_models_pytorch.Unet
encoder_name: efficientnet-b5
encoder_weights: imagenet

5 changes: 5 additions & 0 deletions config/model/smp_unet-spacenet-senet154.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# @package _global_
model:
_target_: segmentation_models_pytorch.Unet
encoder_name: senet154
encoder_weights: imagenet
6 changes: 6 additions & 0 deletions config/model/smp_unet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# @package _global_
model:
_target_: segmentation_models_pytorch.Unet
encoder_name: resnext50_32x4d
encoder_depth: 5
encoder_weights: imagenet
8 changes: 8 additions & 0 deletions config/model/smp_unetplusplus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# @package _global_
model:
_target_: segmentation_models_pytorch.UnetPlusPlus
encoder_name: se_resnext50_32x4d
encoder_depth: 4
decoder_channels: [256, 128, 64, 32]
decoder_attention_type: scse
encoder_weights: imagenet
4 changes: 0 additions & 4 deletions config/model/unet.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions config/model/unet_pretrained.yaml

This file was deleted.

5 changes: 0 additions & 5 deletions config/model/unet_small.yaml

This file was deleted.

121 changes: 0 additions & 121 deletions config/old_development/config_test_4channels_implementation.yaml

This file was deleted.

7 changes: 7 additions & 0 deletions config/optimizer/adabound.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# @package _global_
optimizer:
_target_: utils.adabound.AdaBound
lr: ${training.lr}
gamma: 1e-3
eps: 1e-8
weight_decay: 4e-5
7 changes: 7 additions & 0 deletions config/optimizer/adaboundw.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# @package _global_
optimizer:
_target_: utils.adabound.AdaBoundW
lr: ${training.lr}
gamma: 1e-3
eps: 1e-8
weight_decay: 4e-5
6 changes: 2 additions & 4 deletions config/optimizer/adam.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# @package _global_
optimizer:
optimizer_name: 'adam'
class_name: torch.optim.Adam
params:
lr: ${training.lr}
_target_: torch.optim.Adam
lr: ${training.lr}
8 changes: 3 additions & 5 deletions config/optimizer/adamw.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# @package _global_
optimizer:
optimizer_name: 'adamw'
class_name: torch.optim.AdamW
params:
lr: ${training.lr}
weight_decay: 0.001
_target_: torch.optim.AdamW
lr: ${training.lr}
weight_decay: 4e-5
10 changes: 4 additions & 6 deletions config/optimizer/sgd.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# @package _global_
optimizer:
optimizer_name: 'sgd'
class_name: torch.optim.SGD
params:
lr: ${training.lr}
momentum: 0.9
weight_decay: 0.0005
_target_: torch.optim.SGD
lr: ${training.lr}
momentum: 0.9
weight_decay: 4e-5
8 changes: 3 additions & 5 deletions inference_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def main(params: dict) -> None:
:param params: (dict) Parameters inputted during execution.
"""
# PARAMETERS
model_name = get_key_def('model_name', params['model'], expected_type=str).lower()
num_classes = len(get_key_def('classes_dict', params['dataset']).keys())
num_classes = num_classes + 1 if num_classes > 1 else num_classes # multiclass account for background
modalities = read_modalities(get_key_def('modalities', params['dataset'], expected_type=str))
Expand Down Expand Up @@ -348,10 +347,9 @@ def main(params: dict) -> None:

# CONFIGURE MODEL
model = define_model(
model_name=model_name,
num_bands=num_bands,
num_classes=num_classes,
conc_point=conc_point,
net_params=params.model,
in_channels=num_bands,
out_classes=num_classes,
main_device=device,
devices=[list(gpu_devices_dict.keys())],
state_dict_path=state_dict,
Expand Down
Loading

0 comments on commit a1e322d

Please sign in to comment.