-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use hydra to instantiate loss #256
Conversation
fix bug in clip_raster_with_gpkg
@@ -0,0 +1,4 @@ | |||
# @package _global_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we name this file ohemce.yaml or keep full name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree less long
@@ -8,7 +8,6 @@ training: | |||
max_epochs: ${general.max_epochs} | |||
min_epochs: ${general.min_epochs} | |||
num_workers: | |||
loss_fn: CrossEntropy # name of the metric you want to manage the training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, the loss parameter should stay in the training config, but I couldn't point this parameter to the one in config/loss/multiclass use hydra's ${...} syntax. I did try quite a bit. I'd be moving the loss parameter out of training if we can't find a way to make this work.
models/model_choice.py
Outdated
if loss_fn['_target_'] == 'torch.nn.CrossEntropyLoss': | ||
criterion = instantiate(loss_fn, weight=class_weights) # FIXME: unable to pass this through hydra |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really don't like this. CrossEntropy loss (multiclass) is the only loss that uses class_weights as a parameter. We had adapted all other losses to accept it but no use it. The binary losses that @victorlazio109 has added from smp don't accept the "weight" parameter and therefore raise an exception if this parameter is passed.
Logically, class_weights would have been in the crossentropy.yaml config directly (and not in model_choice). It could therefore be dealt with directly by hydra, but one things makes it a bit tricky:
Cross entropy in torch.nn only accepts a tensor. If it could handle a list, this would be much more convenient for hydra.
The only way hydra can directly input a tensor to the weight argument is by changing the CE yaml to somtehing like this:
# @package _global_
loss:
_target_: torch.nn.CrossEntropyLoss
ignore_index: ${dataset.ignore_index}
weight:
_target_: torch.tensor
data: ${dataset.class_weights}
This works fine when class_weights is not None (ex.: [0.5,0.5] successfully gets converted to tensor), but when it's None torch.tensor(None) raises an exception. It would require hydra to have some kind of if/else statement if class_weights is None.
Anyways, I'm leaving it to you guys if you have ideas. I've spent enough time on this case. If we don't use CE loss very often and especially if we don't use class_weights much, we could remove the "class_weights" parameter in the dataset config and the yaml could look like this :
# @package _global_
loss:
_target_: torch.nn.CrossEntropyLoss
ignore_index: ${dataset.ignore_index}
# comment out if using class_weights
# weight:
# _target_: torch.tensor
# data: [0.5, 0.5]
It's the only alternative I've found, but it is pretty hacky. On the long term, we could raise an issue on torch and ask them to accept a list for class_weights in the CE class.
train_segmentation.py
Outdated
if criterion.mode == 'binary': | ||
outputs = outputs[:, -1, ...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@victorlazio109 did you run into this "dimensionality" problem when using binary losses? When testing out, I noticed the binary losses only accept logits without the "backgound". Therefore, the prediction array must be indexed as shown in this bit of ugly code from [1,2,256,256] (each number being [batch size, num_classes, width, height]) to [1, 256, 256]. In other words, it becomes necessary to discard the "background" array and squeeze (in the numpy sense of things) the dimensions to match those of the labels [1, 256, 256]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved with my last PR, confirm if you can
# Conflicts: # config/training/default_training.yaml # losses/__init__.py # models/model_choice.py # train_segmentation.py
bugfixes
config/loss/binary/dice.yaml
Outdated
loss: | ||
_target_: segmentation_models_pytorch.losses.DiceLoss | ||
mode: binary | ||
ignore_index: ${dataset.ignore_index} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think these missing newlines will cause trouble. I can add them if anybody's nervous :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is it passed if removed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only script that is removed is losses/__init__.py. It is replaced by a colleciton of yamls in config/loss directory, one for each loss. Each yaml points to the class that needs to be instantiated by Hydra (ex.: segmentation_models_pytorch.losses.DiceLoss). It's the "hydranic" way of things that will also appear once we tackle issue #246
@@ -519,7 +519,7 @@ def train(cfg: DictConfig) -> None: | |||
|
|||
# MODEL PARAMETERS | |||
class_weights = get_key_def('class_weights', cfg['dataset'], default=None) | |||
loss_fn = get_key_def('loss_fn', cfg['training'], default='CrossEntropy') | |||
loss_fn = cfg.loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the get_key_def can't be used as is since the loss will later be instantiated with hydra. The loss parameter is not just a string that can default to "crossentropy"
remove weight parameter in losses that don't use it.
@@ -136,12 +135,11 @@ def set_hyperparameters(params, | |||
gamma = get_key_def('gamma', params['scheduler']['params'], 0.9) | |||
class_weights = torch.tensor(class_weights) if class_weights else None | |||
# Loss function | |||
if num_classes == 1: | |||
criterion = SingleClassCriterion(loss_type=loss_fn, ignore_index=dontcare_val) | |||
if loss_fn['_target_'] in ['torch.nn.CrossEntropyLoss', 'losses.focal_loss.FocalLoss', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We lose the differentiator between SingleClass and MultiClass here, if a user were to pass num_classes == 1 while picking CE loss he will almost certainly run into errors at some point in the code execution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I can add a small check for this.
in train_segmentation.py: one liner for loss = criterion
if not loss_fn.is_binary == (num_classes == 1): | ||
raise ValueError(f"A binary loss was chosen for a multiclass task") | ||
del loss_fn.is_binary # prevent exception at instantiation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@victorlazio109 this is the simplest way I've found to make sure binary losses are not called in multiclass problem and vice versa.
- remove ruamel_yaml import from active scripts - fix dontcare2background related to PR NRCan#256
* - remove unused functions - remove ruamel_yaml import from active scripts - fix dontcare2background related to PR #256 * remove data_analysis.py from GDL and move to internal gitlab.
* - remove unused functions - remove ruamel_yaml import from active scripts - fix dontcare2background related to PR #256 * - create set_device function: rom dictionary of available devices, sets the device to be used - check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug) * remove try/except statement for old HPC bug (device ordinal error)
* - remove unused functions - remove ruamel_yaml import from active scripts - fix dontcare2background related to PR #256 * - create set_device function: rom dictionary of available devices, sets the device to be used - check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug) * manage tracker initialization with set_tracker() function in utils.py, adapt get_key_def() to recursively check for parameter value in dictionary of dictionary
…#274) * - remove unused functions - remove ruamel_yaml import from active scripts - fix dontcare2background related to PR #256 * - create set_device function: rom dictionary of available devices, sets the device to be used - check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug) * manage tracker initialization with set_tracker() function in utils.py, adapt get_key_def() to recursively check for parameter value in dictionary of dictionary * - use get_key_def() to validate path existence and to convert to a pathlib.Path object - remove error-handling with try2read_csv and in_case_of_path - use hydra's to_absolute_path utils (remove most calls to ${hydra:runtime.cwd} in yamls - revert usage of paths to before PR #208 (remove error-handling, remove find_first_file(), set unique model directory at train) - replace warnings with logging.warning - replace assert with raise
* - remove unused functions - remove ruamel_yaml import from active scripts - fix dontcare2background related to PR #256 * - create set_device function: rom dictionary of available devices, sets the device to be used - check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug) * manage tracker initialization with set_tracker() function in utils.py, adapt get_key_def() to recursively check for parameter value in dictionary of dictionary * - use get_key_def() to validate path existence and to convert to a pathlib.Path object - remove error-handling with try2read_csv and in_case_of_path - use hydra's to_absolute_path utils (remove most calls to ${hydra:runtime.cwd} in yamls - revert usage of paths to before PR #208 (remove error-handling, remove find_first_file(), set unique model directory at train) - replace warnings with logging.warning - replace assert with raise * - verifications.py: validate_raster() -> add extended check move input_band_count == num_bands assertion to separate function - refactor segmentation() function - refactor gen_img_sample() function - use itetools.product in evaluate_segmentation - inference: refactor num_devices,default_max_ram_used - default_inference.yaml: update parameters with current usage * softcode max_pix_per_mb_gpu and default to 25 in default_inference.yaml
* use hydra to instantiate loss * let hydra manage dontcare value in most places remove weight parameter in losses that don't use it. * add check between binary/multiclass loss and num_classes in train_segmentation.py: one liner for loss = criterion
* - remove unused functions - remove ruamel_yaml import from active scripts - fix dontcare2background related to PR NRCan#256 * remove data_analysis.py from GDL and move to internal gitlab.
* - remove unused functions - remove ruamel_yaml import from active scripts - fix dontcare2background related to PR NRCan#256 * - create set_device function: rom dictionary of available devices, sets the device to be used - check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug) * remove try/except statement for old HPC bug (device ordinal error)
* - remove unused functions - remove ruamel_yaml import from active scripts - fix dontcare2background related to PR NRCan#256 * - create set_device function: rom dictionary of available devices, sets the device to be used - check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug) * manage tracker initialization with set_tracker() function in utils.py, adapt get_key_def() to recursively check for parameter value in dictionary of dictionary
…n#208 (NRCan#274) * - remove unused functions - remove ruamel_yaml import from active scripts - fix dontcare2background related to PR NRCan#256 * - create set_device function: rom dictionary of available devices, sets the device to be used - check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug) * manage tracker initialization with set_tracker() function in utils.py, adapt get_key_def() to recursively check for parameter value in dictionary of dictionary * - use get_key_def() to validate path existence and to convert to a pathlib.Path object - remove error-handling with try2read_csv and in_case_of_path - use hydra's to_absolute_path utils (remove most calls to ${hydra:runtime.cwd} in yamls - revert usage of paths to before PR NRCan#208 (remove error-handling, remove find_first_file(), set unique model directory at train) - replace warnings with logging.warning - replace assert with raise
…an#276) * - remove unused functions - remove ruamel_yaml import from active scripts - fix dontcare2background related to PR NRCan#256 * - create set_device function: rom dictionary of available devices, sets the device to be used - check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug) * manage tracker initialization with set_tracker() function in utils.py, adapt get_key_def() to recursively check for parameter value in dictionary of dictionary * - use get_key_def() to validate path existence and to convert to a pathlib.Path object - remove error-handling with try2read_csv and in_case_of_path - use hydra's to_absolute_path utils (remove most calls to ${hydra:runtime.cwd} in yamls - revert usage of paths to before PR NRCan#208 (remove error-handling, remove find_first_file(), set unique model directory at train) - replace warnings with logging.warning - replace assert with raise * - verifications.py: validate_raster() -> add extended check move input_band_count == num_bands assertion to separate function - refactor segmentation() function - refactor gen_img_sample() function - use itetools.product in evaluate_segmentation - inference: refactor num_devices,default_max_ram_used - default_inference.yaml: update parameters with current usage * softcode max_pix_per_mb_gpu and default to 25 in default_inference.yaml
No description provided.