Skip to content

Commit

Permalink
Support longer audio files reducing memory usage with chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
Gustavo Garcia committed Jul 2, 2024
1 parent ba3f3cd commit 20e3238
Show file tree
Hide file tree
Showing 3 changed files with 375 additions and 355 deletions.
6 changes: 3 additions & 3 deletions tests/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

def test_audio():
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
audio = load_audio(audio_path)
audio = next(load_audio(audio_path))
assert audio.ndim == 1
assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
assert 0 < audio.std() < 1

mel_from_audio = log_mel_spectrogram(audio)
mel_from_file = log_mel_spectrogram(audio_path)
mel_from_audio = next(log_mel_spectrogram(audio))
mel_from_file = next(log_mel_spectrogram(audio_path))

assert np.allclose(mel_from_audio, mel_from_file)
assert mel_from_audio.max() - mel_from_audio.min() <= 2.0
45 changes: 32 additions & 13 deletions whisper/audio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import subprocess
from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Optional, Union
from typing import Generator, Optional, Union

import numpy as np
import torch
Expand All @@ -21,6 +22,7 @@
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token

MAX_CHUNK_DURATION = 2 * 60 * 60 # 2 hour maximum chunk duration

def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Expand Down Expand Up @@ -55,12 +57,16 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

while True:
out = process.stdout.read(MAX_CHUNK_DURATION * sr * 2)
if not out:
break
yield np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0


def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Expand Down Expand Up @@ -108,7 +114,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:


def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
audio: Union[str, np.ndarray, torch.Tensor, Generator[np.ndarray, None, None]],
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
Expand All @@ -135,13 +141,26 @@ def log_mel_spectrogram(
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)

if device is not None:
audio = audio.to(device)
if isinstance(audio, str):
audio = load_audio(audio)
elif isinstance(audio, np.ndarray):
audio = [audio]
elif isinstance(audio, torch.Tensor):
audio = [audio]

for chunk in audio:
if not isinstance(chunk, torch.Tensor):
chunk = torch.from_numpy(chunk)
if device is not None:
chunk = chunk.to(device)
yield _log_mel_spectrogram(chunk, n_mels, padding)


def _log_mel_spectrogram(
audio: torch.Tensor,
n_mels: int = 80,
padding: int = 0,
):
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
Expand Down
Loading

0 comments on commit 20e3238

Please sign in to comment.