Skip to content

Tracking checkpoint name when train_loss in checkpoint name during ddp mode #6775

@liyucheng09

Description

@liyucheng09

🐛 Bug

I wrote a customized callback to deal with checkpoint saving, which saves every n train_steps but only keeps top_k with minimal train_loss.

I want to use self.history to track saved ckpt path and its corresponding train loss. However, it turns out that the rank_0 process cannot find the ckpt file when it calls self._del_model(ckpt_path). I find that the actually saved ckpt might not be the one in rank_0's self.history.

So I'd like to ask how to find the actual saved ckpt path. Any reply is highly appreciated.

Please reproduce using the BoringModel

To Reproduce

Use following BoringModel and post here

class SaveCallback(Callback):

    def __init__(self, save_path, save_steps=1000, save_top_k=0):
        super(SaveCallback, self).__init__()
        self.save_path=save_path
        self.save_steps=save_steps
        self.save_top_k=save_top_k
        self.history=[]

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        if pl_module.global_step !=0 and (pl_module.global_step % self.save_steps == 0) :
            epoch=pl_module.current_epoch
            step=pl_module.global_step
            loss=trainer.callback_metrics["train_loss"].detach().item()

            if self.save_top_k and (len(self.history)==self.save_top_k):
                self.history.sort(key=lambda x: x[2])
                last_max_loss=self.history[-1][2]
                if last_max_loss > loss:
                    path_to_remove=self.history[-1][-1]
                    self._del_model(path_to_remove)

                    ckpt_name=f'epoch-{epoch}--step{step}--train_loss-{loss: .2f}'+'.ckpt'
                    trainer.save_checkpoint(self.save_path+ckpt_name)
                    self.history.append([epoch, step, loss, self.save_path+ckpt_name])
            else:
                ckpt_name=f'epoch-{epoch}--step{step}--train_loss-{loss:.2f}'+'.ckpt'
                trainer.save_checkpoint(self.save_path+ckpt_name)
                self.history.append([epoch, step, loss, self.save_path+ckpt_name])

    @rank_zero_only
    def _del_model(self, path):
        if os.path.exists(path):
            os.remove(path)
            log.debug(f'removed checkpoint: {path}.')

My lightning version is 1.2.6.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedOpen to be worked on

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions