Skip to content

Commit

Permalink
Override inference parameters (model architecture, num bands, num cla…
Browse files Browse the repository at this point in the history
…sses) from checkpoint (NRCan#298)

* fixes NRCan#293 NRCan#246
add tests for optimizer instantiation in test_optimizers.py
adapt our unet models (models/unet.py) to expect same parameter names as smp models

* minor typo fixes

* implement overriding model params from checkpoint with minimal error handling for checkpoints from different gdl versions
fixes NRCan#183

* name model yamls as close as possible to upcoming naming convention

* fix model name

* implement overriding model params from checkpoint with minimal error handling for checkpoints from different gdl versions
fixes NRCan#183

* small bugfix for pointing to parameters inside checkpoint

* model_choice.py: add update checkpoint utility

* fixes NRCan#293 NRCan#246
add tests for optimizer instantiation in test_optimizers.py
adapt our unet models (models/unet.py) to expect same parameter names as smp models

* minor typo fixes

* name model yamls as close as possible to upcoming naming convention

* small bugfix for pointing to parameters inside checkpoint

* remove deeplabv3 dualhead warning and add link for deeplabv3_dualhead.py

* fixes NRCan#293 NRCan#246
add tests for optimizer instantiation in test_optimizers.py
adapt our unet models (models/unet.py) to expect same parameter names as smp models

* name model yamls as close as possible to upcoming naming convention

* minor typo fixes

* update to PR 295

* GDL.py: restore to previous commit based on cauthier's comment
  • Loading branch information
remtav committed Jul 5, 2022
1 parent 096575e commit 5b6cc3c
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 37 deletions.
3 changes: 2 additions & 1 deletion config/model/smp_unet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
model:
_target_: segmentation_models_pytorch.Unet
encoder_name: resnext50_32x4d
encoder_depth: 5
encoder_depth: 4
encoder_weights: imagenet
decoder_channels: [ 256, 128, 64, 32 ]
58 changes: 48 additions & 10 deletions inference_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
from math import sqrt
from typing import List
from typing import List, Union, Sequence

import torch
import torch.nn.functional as F
Expand All @@ -17,10 +17,11 @@
from rasterio.windows import Window
from rasterio.plot import reshape_as_image
from pathlib import Path
from omegaconf import OmegaConf, DictConfig, open_dict
from omegaconf.listconfig import ListConfig

from utils.logger import get_logger, set_tracker
from models.model_choice import define_model
from models.model_choice import define_model, read_checkpoint
from utils import augmentation
from utils.utils import get_device_ids, get_key_def, \
list_input_images, add_metadata_from_raster_to_sample, _window_2D, read_modalities, set_device
Expand Down Expand Up @@ -282,7 +283,36 @@ def calc_inference_chunk_size(gpu_devices_dict: dict, max_pix_per_mb_gpu: int =
return max_chunk_size_rd


def main(params: dict) -> None:
def override_model_params_from_checkpoint(
params: DictConfig,
checkpoint_params):
"""
Overrides model-architecture related parameters from provided checkpoint parameters
@param params: Original parameters as inputted through hydra
@param checkpoint_params: Checkpoint parameters as saved during checkpoint creation when training
@return:
"""
modalities = get_key_def('modalities', params['dataset'], expected_type=Sequence)
classes = get_key_def('classes_dict', params['dataset'], expected_type=(dict, DictConfig))

modalities_ckpt = get_key_def('modalities', checkpoint_params['dataset'], expected_type=Sequence)
classes_ckpt = get_key_def('classes_dict', checkpoint_params['dataset'], expected_type=(dict, DictConfig))
model_ckpt = get_key_def('model', checkpoint_params, expected_type=(dict, DictConfig))

if model_ckpt != params.model or classes_ckpt != classes or modalities_ckpt != modalities:
logging.warning(f"\nParameters from checkpoint will override inputted parameters."
f"\n\t\t\t Inputted | Overriden"
f"\nModel:\t\t {params.model} | {model_ckpt}"
f"\nInput bands:\t\t{modalities} | {modalities_ckpt}"
f"\nOutput classes:\t\t{classes} | {classes_ckpt}")
with open_dict(params):
OmegaConf.update(params, 'dataset.modalities', modalities_ckpt)
OmegaConf.update(params, 'dataset.classes_dict', classes_ckpt)
OmegaConf.update(params, 'model', model_ckpt)
return params


def main(params: Union[DictConfig, dict]) -> None:
"""
Function to manage details about the inference on segmentation task.
1. Read the parameters from the config given.
Expand All @@ -291,22 +321,31 @@ def main(params: dict) -> None:
-------
:param params: (dict) Parameters inputted during execution.
"""
# PARAMETERS
num_classes = len(get_key_def('classes_dict', params['dataset']).keys())
# SETTING OUTPUT DIRECTORY
state_dict = get_key_def('state_dict_path', params['inference'], to_path=True, validate_path_exists=True)

# Override params from checkpoint
checkpoint = read_checkpoint(state_dict)
params = override_model_params_from_checkpoint(
params=params,
checkpoint_params=checkpoint['params']
)

# Dataset params
modalities = get_key_def('modalities', params['dataset'], default=("red", "blue", "green"), expected_type=Sequence)
classes_dict = get_key_def('classes_dict', params['dataset'], expected_type=DictConfig)
num_classes = len(classes_dict)
num_classes = num_classes + 1 if num_classes > 1 else num_classes # multiclass account for background
modalities = read_modalities(get_key_def('modalities', params['dataset'], expected_type=str))
num_bands = len(modalities)
BGR_to_RGB = get_key_def('BGR_to_RGB', params['dataset'], expected_type=bool)

# SETTING OUTPUT DIRECTORY
state_dict = get_key_def('state_dict_path', params['inference'], to_path=True, validate_path_exists=True)
working_folder = state_dict.parent.joinpath(f'inference_{num_bands}bands')
logging.info("\nThe state dict path directory used '{}'".format(working_folder))
Path.mkdir(working_folder, parents=True, exist_ok=True)
logging.info(f'\nInferences will be saved to: {working_folder}\n\n')
# Default input directory based on default output directory
img_dir_or_csv = get_key_def('img_dir_or_csv_file', params['inference'], default=working_folder,
expected_type=str, to_path=True, validate_path_exists=True)
BGR_to_RGB = get_key_def('BGR_to_RGB', params['dataset'], expected_type=bool)

# LOGGING PARAMETERS
exper_name = get_key_def('project_name', params['general'], default='gdl-training')
Expand Down Expand Up @@ -345,7 +384,6 @@ def main(params: dict) -> None:
bucket = None
bucket_name = get_key_def('bucket_name', params['AWS'], default=None)

# CONFIGURE MODEL
model = define_model(
net_params=params.model,
in_channels=num_bands,
Expand Down
2 changes: 2 additions & 0 deletions models/deeplabv3_dualhead.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# See: https://www.azavea.com/blog/2019/08/30/transfer-learning-from-rgb-to-multi-band-imagery/

import logging
from typing import Optional

Expand Down
19 changes: 6 additions & 13 deletions models/model_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
import torch.nn as nn

from utils.utils import update_gdl_checkpoint

logging.getLogger(__name__)


Expand All @@ -23,7 +25,7 @@ def define_model_architecture(
return instantiate(net_params, in_channels=in_channels, classes=out_classes)


def read_checkpoint(filename):
def read_checkpoint(filename, update=True):
"""
Loads checkpoint from provided path to GDL's expected format,
ie model's state dictionary should be under "model_state_dict" and
Expand All @@ -39,7 +41,7 @@ def read_checkpoint(filename):
logging.info(f"\n=> loading model '{filename}'")
# For loading external models with different structure in state dict.
checkpoint = torch.load(filename, map_location='cpu')
if 'model_state_dict' not in checkpoint.keys():
if 'model_state_dict' not in checkpoint.keys() and 'model' not in checkpoint.keys():
val_set = set()
for val in checkpoint.values():
val_set.add(type(val))
Expand All @@ -49,19 +51,10 @@ def read_checkpoint(filename):
new_checkpoint['model_state_dict'] = OrderedDict({k: v for k, v in checkpoint.items()})
del checkpoint
checkpoint = new_checkpoint
# Covers gdl's checkpoints at version <=2.0.1
elif 'model' in checkpoint.keys():
checkpoint['model_state_dict'] = checkpoint['model']
del checkpoint['model']
else:
raise ValueError(f"GDL cannot find weight in provided checkpoint")
if 'optimizer_state_dict' not in checkpoint.keys():
try:
# Covers gdl's checkpoints at version <=2.0.1
checkpoint['optimizer_state_dict'] = checkpoint['optimizer']
del checkpoint['optimizer']
except KeyError:
logging.critical(f"No optimizer state dictionary was found in provided checkpoint")
elif update:
checkpoint = update_gdl_checkpoint(checkpoint)
return checkpoint
except FileNotFoundError:
raise logging.critical(FileNotFoundError(f"\n=> No model found at '{filename}'"))
Expand Down
10 changes: 7 additions & 3 deletions tests/model/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from hydra.utils import to_absolute_path, instantiate
from torch import nn

import models.unet
from models import unet
from models.model_choice import read_checkpoint, adapt_checkpoint_to_dp_model, define_model, define_model_architecture
from utils.utils import get_device_ids, set_device
Expand Down Expand Up @@ -53,14 +54,17 @@ class TestReadCheckpoint(object):
"""
Tests reading a checkpoint saved outside GDL into memory
"""
dummy_model = torchvision.models.resnet18()
dummy_optimizer = optimizer = instantiate({'_target_': 'torch.optim.Adam'}, params=dummy_model.parameters())
var = 4
dummy_model = models.unet.UNetSmall(classes=var, in_channels=var)
dummy_optimizer = instantiate({'_target_': 'torch.optim.Adam'}, params=dummy_model.parameters())
filename = "test.pth.tar"
torch.save(dummy_model.state_dict(), filename)
read_checkpoint(filename)
# test gdl's checkpoints at version <=2.0.1
torch.save({'epoch': 999,
'params': {'model': 'resnet18'},
'params': {
'global': {'num_classes': var, 'model_name': 'unet_small', 'number_of_bands': var}
},
'model': dummy_model.state_dict(),
'best_loss': 0.1,
'optimizer': dummy_optimizer.state_dict()}, filename)
Expand Down
9 changes: 0 additions & 9 deletions train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,15 +780,6 @@ def main(cfg: DictConfig) -> None:
-------
:param cfg: (dict) Parameters found in the yaml config file.
"""
# Limit of the NIR implementation
# FIXME: keep this warning?
# if 'deeplabv3' not in cfg.model.model_name and 'IR' in read_modalities(cfg.dataset.modalities):
# logging.info(
# '\nThe NIR modality will be fed at first layer of model alongside other bands,'
# '\nthe implementation of concatenation point at an intermediary layer is only available'
# '\nfor the deeplabv3 model for now. \nMore will follow on demand.'
# )

# Preprocessing
# HERE the code to do for the preprocessing for the segmentation

Expand Down
78 changes: 77 additions & 1 deletion utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import subprocess
from functools import reduce
from pathlib import Path
from typing import Sequence, List
from typing import Sequence, List, Dict

from hydra.utils import to_absolute_path
from pytorch_lightning.utilities import rank_zero_only
Expand Down Expand Up @@ -629,3 +629,79 @@ def print_config(

with open("run_config.config", "w") as fp:
rich.print(tree, file=fp)


def update_gdl_checkpoint(checkpoint_params: Dict) -> Dict:
"""
Utility to update model checkpoints from older versions of GDL to current version
@param checkpoint_params:
Dictionary containing weights, optimizer state and saved configuration params from training
@return:
"""
# covers gdl checkpoints from version <= 2.0.1
if 'model' in checkpoint_params.keys():
checkpoint_params['model_state_dict'] = checkpoint_params['model']
del checkpoint_params['model']
if 'optimizer' in checkpoint_params.keys():
checkpoint_params['optimizer_state_dict'] = checkpoint_params['optimizer']
del checkpoint_params['optimizer']

# covers gdl checkpoints pre-hydra (<=2.0.0)
bands = ['R', 'G', 'B', 'N']
old2new = {
'manet_pretrained': {
'_target_': 'segmentation_models_pytorch.MAnet', 'encoder_name': 'resnext50_32x4d',
'encoder_weights': 'imagenet'
},
'unet_pretrained': {
'_target_': 'segmentation_models_pytorch.Unet', 'encoder_name': 'resnext50_32x4d',
'encoder_depth': 4, 'encoder_weights': 'imagenet', 'decoder_channels': [256, 128, 64, 32]
},
'unet': {
'_target_': 'models.unet.UNet', 'dropout': False, 'prob': False
},
'unet_small': {
'_target_': 'models.unet.UNetSmall', 'dropout': False, 'prob': False
},
'deeplabv3_pretrained': {
'_target_': 'segmentation_models_pytorch.DeepLabV3', 'encoder_name': 'resnet101',
'encoder_weights': 'imagenet'
},
'deeplabv3_resnet101_dualhead': {
'_target_': 'models.deeplabv3_dualhead.DeepLabV3_dualhead', 'conc_point': 'conv1',
'encoder_weights': 'imagenet'
},
'deeplabv3+_pretrained': {
'_target_': 'segmentation_models_pytorch.DeepLabV3Plus', 'encoder_name': 'resnext50_32x4d',
'encoder_weights': 'imagenet'
},
}
try:
# don't update if already a recent checkpoint
get_key_def('classes_dict', checkpoint_params['params']['dataset'], expected_type=(dict, DictConfig))
get_key_def('modalities', checkpoint_params['params']['dataset'], expected_type=Sequence)
get_key_def('model', checkpoint_params['params'], expected_type=(dict, DictConfig))
return checkpoint_params
except KeyError:
num_classes_ckpt = get_key_def('num_classes', checkpoint_params['params']['global'], expected_type=int)
num_bands_ckpt = get_key_def('number_of_bands', checkpoint_params['params']['global'], expected_type=int)
model_name = get_key_def('model_name', checkpoint_params['params']['global'], expected_type=str)
try:
model_ckpt = old2new[model_name]
except KeyError as e:
logging.critical(f"\nCouldn't locate yaml configuration for model architecture {model_name} as found "
f"in provided checkpoint. Name of yaml may have changed."
f"\nError {type(e)}: {e}")
raise e
# For GDL pre-v2.0.2
#bands_ckpt = ''
#bands_ckpt = bands_ckpt.join([bands[i] for i in range(num_bands_ckpt)])
checkpoint_params['params'].update({
'dataset': {
'modalities': [bands[i] for i in range(num_bands_ckpt)], #bands_ckpt,
#"classes_dict": {f"BUIL": 1}
"classes_dict": {f"class{i + 1}": i + 1 for i in range(num_classes_ckpt)}
}
})
checkpoint_params['params'].update({'model': model_ckpt})
return checkpoint_params

0 comments on commit 5b6cc3c

Please sign in to comment.