Skip to content

Commit

Permalink
Actually invoke whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusEverling committed Feb 23, 2024
1 parent 9691113 commit 579ba41
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 12 deletions.
3 changes: 1 addition & 2 deletions project_W_runner/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import aiohttp
import click

from project_W_runner.config import loadConfig
Expand All @@ -9,7 +8,7 @@
@click.option("--customConfigPath", type=str, required=False)
def main(customconfigpath: str = None):
config = loadConfig([customconfigpath]) if customconfigpath else loadConfig()
runner = Runner(backend_url=config["backendURL"], token=config["runnerToken"])
runner = Runner(backend_url=config["backendURL"], token=config["runnerToken"], torch_device=config.get("torchDevice"))
asyncio.run(runner.run())


Expand Down
3 changes: 3 additions & 0 deletions project_W_runner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class prettyValidationError(ValidationError):
"type": "string",
"pattern": r"^[a-zA-Z0-9_-]+$",
},
"torchDevice": {
"type": "string"
},
"disableOptionValidation": {
"type": "boolean",
"default": False
Expand Down
26 changes: 16 additions & 10 deletions project_W_runner/runner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import base64
from dataclasses import dataclass
from hmac import new
from os import error
from threading import Condition, Thread
import time
from typing import Optional

from project_W_runner.utils import prepare_audio, transcribe
from .logger import get_logger
import aiohttp

Expand All @@ -27,9 +27,10 @@ class JobData:
audio: bytes
model: Optional[str]
language: Optional[str]
progress: Optional[float] = None
transcript: Optional[str] = None
error: Optional[str] = None
progress: Optional[float] = None
current_step: Optional[str] = None


class ShutdownSignal(Exception):
Expand All @@ -47,16 +48,18 @@ def __init__(self, reason: str):
class Runner:
backend_url: str
token: str
torch_device: Optional[str]
current_job_data: Optional[JobData]
session: aiohttp.ClientSession
# We use this condition variable to signal to the processing thread
# that a new job has been assigned and it should start processing it.
cond_var: Condition
new_job: bool = False

def __init__(self, backend_url: str, token: str):
def __init__(self, backend_url: str, token: str, torch_device: Optional[str]):
self.backend_url = backend_url
self.token = token
self.torch_device = torch_device
self.current_job_data = None
self.cond_var = Condition()
Thread(target=self.run_processing_thread, daemon=True).start()
Expand All @@ -77,12 +80,15 @@ def process_current_job(self):
"""
Processes the current job, using the Whisper python package.
"""
self.current_job_data.progress = 0.0
for _ in range(10):
time.sleep(2)
self.current_job_data.progress += 0.1
self.current_job_data.transcript = "Lorem ipsum"
logger.info("Job processed, going back to idle")

# For some silly reason python doesn't let you do assignments in a lambda.
def progress_callback(progress: float):
print(f"Progress: {progress * 100:.2}%")
self.current_job_data.progress = progress

audio = prepare_audio(self.current_job_data.audio)
result = transcribe(audio, self.current_job_data.model, self.current_job_data.language, progress_callback, self.torch_device)
self.current_job_data.transcript = result["text"]

async def post(self, route: str, data: dict = None, params: dict = None, append_auth_header: bool = True):
"""
Expand Down
84 changes: 84 additions & 0 deletions project_W_runner/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@

from subprocess import CalledProcessError, run
from typing import Callable, Optional
import numpy as np
import tqdm
from whisper import load_model
from whisper.transcribe import transcribe as whisper_transcribe


def prepare_audio(audio: bytes) -> np.array:
"""
Transform an in-memory audio file into a NumPy array to be passed to Whisper.
The Whisper package seems to only allow reading files from disk, and not from memory,
so by doing this step ourselves, we can avoid having to write the audio to disk.
"""
args = [
"ffmpeg",
"-i", "pipe:",
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", "16000",
"-stats",
"pipe:"
]
try:
out = run(args, input=audio, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to transform audio: {e.stderr.decode()}") from e
except FileNotFoundError as e:
raise RuntimeError(f"File not found: {e.filename}\n\nIs ffmpeg installed and in the PATH?") from e

return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0


def transcribe(audio: np.array, model: Optional[str], language: Optional[str], progress_callback: Callable[[float], None], device: Optional[str]) -> dict[str, str]:
"""
Transcribe the given audio using the given model and language. Returns the dictionary of all the
information returned by the Whisper invocation. The progress_callback is called periodically with
the progress of the transcription, as a float between 0 and 1. If device is not None, it will be
used as the device for the model.
"""
# Heavy inspiration from <https://github.com/ssciwr/vink/blob/main/vink.py>
def monkeypatching_tqdm(progress_cb):
def _monkeypatching_tqdm(
total=None,
ncols=None,
unit=None,
unit_scale=True,
unit_divisor=None,
disable=False,
):
class TqdmMonkeypatchContext:
def __init__(self) -> None:
self.progress = 0.0

def __enter__(self):
return self

def __exit__(self, *args):
pass

def update(self, value):
if unit_divisor:
value = value / unit_divisor
self.progress += value
progress_cb(self.progress / total)
if unit_divisor:
total = total / unit_divisor

return TqdmMonkeypatchContext()

return _monkeypatching_tqdm

# TODO: Load the model for the correct language.
model = load_model(model or "base", device=device)
tqdm.tqdm = monkeypatching_tqdm(progress_callback)

progress_callback(0.0)
ret = whisper_transcribe(model, audio)
# Just in case the progress_callback was not called with 1.0, do that now.
progress_callback(1.0)

return ret
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ dependencies = [
"aiohttp",
"click",
"jsonschema",
"openai-whisper",
"numpy",
"platformdirs",
"pyaml_env",
]
Expand Down

0 comments on commit 579ba41

Please sign in to comment.