Skip to content

Commit

Permalink
Merge pull request #4 from argmaxinc/atila/openai_api_earnings22
Browse files Browse the repository at this point in the history
OpenAI API earnings22 Eval
  • Loading branch information
atiorh authored Mar 5, 2024
2 parents 4f247ee + 808fa17 commit 95d178c
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 20 deletions.
2 changes: 1 addition & 1 deletion tests/test_text_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
TEST_WHISPER_VERSION = (
os.getenv("TEST_WHISPER_VERSION", None) or "openai/whisper-tiny"
) # tiny"
TEST_DEV = get_fastest_device()
TEST_DEV = os.getenv("TEST_DEV", None) or get_fastest_device()
TEST_TORCH_DTYPE = torch.float32
TEST_PSNR_THR = 35
TEST_CACHE_DIR = os.getenv("TEST_CACHE_DIR", None) or "/tmp"
Expand Down
6 changes: 4 additions & 2 deletions whisperkit/_constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2023 Argmax, Inc. All Rights Reserved.
# Copyright (C) 2024 Argmax, Inc. All Rights Reserved.
#
import os

Expand All @@ -22,5 +22,7 @@
EVAL_DATASETS.append(CUSTOM_EVAL_DATASET)

# Tests
OPENAI_API_MODEL_VERSION = "openai/whisper-large-v2"
OPENAI_API_MODEL_VERSION = "openai_whisper-large-v2"
OPENAI_API_MAX_FILE_SIZE = 25e6 # bytes
OPENAI_API_COMPRESSED_UPLOAD_BIT_RATE = "12k" # kbps
TEST_DATA_REPO = "argmaxinc/whisperkit-test-data"
88 changes: 72 additions & 16 deletions whisperkit/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,20 +394,19 @@ class WhisperOpenAIAPI:
See https://platform.openai.com/docs/guides/speech-to-text
"""
def __init__(self,
whisper_version: str = "openai/whisper-large-v2",
whisper_version: str = _constants.OPENAI_API_MODEL_VERSION,
out_dir: Optional[str] = ".",
**kwargs) -> None:

if whisper_version != "openai/whisper-large-v2":
raise ValueError("OpenAI API only supports 'openai/whisper-large-v2' as of 02/28/2024")
if whisper_version != _constants.OPENAI_API_MODEL_VERSION:
raise ValueError(f"OpenAI API only supports '{_constants.OPENAI_API_MODEL_VERSION}'")
self.whisper_version = whisper_version

self.client = None

if len(kwargs) > 0:
logger.warning(f"Unused kwargs: {kwargs}")

api_key = os.getenv("OPENAI_API_KEY", None)
assert api_key is not None
self.client = openai.Client(api_key=api_key)
self.out_dir = out_dir
self.results_dir = os.path.join(out_dir, "OpenAI-API")
os.makedirs(self.results_dir, exist_ok=True)
Expand All @@ -422,6 +421,54 @@ def __init__(self,
=======================================================
""")

def _maybe_init_client(self):
if self.client is None:
api_key = os.getenv("OPENAI_API_KEY", None)
assert api_key is not None
self.client = openai.Client(api_key=api_key)

def _maybe_compress_audio_file(self, audio_file_path: str) -> str:
""" If size of file at `audio_file_path` is larger than OpenAI API max file size, compress with ffmpeg
"""
audio_file_size = os.path.getsize(audio_file_path)
if audio_file_size > _constants.OPENAI_API_MAX_FILE_SIZE:
logger.info(
f"Compressing {audio_file_path.rsplit('/')[-1]} with size {audio_file_size / 1e6:.1f} MB > "
f"{_constants.OPENAI_API_MAX_FILE_SIZE / 1e6:.1f} MB (OpenAI API max file size)")

compressed_audio_file_path = os.path.splitext(audio_file_path)[0] + ".ogg"
# if not os.path.exists(compressed_audio_file_path):
if subprocess.check_call(" ".join([
"ffmpeg",
"-i", audio_file_path,
"-vn",
"-map_metadata", "-1",
"-ac", "1", "-c:a", "libopus", "-b:a", _constants.OPENAI_API_COMPRESSED_UPLOAD_BIT_RATE,
"-application", "voip",
"-y", # Overwrite
compressed_audio_file_path
]), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True):
raise subprocess.CalledProcessError(
"Failed to compress audio file. Make sure ffmpeg is installed.")

audio_file_path = compressed_audio_file_path
compressed_size = os.path.getsize(audio_file_path)

if compressed_size > _constants.OPENAI_API_MAX_FILE_SIZE:
raise ValueError(
f"Compressed file size {compressed_size / 1e6:.1f} MB exceeds OpenAI API max file size "
f"({_constants.OPENAI_API_MAX_FILE_SIZE / 1e6:.1f}) MB. Either (a) override "
"whisperkit._constants.OPENAI_API_COMPRESSED_UPLOAD_BIT_RATE with a lower value or (2) "
"follow https://platform.openai.com/docs/guides/speech-to-text/longer-inputs"
)

logger.info(
f"Compressed {audio_file_path.rsplit('/')[-1]} to {compressed_size / 1e6:.1f} MB < "
f"{_constants.OPENAI_API_MAX_FILE_SIZE / 1e6:.1f} MB"
)

return audio_file_path

def __call__(self, audio_file_path: str) -> str:
if not os.path.exists(audio_file_path):
raise FileNotFoundError(audio_file_path)
Expand All @@ -432,17 +479,26 @@ def __call__(self, audio_file_path: str) -> str:
-------------------------------------------------------
=======================================================
""")
with open(audio_file_path, "rb") as file_handle:
api_result = json.loads(self.client.audio.transcriptions.create(
model="whisper-1",
timestamp_granularities=["word", "segment"],
response_format="verbose_json",
file=file_handle,
).json())

result_fname = f"{audio_file_path.rsplit('/')[-1].rsplit('.')[0]}.json"
with open(os.path.join(self.results_dir, result_fname), "w") as f:
json.dump(api_result, f, indent=4)

if not os.path.exists(os.path.join(self.results_dir, result_fname)):
audio_file_path = self._maybe_compress_audio_file(audio_file_path)
self._maybe_init_client()

with open(audio_file_path, "rb") as file_handle:
api_result = json.loads(self.client.audio.transcriptions.create(
model="whisper-1",
timestamp_granularities=["word", "segment"],
response_format="verbose_json",
file=file_handle,
).json())

# result_fname = f"{audio_file_path.rsplit('/')[-1].rsplit('.')[0]}.json"
with open(os.path.join(self.results_dir, result_fname), "w") as f:
json.dump(api_result, f, indent=4)
else:
with open(os.path.join(self.results_dir, result_fname), "r") as f:
api_result = json.load(f)

logger.info(f"""\n
=======================================================
Expand Down
3 changes: 2 additions & 1 deletion whisperkit/text_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#
from copy import deepcopy
from itertools import product
import os
from typing import Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -334,7 +335,7 @@ def _fill_lut(
using forced decoder context prefix tokens and save results to nn.Embedding look-up tables
"""
str2int = self.tokenizer.vocab
dev = argmaxtools_utils.get_fastest_device()
dev = os.getenv("TEST_DEV", None) or argmaxtools_utils.get_fastest_device()

# Note: The cache technically shouldn't be pre-computable because, even though the (forced)
# decoder `input_ids` are known ahead of time, the cache is also a function of the runtime-
Expand Down

0 comments on commit 95d178c

Please sign in to comment.