Skip to content

Commit 47ee14b

Browse files
authored
Update train.py
1 parent c1946ed commit 47ee14b

File tree

1 file changed

+26
-27
lines changed

1 file changed

+26
-27
lines changed

vits_extend/train.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -274,34 +274,33 @@ def train(rank, args, chkpt_path, hp, hp_str):
274274
}, save_path)
275275
logger.info("Saved checkpoint to: %s" % save_path)
276276

277-
278-
def clean_checkpoints(path_to_models=f'{pth_dir}', n_ckpts_to_keep=hp.log.keep_ckpts, sort_by_time=True):
279-
"""Freeing up space by deleting saved ckpts
280-
Arguments:
281-
path_to_models -- Path to the model directory
282-
n_ckpts_to_keep -- Number of ckpts to keep, excluding sovits5.0_0.pth
283-
If n_ckpts_to_keep == 0, do not delete any ckpts
284-
sort_by_time -- True -> chronologically delete ckpts
285-
False -> lexicographically delete ckpts
286-
"""
287-
assert isinstance(n_ckpts_to_keep, int) and n_ckpts_to_keep >= 0
288-
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
289-
name_key = (lambda _f: int(re.compile(f'{args.name}_(\d+)\.pt').match(_f).group(1)))
290-
time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
291-
sort_key = time_key if sort_by_time else name_key
292-
x_sorted = lambda _x: sorted(
293-
[f for f in ckpts_files if f.startswith(_x) and not f.endswith('sovits5.0_0.pth')], key=sort_key)
294-
if n_ckpts_to_keep == 0:
295-
to_del = []
296-
else:
297-
to_del = [os.path.join(path_to_models, fn) for fn in x_sorted(f'{args.name}')[:-n_ckpts_to_keep]]
298-
del_info = lambda fn: logger.info(f"Free up space by deleting ckpt {fn}")
299-
del_routine = lambda x: [os.remove(x), del_info(x)]
300-
rs = [del_routine(fn) for fn in to_del]
301-
302-
clean_checkpoints()
303-
304277
if rank == 0:
278+
def clean_checkpoints(path_to_models=f'{pth_dir}', n_ckpts_to_keep=hp.log.keep_ckpts, sort_by_time=True):
279+
"""Freeing up space by deleting saved ckpts
280+
Arguments:
281+
path_to_models -- Path to the model directory
282+
n_ckpts_to_keep -- Number of ckpts to keep, excluding sovits5.0_0.pth
283+
If n_ckpts_to_keep == 0, do not delete any ckpts
284+
sort_by_time -- True -> chronologically delete ckpts
285+
False -> lexicographically delete ckpts
286+
"""
287+
assert isinstance(n_ckpts_to_keep, int) and n_ckpts_to_keep >= 0
288+
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
289+
name_key = (lambda _f: int(re.compile(f'{args.name}_(\d+)\.pt').match(_f).group(1)))
290+
time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
291+
sort_key = time_key if sort_by_time else name_key
292+
x_sorted = lambda _x: sorted(
293+
[f for f in ckpts_files if f.startswith(_x) and not f.endswith('sovits5.0_0.pth')], key=sort_key)
294+
if n_ckpts_to_keep == 0:
295+
to_del = []
296+
else:
297+
to_del = [os.path.join(path_to_models, fn) for fn in x_sorted(f'{args.name}')[:-n_ckpts_to_keep]]
298+
del_info = lambda fn: logger.info(f"Free up space by deleting ckpt {fn}")
299+
del_routine = lambda x: [os.remove(x), del_info(x)]
300+
rs = [del_routine(fn) for fn in to_del]
301+
302+
clean_checkpoints()
303+
305304
os.makedirs(f'{pth_dir}', exist_ok=True)
306305
keep_ckpts = getattr(hp.log, 'keep_ckpts', 0)
307306
if keep_ckpts > 0:

0 commit comments

Comments
 (0)