Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkpoint_dir is a required argument #120

Merged
merged 1 commit into from
Jan 2, 2025

Conversation

sai8951
Copy link
Contributor

@sai8951 sai8951 commented Dec 19, 2024

In the __init__() of class UNetTrainer, checkpoint_dir is out of kwargs.

def __init__(self, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, checkpoint_dir,
max_num_epochs, max_num_iterations, validate_after_iters=200, log_after_iters=100, validate_iters=None,
num_iterations=1, num_epoch=0, eval_score_higher_is_better=True, tensorboard_formatter=None,
skip_train_validation=False, resume=None, pre_trained=None, **kwargs):

The code always runs L145 if block; if 'checkpoint_dir' not in kwargs:.
if 'checkpoint_dir' not in kwargs:
self.checkpoint_dir = os.path.split(pre_trained)[0]

It causes unintended overwrite of checkpoint file.
The if sentence should be if not self.checkpoint_dir: or the if block should be removed.

@sai8951 sai8951 marked this pull request as ready for review December 19, 2024 08:04
@wolny
Copy link
Owner

wolny commented Jan 2, 2025

Hi @sai8951, thanks for the PR. You're absolutely right!

@wolny wolny merged commit 3d67c4e into wolny:master Jan 2, 2025
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants