Skip to content

Commit

Permalink
Enable torch compile. Which means we need gcc in the env. (#256)
Browse files Browse the repository at this point in the history
But this further increases transcription speed.
  • Loading branch information
jquagga authored Oct 9, 2024
1 parent 7c84a4d commit 35533b0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 30 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
dependencies:
- python=3.12
- ffmpeg
- gcc
- pip
- pip:
- better_profanity
Expand Down
59 changes: 29 additions & 30 deletions ttt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import requests
import torch
from better_profanity import profanity
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
Expand All @@ -24,7 +25,7 @@
# Before we dig in, let's globally set up transformers
# We will load up the model, etc now so we only need to
# use the PIPE constant in the function.

torch.set_float32_matmul_precision("high")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = os.environ.get("TTT_TRANSFORMERS_MODEL_ID", "openai/whisper-large-v3-turbo")
Expand All @@ -36,6 +37,12 @@
use_safetensors=True,
)
model.to(device)

# Enable static cache and compile the forward pass
model.generation_config.cache_implementation = "static"
model.generation_config.max_new_tokens = 256
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

processor = AutoProcessor.from_pretrained(model_id)
PIPE = pipeline(
"automatic-speech-recognition",
Expand All @@ -52,46 +59,41 @@


def transcribe_transformers(calljson, audiofile):
"""Transcribes audio file using transformers library.
"""
Transcribes audio from the given file using transformers.
Args:
calljson (dict): A dictionary containing the JSON data.
audiofile (str): The path to the audio file.
calljson (dict): JSON data containing call information.
audiofile (str): Path to the audio file.
Returns:
dict: The updated calljson dictionary with the transcript.
Explanation:
This function transcribes the audio file using the transformers library. It loads a pre-trained model
and processor, creates a pipeline for automatic speech recognition, and processes the audio file.
The resulting transcript is added to the calljson dictionary and returned.
dict: Updated calljson with transcribed text.
"""

audiofile = str(audiofile)

# Set the return argument to english
# result = PIPE(audiofile, generate_kwargs={"language": "english"})
result = PIPE(audiofile, generate_kwargs={"language": "english", "return_timestamps": True})
# Set the return argument to english & return timestamps to support
# calls over 30 seconds.
with sdpa_kernel(SDPBackend.MATH):
result = PIPE(
audiofile,
generate_kwargs={"language": "english", "return_timestamps": True},
)
calljson["text"] = result["text"]
return calljson


def send_notifications(calljson, audiofile, destinations):
"""
Sends notifications using the provided calljson, audiofile, and destinations.
Sends notifications based on call information.
Args:
calljson (dict): The JSON object containing call information.
audiofile (str): The path to the audio file.
destinations (dict): A dictionary mapping short names and talkgroups to notification URLs.
Raises:
None
calljson (dict): JSON data containing call information.
audiofile (str): Path to the audio file.
destinations (dict): Dictionary mapping short names to talkgroup URLs.
Returns:
None
Examples:
send_notifications(calljson, audiofile, destinations)
"""

# Run ai text through profanity filter
Expand Down Expand Up @@ -124,21 +126,18 @@ def send_notifications(calljson, audiofile, destinations):

def audio_notification(audiofile, apobj, body, title):
"""
Encode audio file to AAC format and send a notification with the audio attachment.
Notifies with audio attachment if possible, else with text only.
Args:
audiofile (str): Path to the input audio file.
apobj: Object used to send notifications.
audiofile (str): Path to the audio file.
apobj: Apprise object for notifications.
body (str): Body of the notification.
title (str): Title of the notification.
Returns:
None
Raises:
subprocess.CalledProcessError: If ffmpeg encoding fails.
subprocess.TimeoutExpired: If ffmpeg encoding exceeds 30 seconds.
"""

# Try and except to handle ffmpeg encoding failures
# If it fails, just upload the text and skip the audio attachment
try:
Expand Down

0 comments on commit 35533b0

Please sign in to comment.