diff --git a/examples/generate.py b/examples/generate.py index 7757aeb..5cbeb28 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -12,8 +12,7 @@ from vocos_mlx import Vocos -import torch -import torchaudio +import soundfile as sf SAMPLE_RATE = 24_000 HOP_LENGTH = 256 @@ -37,9 +36,9 @@ def generate( f5tts = CFM.from_pretrained(model_name, vocab) # load reference audio - audio, sr = torchaudio.load(Path(ref_audio_path)) - audio = mx.array(audio.numpy()) - ref_audio_duration = audio.shape[1] / SAMPLE_RATE + audio, sr = sf.read(ref_audio_path) + audio = mx.array(audio) + ref_audio_duration = audio.shape[0] / SAMPLE_RATE rms = mx.sqrt(mx.mean(mx.square(audio))) if rms < TARGET_RMS: @@ -55,7 +54,7 @@ def generate( vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz") wave, _ = f5tts.sample( - audio, + mx.expand_dims(audio, axis=0), text=text, duration=frame_duration, steps=32, @@ -66,13 +65,13 @@ def generate( ) # trim the reference audio - wave = wave[audio.shape[1] :] - generated_duration = len(wave) / SAMPLE_RATE + wave = wave[audio.shape[0]:] + generated_duration = wave.shape[0] / SAMPLE_RATE elapsed_time = datetime.datetime.now() - start_date print(f"Generated {generated_duration:.2f} seconds of audio in {elapsed_time}.") - torchaudio.save(output_path, torch.Tensor(np.array(wave)).unsqueeze(0), SAMPLE_RATE) + sf.write(output_path, np.array(wave), SAMPLE_RATE) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 368b48b..9d4eda3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,15 +7,14 @@ requires = [ "mlx", "numpy", "setuptools", - "torch", - "torchaudio", + "soundfile", "vocos-mlx" ] build-backend = "setuptools.build_meta" [project] name = "f5-tts-mlx" -version = "0.0.2" +version = "0.0.3" authors = [{name = "Lucas Newman", email = "lucasnewman@me.com"}] license = {text = "MIT"} description = "F5-TTS - MLX"