-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on
Description
🐛 Bug
I wrote a customized callback to deal with checkpoint saving, which saves every n train_steps but only keeps top_k with minimal train_loss.
I want to use self.history to track saved ckpt path and its corresponding train loss. However, it turns out that the rank_0 process cannot find the ckpt file when it calls self._del_model(ckpt_path). I find that the actually saved ckpt might not be the one in rank_0's self.history.
So I'd like to ask how to find the actual saved ckpt path. Any reply is highly appreciated.
Please reproduce using the BoringModel
To Reproduce
Use following BoringModel and post here
class SaveCallback(Callback):
def __init__(self, save_path, save_steps=1000, save_top_k=0):
super(SaveCallback, self).__init__()
self.save_path=save_path
self.save_steps=save_steps
self.save_top_k=save_top_k
self.history=[]
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if pl_module.global_step !=0 and (pl_module.global_step % self.save_steps == 0) :
epoch=pl_module.current_epoch
step=pl_module.global_step
loss=trainer.callback_metrics["train_loss"].detach().item()
if self.save_top_k and (len(self.history)==self.save_top_k):
self.history.sort(key=lambda x: x[2])
last_max_loss=self.history[-1][2]
if last_max_loss > loss:
path_to_remove=self.history[-1][-1]
self._del_model(path_to_remove)
ckpt_name=f'epoch-{epoch}--step{step}--train_loss-{loss: .2f}'+'.ckpt'
trainer.save_checkpoint(self.save_path+ckpt_name)
self.history.append([epoch, step, loss, self.save_path+ckpt_name])
else:
ckpt_name=f'epoch-{epoch}--step{step}--train_loss-{loss:.2f}'+'.ckpt'
trainer.save_checkpoint(self.save_path+ckpt_name)
self.history.append([epoch, step, loss, self.save_path+ckpt_name])
@rank_zero_only
def _del_model(self, path):
if os.path.exists(path):
os.remove(path)
log.debug(f'removed checkpoint: {path}.')
My lightning version is 1.2.6.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on