Skip to content

Commit

Permalink
Use soundfile instead of torchaudio.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Oct 14, 2024
1 parent 619afe1 commit ca45c85
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
17 changes: 8 additions & 9 deletions examples/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@

from vocos_mlx import Vocos

import torch
import torchaudio
import soundfile as sf

SAMPLE_RATE = 24_000
HOP_LENGTH = 256
Expand All @@ -37,9 +36,9 @@ def generate(
f5tts = CFM.from_pretrained(model_name, vocab)

# load reference audio
audio, sr = torchaudio.load(Path(ref_audio_path))
audio = mx.array(audio.numpy())
ref_audio_duration = audio.shape[1] / SAMPLE_RATE
audio, sr = sf.read(ref_audio_path)
audio = mx.array(audio)
ref_audio_duration = audio.shape[0] / SAMPLE_RATE

rms = mx.sqrt(mx.mean(mx.square(audio)))
if rms < TARGET_RMS:
Expand All @@ -55,7 +54,7 @@ def generate(
vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz")

wave, _ = f5tts.sample(
audio,
mx.expand_dims(audio, axis=0),
text=text,
duration=frame_duration,
steps=32,
Expand All @@ -66,13 +65,13 @@ def generate(
)

# trim the reference audio
wave = wave[audio.shape[1] :]
generated_duration = len(wave) / SAMPLE_RATE
wave = wave[audio.shape[0]:]
generated_duration = wave.shape[0] / SAMPLE_RATE
elapsed_time = datetime.datetime.now() - start_date

print(f"Generated {generated_duration:.2f} seconds of audio in {elapsed_time}.")

torchaudio.save(output_path, torch.Tensor(np.array(wave)).unsqueeze(0), SAMPLE_RATE)
sf.write(output_path, np.array(wave), SAMPLE_RATE)


if __name__ == "__main__":
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ requires = [
"mlx",
"numpy",
"setuptools",
"torch",
"torchaudio",
"soundfile",
"vocos-mlx"
]
build-backend = "setuptools.build_meta"

[project]
name = "f5-tts-mlx"
version = "0.0.2"
version = "0.0.3"
authors = [{name = "Lucas Newman", email = "lucasnewman@me.com"}]
license = {text = "MIT"}
description = "F5-TTS - MLX"
Expand Down

0 comments on commit ca45c85

Please sign in to comment.