Skip to content

Commit

Permalink
Add max_ckpt_keep for trainer (deepmodeling#3441)
Browse files Browse the repository at this point in the history
Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
iProzd and njzjz authored Mar 9, 2024
1 parent a9bcf41 commit fd82f04
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
10 changes: 10 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
self.disp_freq = training_params.get("disp_freq", 1000)
self.save_ckpt = training_params.get("save_ckpt", "model.ckpt")
self.save_freq = training_params.get("save_freq", 1000)
self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5)
self.lcurve_should_print_header = True

def get_opt_param(params):
Expand Down Expand Up @@ -924,6 +925,15 @@ def save_model(self, save_path, lr=0.0, step=0):
{"model": module.state_dict(), "optimizer": self.optimizer.state_dict()},
save_path,
)
checkpoint_dir = save_path.parent
checkpoint_files = [
f
for f in checkpoint_dir.glob("*.pt")
if not f.is_symlink() and f.name.startswith(self.save_ckpt)
]
if len(checkpoint_files) > self.max_ckpt_keep:
checkpoint_files.sort(key=lambda x: x.stat().st_mtime)
checkpoint_files[0].unlink()

def get_data(self, is_train=True, task_key="Default"):
if not self.multi_task:
Expand Down
5 changes: 4 additions & 1 deletion deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def get_lr_and_coef(lr_param):
self.disp_freq = tr_data.get("disp_freq", 1000)
self.save_freq = tr_data.get("save_freq", 1000)
self.save_ckpt = tr_data.get("save_ckpt", "model.ckpt")
self.max_ckpt_keep = tr_data.get("max_ckpt_keep", 5)
self.display_in_training = tr_data.get("disp_training", True)
self.timing_in_training = tr_data.get("time_training", True)
self.profiling = self.run_opt.is_chief and tr_data.get("profiling", False)
Expand Down Expand Up @@ -498,7 +499,9 @@ def _init_session(self):
# Initializes or restore global variables
init_op = tf.global_variables_initializer()
if self.run_opt.is_chief:
self.saver = tf.train.Saver(save_relative_paths=True)
self.saver = tf.train.Saver(
save_relative_paths=True, max_to_keep=self.max_ckpt_keep
)
if self.run_opt.init_mode == "init_from_scratch":
log.info("initialize model from scratch")
run_sess(self.sess, init_op)
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,6 +2134,11 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
doc_disp_freq = "The frequency of printing learning curve."
doc_save_freq = "The frequency of saving check point."
doc_save_ckpt = "The path prefix of saving check point files."
doc_max_ckpt_keep = (
"The maximum number of checkpoints to keep. "
"The oldest checkpoints will be deleted once the number of checkpoints exceeds max_ckpt_keep. "
"Defaults to 5."
)
doc_disp_training = "Displaying verbose information during training."
doc_time_training = "Timing durining training."
doc_profiling = "Profiling during training."
Expand Down Expand Up @@ -2192,6 +2197,7 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
Argument(
"save_ckpt", str, optional=True, default="model.ckpt", doc=doc_save_ckpt
),
Argument("max_ckpt_keep", int, optional=True, default=5, doc=doc_max_ckpt_keep),
Argument(
"disp_training", bool, optional=True, default=True, doc=doc_disp_training
),
Expand Down

0 comments on commit fd82f04

Please sign in to comment.