Skip to content

Commit

Permalink
Update logic get segment from features before encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
trungkienbkhn committed Feb 24, 2024
1 parent 06d32bf commit 4f712b0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
15 changes: 15 additions & 0 deletions faster_whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,18 @@ def _resample_frames(frames, resampler):
# Add None to flush the resampler.
for frame in itertools.chain(frames, [None]):
yield from resampler.resample(frame)


def pad_or_trim(array, length: int, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)

return array
3 changes: 2 additions & 1 deletion faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import tokenizers

from faster_whisper.audio import decode_audio
from faster_whisper.audio import decode_audio, pad_or_trim
from faster_whisper.feature_extractor import FeatureExtractor
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
Expand Down Expand Up @@ -492,6 +492,7 @@ def generate_segments(
)
segment = features[:, seek : seek + segment_size]
segment_duration = segment_size * self.feature_extractor.time_per_frame
segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames)

if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
Expand Down

0 comments on commit 4f712b0

Please sign in to comment.