diff --git a/whisper/timing.py b/whisper/timing.py index 1a73eaaf..207d877d 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -225,28 +225,6 @@ def find_alignment( for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) ] - # hack: truncate long words at the start of a window and the start of a sentence. - # a better segmentation algorithm based on VAD should be able to replace this. - word_durations = end_times - start_times - word_durations = word_durations[word_durations.nonzero()] - if len(word_durations) > 0: - median_duration = np.median(word_durations) - max_duration = median_duration * 2 - sentence_end_marks = ".。!!??" - # ensure words at sentence boundaries are not longer than twice the median word duration. - for i in range(1, len(start_times)): - if end_times[i] - start_times[i] > max_duration: - if words[i] in sentence_end_marks: - end_times[i] = start_times[i] + max_duration - elif words[i - 1] in sentence_end_marks: - start_times[i] = end_times[i] - max_duration - # ensure the first and second word is not longer than twice the median word duration. - if len(start_times) > 0 and end_times[0] - start_times[0] > max_duration: - if len(start_times) > 1 and end_times[1] - start_times[1] > max_duration: - boundary = max(end_times[1] / 2, end_times[1] - max_duration) - end_times[0] = start_times[1] = boundary - start_times[0] = max(0, end_times[0] - max_duration) - return [ WordTiming(word, tokens, start, end, probability) for word, tokens, start, end, probability in zip( @@ -298,6 +276,7 @@ def add_word_timestamps( num_frames: int, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", + last_speech_timestamp: float, **kwargs, ): if len(segments) == 0: @@ -310,6 +289,25 @@ def add_word_timestamps( text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) + word_durations = np.array([t.end - t.start for t in alignment]) + word_durations = word_durations[word_durations.nonzero()] + median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 + max_duration = median_duration * 2 + + # hack: truncate long words at sentence boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(word_durations) > 0: + median_duration = np.median(word_durations) + max_duration = median_duration * 2 + sentence_end_marks = ".。!!??" + # ensure words at sentence boundaries are not longer than twice the median word duration. + for i in range(1, len(alignment)): + if alignment[i].end - alignment[i].start > max_duration: + if alignment[i].word in sentence_end_marks: + alignment[i].end = alignment[i].start + max_duration + elif alignment[i - 1].word in sentence_end_marks: + alignment[i].start = alignment[i].end - max_duration + merge_punctuations(alignment, prepend_punctuations, append_punctuations) time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE @@ -335,18 +333,48 @@ def add_word_timestamps( saved_tokens += len(timing.tokens) word_index += 1 + # hack: truncate long words at segment boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. if len(words) > 0: - segment["start"] = words[0]["start"] - # hack: prefer the segment-level end timestamp if the last word is too long. - # a better segmentation algorithm based on VAD should be able to replace this. + # ensure the first and second word after a pause is not longer than + # twice the median word duration. + if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( + words[0]["end"] - words[0]["start"] > max_duration + or ( + len(words) > 1 + and words[1]["end"] - words[0]["start"] > max_duration * 2 + ) + ): + if ( + len(words) > 1 + and words[1]["end"] - words[1]["start"] > max_duration + ): + boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration) + words[0]["end"] = words[1]["start"] = boundary + words[0]["start"] = max(0, words[0]["end"] - max_duration) + + # prefer the segment-level start timestamp if the first word is too long. + if ( + segment["start"] < words[0]["end"] + and segment["start"] - 0.5 > words[0]["start"] + ): + words[0]["start"] = max( + 0, min(words[0]["end"] - median_duration, segment["start"]) + ) + else: + segment["start"] = words[0]["start"] + + # prefer the segment-level end timestamp if the last word is too long. if ( segment["end"] > words[-1]["start"] and segment["end"] + 0.5 < words[-1]["end"] ): - # adjust the word-level timestamps based on the segment-level timestamps - words[-1]["end"] = segment["end"] + words[-1]["end"] = max( + words[-1]["start"] + median_duration, segment["end"] + ) else: - # adjust the segment-level timestamps based on the word-level timestamps segment["end"] = words[-1]["end"] + last_speech_timestamp = segment["end"] + segment["words"] = words diff --git a/whisper/transcribe.py b/whisper/transcribe.py index ff73a553..6e43a22f 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -222,6 +222,7 @@ def new_segment( with tqdm.tqdm( total=content_frames, unit="frames", disable=verbose is not False ) as pbar: + last_speech_timestamp = 0.0 while seek < content_frames: time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) mel_segment = mel[:, seek : seek + N_FRAMES] @@ -321,10 +322,13 @@ def new_segment( num_frames=segment_size, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, + last_speech_timestamp=last_speech_timestamp, ) word_end_timestamps = [ w["end"] for s in current_segments for w in s["words"] ] + if len(word_end_timestamps) > 0: + last_speech_timestamp = word_end_timestamps[-1] if not single_timestamp_ending and len(word_end_timestamps) > 0: seek_shift = round( (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND