Skip to content

Commit

Permalink
Updated to deal with Gradio 4.xx versions
Browse files Browse the repository at this point in the history
  • Loading branch information
erew123 authored May 14, 2024
1 parent 3a105ce commit 7c7cb72
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,15 +778,15 @@ def get_available_voices(minimum_size_kb=1200):

def find_best_models(directory):
"""Find files named 'best_model.pth' in the given directory."""
return [file for file in Path(directory).rglob("best_model.pth")]
return [str(file) for file in Path(directory).rglob("best_model.pth")]

def find_models(directory, extension):
"""Find files with a specific extension in the given directory."""
return [file for file in Path(directory).rglob(f"*.{extension}")]
return [str(file) for file in Path(directory).rglob(f"*.{extension}")]

def find_jsons(directory, filename):
"""Find files with a specific filename in the given directory."""
return list(Path(directory).rglob(filename))
return [str(file) for file in Path(directory).rglob(filename)]

# Your main directory
main_directory = Path(this_dir) / "finetune" / "tmp-trn"
Expand All @@ -809,22 +809,21 @@ def find_latest_best_model(folder_path):

def compact_model(xtts_checkpoint_copy):
this_dir = Path(__file__).parent.resolve()
best_model_path_str = xtts_checkpoint_copy
print("THIS DIR:", this_dir)
best_model_path_str = str(xtts_checkpoint_copy) # Convert to string
print("best_model_path_str", best_model_path_str)

# Check if the best model file exists
if best_model_path_str is None:
if not best_model_path_str:
print("[FINETUNE] No trained model was found.")
return "No trained model was found."

print(f"[FINETUNE] Best model path: {best_model_path_str}")

# Convert model_path_str to Path
best_model_path = Path(best_model_path_str)

# Attempt to load the model
try:
checkpoint = torch.load(best_model_path, map_location=torch.device("cpu"))
print(f"[FINETUNE] Checkpoint loaded: {best_model_path}")
checkpoint = torch.load(best_model_path_str, map_location=torch.device("cpu"))
print(f"[FINETUNE] Checkpoint loaded: {best_model_path_str}")
except Exception as e:
print("[FINETUNE] Error loading checkpoint:", e)
raise
Expand All @@ -842,15 +841,15 @@ def compact_model(xtts_checkpoint_copy):
del checkpoint["model"][key]

# Save the modified checkpoint in the target directory
torch.save(checkpoint, target_dir / "model.pth")
torch.save(checkpoint, str(target_dir / "model.pth")) # Convert to string

# Specify the files you want to copy
files_to_copy = ["vocab.json", "config.json", "speakers_xtts.pth", "mel_stats.pth", "dvae.pth",]
files_to_copy = ["vocab.json", "config.json", "speakers_xtts.pth", "mel_stats.pth", "dvae.pth"]

for file_name in files_to_copy:
src_path = this_dir / base_path / base_model_path / file_name
dest_path = target_dir / file_name
shutil.copy(str(src_path), str(dest_path))
shutil.copy(str(src_path), str(dest_path)) # Convert to string

source_wavs_dir = this_dir / "finetune" / "tmp-trn" / "wavs"
target_wavs_dir = target_dir / "wavs"
Expand All @@ -861,7 +860,7 @@ def compact_model(xtts_checkpoint_copy):
# Check if it's a file and larger than 1000 KB
if file_path.is_file() and file_path.stat().st_size > 1000 * 1024:
# Copy the file to the target directory
shutil.copy(str(file_path), str(target_wavs_dir / file_path.name))
shutil.copy(str(file_path), str(target_wavs_dir / file_path.name)) # Convert to string

print("[FINETUNE] Model copied to '/models/trainedmodel/'")
return "Model copied to '/models/trainedmodel/'"
Expand Down Expand Up @@ -1764,7 +1763,7 @@ def train_model(language, train_csv, eval_csv, learning_rates, num_epochs, batch
)
# Create refresh button
refresh_button = create_refresh_button(
[xtts_checkpoint,],
[xtts_checkpoint_copy,],
[
lambda: {"choices": find_best_models(main_directory), "value": ""},
],
Expand Down

0 comments on commit 7c7cb72

Please sign in to comment.