Skip to content

Commit

Permalink
Fix old checkpoint deletion by sorting the models properly
Browse files Browse the repository at this point in the history
Should fix voicepaw#65 this time around.
The reason why this didn't work before is because it sorted them based on their name or time alone, but not also based on their prefix, so the resulting array looked like "G_", "D_", "G_", "D_", ...

This addresses that by also sorting based on their prefix, resulting in a properly sorted array "G_", "G_", "G_", "D_", "D_", "D_"

This has to be done this way due to how the itertools.groupby method works: https://note.nkmk.me/en/python-itertools-groupby/
  • Loading branch information
Lordmau5 committed Mar 23, 2023
1 parent 9ffb621 commit c9c1563
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions src/so_vits_svc_fork/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,24 +451,26 @@ def clean_checkpoints(
sort_by_time -- True -> chronologically delete ckpts
False -> lexicographically delete ckpts
"""
LOG.warning("Cleaning old checkpoints...")
path_to_models = Path(path_to_models)
name_key = lambda p: int(re.match(r"._(\d+)", p.stem).group(1))

# Define sort key functions
name_key = lambda p: int(re.match(r"[GD]_(\d+)\.pth", p.name).group(1))
time_key = lambda p: p.stat().st_mtime
models_sorted = sorted(
filter(
lambda p: (p.is_file() and re.match(r"._\d+", p.stem)),
path_to_models.glob("*.pth"),
),
key=time_key if sort_by_time else name_key,
)
models_sorted_grouped = groupby(models_sorted, lambda p: p.stem[0])
for k, g in models_sorted_grouped:
to_dels = list(g)[:-n_ckpts_to_keep]
for to_del in to_dels:
if to_del.stem.endswith("_0"):
continue
LOG.warning(f"Removing {to_del}")
to_del.unlink()
path_key = lambda p: (p.name[0],) + (time_key(p),) if sort_by_time else (p.name[0],) + (name_key(p),)

models = list(filter(lambda p: (p.is_file() and re.match(r"[GD]_\d+\.pth", p.name) and not p.name.endswith("_0.pth")), path_to_models.glob("*.pth")))

models_sorted = sorted(models, key=path_key)

models_sorted_grouped = groupby(models_sorted, lambda p: p.name[0])

for group_name, group_items in models_sorted_grouped:
to_delete_list = list(group_items)[:-n_ckpts_to_keep]

for to_delete in to_delete_list:
LOG.warning(f"Removing {to_delete}")
to_delete.unlink()


def summarize(
Expand Down

0 comments on commit c9c1563

Please sign in to comment.