Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use hydra to instantiate loss #256

Merged
merged 8 commits into from
Feb 7, 2022
Merged

Conversation

remtav
Copy link
Collaborator

@remtav remtav commented Feb 1, 2022

No description provided.

@remtav remtav marked this pull request as draft February 2, 2022 13:42
@@ -0,0 +1,4 @@
# @package _global_
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we name this file ohemce.yaml or keep full name?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree less long

@@ -8,7 +8,6 @@ training:
max_epochs: ${general.max_epochs}
min_epochs: ${general.min_epochs}
num_workers:
loss_fn: CrossEntropy # name of the metric you want to manage the training
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, the loss parameter should stay in the training config, but I couldn't point this parameter to the one in config/loss/multiclass use hydra's ${...} syntax. I did try quite a bit. I'd be moving the loss parameter out of training if we can't find a way to make this work.

Comment on lines 141 to 142
if loss_fn['_target_'] == 'torch.nn.CrossEntropyLoss':
criterion = instantiate(loss_fn, weight=class_weights) # FIXME: unable to pass this through hydra
Copy link
Collaborator Author

@remtav remtav Feb 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't like this. CrossEntropy loss (multiclass) is the only loss that uses class_weights as a parameter. We had adapted all other losses to accept it but no use it. The binary losses that @victorlazio109 has added from smp don't accept the "weight" parameter and therefore raise an exception if this parameter is passed.

Logically, class_weights would have been in the crossentropy.yaml config directly (and not in model_choice). It could therefore be dealt with directly by hydra, but one things makes it a bit tricky:
Cross entropy in torch.nn only accepts a tensor. If it could handle a list, this would be much more convenient for hydra.

The only way hydra can directly input a tensor to the weight argument is by changing the CE yaml to somtehing like this:

# @package _global_
loss:
    _target_: torch.nn.CrossEntropyLoss
    ignore_index: ${dataset.ignore_index}
    weight:
        _target_: torch.tensor
        data: ${dataset.class_weights}

This works fine when class_weights is not None (ex.: [0.5,0.5] successfully gets converted to tensor), but when it's None torch.tensor(None) raises an exception. It would require hydra to have some kind of if/else statement if class_weights is None.

Anyways, I'm leaving it to you guys if you have ideas. I've spent enough time on this case. If we don't use CE loss very often and especially if we don't use class_weights much, we could remove the "class_weights" parameter in the dataset config and the yaml could look like this :

# @package _global_
loss:
    _target_: torch.nn.CrossEntropyLoss
    ignore_index: ${dataset.ignore_index}
    # comment out if using class_weights
    # weight:
       #  _target_: torch.tensor
        # data: [0.5, 0.5] 

It's the only alternative I've found, but it is pretty hacky. On the long term, we could raise an issue on torch and ask them to accept a list for class_weights in the CE class.

Comment on lines 330 to 331
if criterion.mode == 'binary':
outputs = outputs[:, -1, ...]
Copy link
Collaborator Author

@remtav remtav Feb 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@victorlazio109 did you run into this "dimensionality" problem when using binary losses? When testing out, I noticed the binary losses only accept logits without the "backgound". Therefore, the prediction array must be indexed as shown in this bit of ugly code from [1,2,256,256] (each number being [batch size, num_classes, width, height]) to [1, 256, 256]. In other words, it becomes necessary to discard the "background" array and squeeze (in the numpy sense of things) the dimensions to match those of the labels [1, 256, 256]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved with my last PR, confirm if you can

# Conflicts:
#	config/training/default_training.yaml
#	losses/__init__.py
#	models/model_choice.py
#	train_segmentation.py
@remtav remtav marked this pull request as ready for review February 3, 2022 21:55
loss:
_target_: segmentation_models_pytorch.losses.DiceLoss
mode: binary
ignore_index: ${dataset.ignore_index}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these missing newlines will cause trouble. I can add them if anybody's nervous :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is it passed if removed ?

Copy link
Collaborator Author

@remtav remtav Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only script that is removed is losses/__init__.py. It is replaced by a colleciton of yamls in config/loss directory, one for each loss. Each yaml points to the class that needs to be instantiated by Hydra (ex.: segmentation_models_pytorch.losses.DiceLoss). It's the "hydranic" way of things that will also appear once we tackle issue #246

@@ -519,7 +519,7 @@ def train(cfg: DictConfig) -> None:

# MODEL PARAMETERS
class_weights = get_key_def('class_weights', cfg['dataset'], default=None)
loss_fn = get_key_def('loss_fn', cfg['training'], default='CrossEntropy')
loss_fn = cfg.loss
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the get_key_def can't be used as is since the loss will later be instantiated with hydra. The loss parameter is not just a string that can default to "crossentropy"

CharlesAuthier
CharlesAuthier previously approved these changes Feb 4, 2022
@@ -136,12 +135,11 @@ def set_hyperparameters(params,
gamma = get_key_def('gamma', params['scheduler']['params'], 0.9)
class_weights = torch.tensor(class_weights) if class_weights else None
# Loss function
if num_classes == 1:
criterion = SingleClassCriterion(loss_type=loss_fn, ignore_index=dontcare_val)
if loss_fn['_target_'] in ['torch.nn.CrossEntropyLoss', 'losses.focal_loss.FocalLoss',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We lose the differentiator between SingleClass and MultiClass here, if a user were to pass num_classes == 1 while picking CE loss he will almost certainly run into errors at some point in the code execution

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I can add a small check for this.

in train_segmentation.py: one liner for loss = criterion
Comment on lines +518 to +520
if not loss_fn.is_binary == (num_classes == 1):
raise ValueError(f"A binary loss was chosen for a multiclass task")
del loss_fn.is_binary # prevent exception at instantiation
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@victorlazio109 this is the simplest way I've found to make sure binary losses are not called in multiclass problem and vice versa.

@remtav remtav merged commit 31d429c into NRCan:develop Feb 7, 2022
@remtav remtav deleted the 255-loss-hydra branch February 7, 2022 18:25
remtav added a commit to remtav/geo-deep-learning that referenced this pull request Feb 9, 2022
- remove ruamel_yaml import from active scripts
- fix dontcare2background related to PR NRCan#256
This was referenced Feb 9, 2022
remtav added a commit that referenced this pull request Feb 15, 2022
* - remove unused functions
- remove ruamel_yaml import from active scripts
- fix dontcare2background related to PR #256

* remove data_analysis.py from GDL and move to internal gitlab.
remtav added a commit that referenced this pull request Feb 15, 2022
* - remove unused functions
- remove ruamel_yaml import from active scripts
- fix dontcare2background related to PR #256

* - create set_device function: rom dictionary of available devices, sets the device to be used
- check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug)

* remove try/except statement for old HPC bug (device ordinal error)
remtav added a commit that referenced this pull request Feb 15, 2022
* - remove unused functions
- remove ruamel_yaml import from active scripts
- fix dontcare2background related to PR #256

* - create set_device function: rom dictionary of available devices, sets the device to be used
- check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug)

* manage tracker initialization with set_tracker() function
in utils.py, adapt get_key_def() to recursively check for parameter value in dictionary of dictionary
remtav added a commit that referenced this pull request Feb 16, 2022
…#274)

* - remove unused functions
- remove ruamel_yaml import from active scripts
- fix dontcare2background related to PR #256

* - create set_device function: rom dictionary of available devices, sets the device to be used
- check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug)

* manage tracker initialization with set_tracker() function
in utils.py, adapt get_key_def() to recursively check for parameter value in dictionary of dictionary

* - use get_key_def() to validate path existence and to convert to a pathlib.Path object
- remove error-handling with try2read_csv and in_case_of_path
- use hydra's to_absolute_path utils (remove most calls to ${hydra:runtime.cwd} in yamls
- revert usage of paths to before PR #208 (remove error-handling, remove find_first_file(), set unique model directory at train)
- replace warnings with logging.warning
- replace assert with raise
remtav added a commit that referenced this pull request Feb 18, 2022
* - remove unused functions
- remove ruamel_yaml import from active scripts
- fix dontcare2background related to PR #256

* - create set_device function: rom dictionary of available devices, sets the device to be used
- check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug)

* manage tracker initialization with set_tracker() function
in utils.py, adapt get_key_def() to recursively check for parameter value in dictionary of dictionary

* - use get_key_def() to validate path existence and to convert to a pathlib.Path object
- remove error-handling with try2read_csv and in_case_of_path
- use hydra's to_absolute_path utils (remove most calls to ${hydra:runtime.cwd} in yamls
- revert usage of paths to before PR #208 (remove error-handling, remove find_first_file(), set unique model directory at train)
- replace warnings with logging.warning
- replace assert with raise

* - verifications.py: validate_raster() -> add extended check move input_band_count == num_bands assertion to separate function
- refactor segmentation() function
- refactor gen_img_sample() function
- use itetools.product in evaluate_segmentation
- inference: refactor num_devices,default_max_ram_used
- default_inference.yaml: update parameters with current usage

* softcode max_pix_per_mb_gpu
and default to 25 in default_inference.yaml
remtav added a commit to remtav/geo-deep-learning that referenced this pull request Jul 5, 2022
* use hydra to instantiate loss

* let hydra manage dontcare value in most places
remove weight parameter in losses that don't use it.

* add check between binary/multiclass loss and num_classes
in train_segmentation.py: one liner for loss = criterion
remtav added a commit to remtav/geo-deep-learning that referenced this pull request Jul 5, 2022
* - remove unused functions
- remove ruamel_yaml import from active scripts
- fix dontcare2background related to PR NRCan#256

* remove data_analysis.py from GDL and move to internal gitlab.
remtav added a commit to remtav/geo-deep-learning that referenced this pull request Jul 5, 2022
* - remove unused functions
- remove ruamel_yaml import from active scripts
- fix dontcare2background related to PR NRCan#256

* - create set_device function: rom dictionary of available devices, sets the device to be used
- check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug)

* remove try/except statement for old HPC bug (device ordinal error)
remtav added a commit to remtav/geo-deep-learning that referenced this pull request Jul 5, 2022
* - remove unused functions
- remove ruamel_yaml import from active scripts
- fix dontcare2background related to PR NRCan#256

* - create set_device function: rom dictionary of available devices, sets the device to be used
- check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug)

* manage tracker initialization with set_tracker() function
in utils.py, adapt get_key_def() to recursively check for parameter value in dictionary of dictionary
remtav added a commit to remtav/geo-deep-learning that referenced this pull request Jul 5, 2022
…n#208 (NRCan#274)

* - remove unused functions
- remove ruamel_yaml import from active scripts
- fix dontcare2background related to PR NRCan#256

* - create set_device function: rom dictionary of available devices, sets the device to be used
- check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug)

* manage tracker initialization with set_tracker() function
in utils.py, adapt get_key_def() to recursively check for parameter value in dictionary of dictionary

* - use get_key_def() to validate path existence and to convert to a pathlib.Path object
- remove error-handling with try2read_csv and in_case_of_path
- use hydra's to_absolute_path utils (remove most calls to ${hydra:runtime.cwd} in yamls
- revert usage of paths to before PR NRCan#208 (remove error-handling, remove find_first_file(), set unique model directory at train)
- replace warnings with logging.warning
- replace assert with raise
remtav added a commit to remtav/geo-deep-learning that referenced this pull request Jul 5, 2022
…an#276)

* - remove unused functions
- remove ruamel_yaml import from active scripts
- fix dontcare2background related to PR NRCan#256

* - create set_device function: rom dictionary of available devices, sets the device to be used
- check if model can be pushed to device, else catch exception and try with cuda, not cuda:0 (HPC bug)

* manage tracker initialization with set_tracker() function
in utils.py, adapt get_key_def() to recursively check for parameter value in dictionary of dictionary

* - use get_key_def() to validate path existence and to convert to a pathlib.Path object
- remove error-handling with try2read_csv and in_case_of_path
- use hydra's to_absolute_path utils (remove most calls to ${hydra:runtime.cwd} in yamls
- revert usage of paths to before PR NRCan#208 (remove error-handling, remove find_first_file(), set unique model directory at train)
- replace warnings with logging.warning
- replace assert with raise

* - verifications.py: validate_raster() -> add extended check move input_band_count == num_bands assertion to separate function
- refactor segmentation() function
- refactor gen_img_sample() function
- use itetools.product in evaluate_segmentation
- inference: refactor num_devices,default_max_ram_used
- default_inference.yaml: update parameters with current usage

* softcode max_pix_per_mb_gpu
and default to 25 in default_inference.yaml
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants