diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a201..58754b40b 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -134,6 +134,11 @@ def transcribe( content_frames = mel.shape[-1] - N_FRAMES content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE) + def maybe_dereference_list(obj: Union[dict, list[dict]]) -> dict: + if isinstance(obj, list): + return obj[0] + return obj + if decode_options.get("language", None) is None: if not model.is_multilingual: decode_options["language"] = "en" @@ -144,7 +149,7 @@ def transcribe( ) mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) _, probs = model.detect_language(mel_segment) - decode_options["language"] = max(probs, key=probs.get) + decode_options["language"] = max(maybe_dereference_list(probs), key=maybe_dereference_list(probs).get) if verbose is not None: print( f"Detected language: {LANGUAGES[decode_options['language']].title()}" @@ -192,7 +197,7 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: kwargs.pop("best_of", None) options = DecodingOptions(**kwargs, temperature=t) - decode_result = model.decode(segment, options) + decode_result = maybe_dereference_list(model.decode(segment, options)) needs_fallback = False if (