Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model_choice.py: refactor model, loss and optimizer instantiation and loading #292

Closed
remtav opened this issue Mar 21, 2022 · 0 comments · Fixed by #294
Closed

Model_choice.py: refactor model, loss and optimizer instantiation and loading #292

remtav opened this issue Mar 21, 2022 · 0 comments · Fixed by #294

Comments

@remtav
Copy link
Collaborator

remtav commented Mar 21, 2022

Model_choice.py has gotten messy and cluttered over the years and needs a bit of refactoring. This refactoring should be done before addressing #246 and #152.

Current state of things

  • set_hyperparamers() function has very vague prupose of "set[ting] hyperparameters based on values provided in yaml config file";
  • net() function is supposed to "Define the neural net", but in reality it's an all-in-one vague function that does the following:
  1. Defines net architecture;
  2. Reads a checkpoint to memory with load_checkpoint() (from a .pth.tar file as created by torch.save);
  3. Returns if net() is called from inference or continues with the following if net() is called from train mode:
  4. If more than one gpu is requested, determines which gpus are available based on user-inputted threshold for GPU's available RAM and usage %;
  5. Sets model to DataParallel if more than one gpu is requested and available;
  6. Sets main device with set_device() function;
  7. Pushes model to main device;
  8. Calls set_hyperparamters() (see above);
  9. Pushes loss to device;
  10. Returns 7 (!!) objects: model, model_name, loss, etc.

Suggested solution (high level)

All these steps could be better separated in small, dedicated functions of their own:

  • read_checkpoint(): renamed version of load_checkpoint (prevents confusion with load_state_dict function). Although it derives from torch.load(checkpoint)'s function, this function really just reads a checkpoint in memory from a .pth.tar file to a Python dict containing weights, optimizer, etc.
  • define_net_architecture(): define the model architecture from config parameters (i.e. create model with randomly initialized weights)
  • adapt_checkpoint_to_dp_model(): for use at test loop during training only, adapts a generic checkpoint to be loaded to a DataParallel model as is done in load_from_checkpoint (if model is DataParallel object)
  • define_loss(): calls verify_weights() and instantiates a loss criterion
  • define_optimizer(): instantiates optimizer with learning rate, weight decay, etc.

These functions would be called only when necessary in 3 main places:

  1. Beginning of train_segmentation:
  1. Test loop in train_segmentation:
  • Load best checkpoint to model (adapt checkpoint keys if model is a nn.DataParallel instance using dedicated function);
  1. Beginning of inference:
  • override architecture, input bands, output classes from checkpoint's params;
  • define net architecture
  • load weights from provided checkpoint to model using pytorch's [model_object].load_state_dict() method
remtav added a commit to remtav/geo-deep-learning that referenced this issue Mar 21, 2022
adapt train_segmentation.py and inference_segmentation.py to new usage
move state_dict_path param to default_training.yaml
implement unit tests for model_choice.py functions
read_checkpoint(): add robustness (covers external checkpoints with only model weights, and complies to torch's save key naming standard 'model_state_dict' and 'optimizer_state_dict' rather than gdl's 'model' and 'optimizer' keys
create high level define_model() function using all low level models definition/loading functions from model_choice.py
test_losses.py: implement class weights test
softcode strict loading boolean for loading provided state_dict at train_segmentation.py
remtav added a commit that referenced this issue Mar 22, 2022
* refactor model_choice.py using solution in issue #292
adapt train_segmentation.py and inference_segmentation.py to new usage
move state_dict_path param to default_training.yaml
implement unit tests for model_choice.py functions
read_checkpoint(): add robustness (covers external checkpoints with only model weights, and complies to torch's save key naming standard 'model_state_dict' and 'optimizer_state_dict' rather than gdl's 'model' and 'optimizer' keys
create high level define_model() function using all low level models definition/loading functions from model_choice.py
test_losses.py: implement class weights test
softcode strict loading boolean for loading provided state_dict at train_segmentation.py

* bugfix for github actions

* more bugfixes for github actions

* train_segmentation.py: bugfix --> read weights under 'model_state_dict' key
remtav added a commit to remtav/geo-deep-learning that referenced this issue Jul 5, 2022
…RCan#294)

* refactor model_choice.py using solution in issue NRCan#292
adapt train_segmentation.py and inference_segmentation.py to new usage
move state_dict_path param to default_training.yaml
implement unit tests for model_choice.py functions
read_checkpoint(): add robustness (covers external checkpoints with only model weights, and complies to torch's save key naming standard 'model_state_dict' and 'optimizer_state_dict' rather than gdl's 'model' and 'optimizer' keys
create high level define_model() function using all low level models definition/loading functions from model_choice.py
test_losses.py: implement class weights test
softcode strict loading boolean for loading provided state_dict at train_segmentation.py

* bugfix for github actions

* more bugfixes for github actions

* train_segmentation.py: bugfix --> read weights under 'model_state_dict' key
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant