Skip to content

Commit

Permalink
Merge pull request #7 from SimonMossmyr/patch-1
Browse files Browse the repository at this point in the history
Add path argument to EarlyStopping init
  • Loading branch information
Bjarten authored Jun 4, 2020
2 parents 7d8a086 + 321aa0d commit 7ec86aa
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions pytorchtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0):
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Expand All @@ -12,6 +12,8 @@ def __init__(self, patience=7, verbose=False, delta=0):
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
path (str): Path for the checkpoint to be saved to.
Default: 'checkpoint.pt'
"""
self.patience = patience
self.verbose = verbose
Expand All @@ -20,6 +22,7 @@ def __init__(self, patience=7, verbose=False, delta=0):
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.path = path

def __call__(self, val_loss, model):

Expand All @@ -42,5 +45,5 @@ def save_checkpoint(self, val_loss, model):
'''Saves model when validation loss decrease.'''
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), 'checkpoint.pt')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss

0 comments on commit 7ec86aa

Please sign in to comment.