-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model_choice.py: refactor model, loss and optimizer instantiation and loading #292
Comments
remtav
added a commit
to remtav/geo-deep-learning
that referenced
this issue
Mar 21, 2022
adapt train_segmentation.py and inference_segmentation.py to new usage move state_dict_path param to default_training.yaml implement unit tests for model_choice.py functions read_checkpoint(): add robustness (covers external checkpoints with only model weights, and complies to torch's save key naming standard 'model_state_dict' and 'optimizer_state_dict' rather than gdl's 'model' and 'optimizer' keys create high level define_model() function using all low level models definition/loading functions from model_choice.py test_losses.py: implement class weights test softcode strict loading boolean for loading provided state_dict at train_segmentation.py
remtav
added a commit
that referenced
this issue
Mar 22, 2022
* refactor model_choice.py using solution in issue #292 adapt train_segmentation.py and inference_segmentation.py to new usage move state_dict_path param to default_training.yaml implement unit tests for model_choice.py functions read_checkpoint(): add robustness (covers external checkpoints with only model weights, and complies to torch's save key naming standard 'model_state_dict' and 'optimizer_state_dict' rather than gdl's 'model' and 'optimizer' keys create high level define_model() function using all low level models definition/loading functions from model_choice.py test_losses.py: implement class weights test softcode strict loading boolean for loading provided state_dict at train_segmentation.py * bugfix for github actions * more bugfixes for github actions * train_segmentation.py: bugfix --> read weights under 'model_state_dict' key
remtav
added a commit
to remtav/geo-deep-learning
that referenced
this issue
Jul 5, 2022
…RCan#294) * refactor model_choice.py using solution in issue NRCan#292 adapt train_segmentation.py and inference_segmentation.py to new usage move state_dict_path param to default_training.yaml implement unit tests for model_choice.py functions read_checkpoint(): add robustness (covers external checkpoints with only model weights, and complies to torch's save key naming standard 'model_state_dict' and 'optimizer_state_dict' rather than gdl's 'model' and 'optimizer' keys create high level define_model() function using all low level models definition/loading functions from model_choice.py test_losses.py: implement class weights test softcode strict loading boolean for loading provided state_dict at train_segmentation.py * bugfix for github actions * more bugfixes for github actions * train_segmentation.py: bugfix --> read weights under 'model_state_dict' key
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Model_choice.py has gotten messy and cluttered over the years and needs a bit of refactoring. This refactoring should be done before addressing #246 and #152.
Current state of things
Suggested solution (high level)
All these steps could be better separated in small, dedicated functions of their own:
These functions would be called only when necessary in 3 main places:
The text was updated successfully, but these errors were encountered: