-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(whisper-large-v3): add basic usage (#3)
* feat(whisper-v3): start from scratch with basic of structure to exists in hugging website * refactor: update benchmark workflow and enhance transcription output with JSON result * chore: update requirement.txt * chore: update benchmark workflow to install ffmpeg and upgrade pip, and modify dependencies in requirements.txt * chore: add accelerate dependency to requirements.txt * chore: restart to base * feat(whisper-model-v3): start whisper model in external directory * feat(whisper-model-v3): adding path of wav file and update github workflow * fix(whisper-model-v3): change pip version * chore(whisper-model-v3): update dependencies * fix(whisper-model-v3): add audio trimming functionality and use trimmed audio in pipeline for getting just 28 second of wav file * chore(whisper-model-v3): update dependencies * feat(whisper-model-v3): add JSON output functionality for pipeline results * fix(whisper-model-v3): use trimmed audio file in pipeline for output * fix(benchmark): add missing package installation in workflow * fix(benchmark): simplify ffmpeg installation in workflow * fix(devcontainer): remove unnecessary package installation in Dockerfile and return to base * fix(main): add json import for enhanced data handling * fix(whisper-v3): update JSON output handling to improve readability * fix(whisper-v3): enhance audio processing with error handling and path management * refactor(benchmark): add workflow for ASR model benchmarking with whisper-v3 * fix(whisper-v3): enhance error handling in audio processing and model loading * chore(benchmark): remove obsolete benchmark workflow for whisper-v3 * lint: make happy * refactor(whisper-large-v3): enhance usage * fix(workflow): indent * chore: review * chore(whisper): deps --------- Co-authored-by: S. Amir Mohammad Najafi <njfamirm@gmail.com>
- Loading branch information
Showing
8 changed files
with
211 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# [openai/whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import sys | ||
import json | ||
from io import StringIO | ||
from predict import load_model, predict | ||
from metrics import calculate_wer | ||
|
||
processor, model = load_model() | ||
|
||
audio_path = '../assets/audio-01.wav' | ||
ground_truth_path = '../assets/audio-01.txt' | ||
|
||
# Read the ground truth text | ||
with open(ground_truth_path, 'r') as file: | ||
ground_truth = file.read() | ||
|
||
transcription = predict(processor, model, audio_path) | ||
|
||
# Calculate WER before normalizing | ||
wer_score_before = calculate_wer(ground_truth, transcription) | ||
|
||
output_data = { | ||
'transcription': transcription, | ||
'werBeforeNormalization': wer_score_before, | ||
'normalizedTranscription': transcription, | ||
'werAfterNormalization': wer_score_before | ||
} | ||
|
||
# Write JSON output to a file | ||
with open('result.json', 'w', encoding='utf-8') as f: | ||
json.dump(output_data, f, ensure_ascii=False, indent=4) | ||
|
||
# Read and print the JSON output | ||
with open('result.json', 'r', encoding='utf-8') as f: | ||
json_output = f.read() | ||
|
||
print(json_output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from jiwer import wer | ||
|
||
def calculate_wer(ground_truth, transcription): | ||
wer_score = wer(ground_truth, transcription) | ||
return wer_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
import torchaudio | ||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | ||
|
||
# Set the device to GPU if available, otherwise use CPU | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | ||
|
||
def load_model(): | ||
model_name = "openai/whisper-large-v3" | ||
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name, torch_dtype=torch_dtype) | ||
model.to(device) | ||
processor = AutoProcessor.from_pretrained(model_name) | ||
return processor, model | ||
|
||
def predict(processor, model, audio_path): | ||
# Initialize the pipeline | ||
pipe = pipeline( | ||
"automatic-speech-recognition", | ||
model=model, | ||
tokenizer=processor.tokenizer, | ||
feature_extractor=processor.feature_extractor, | ||
torch_dtype=torch_dtype, | ||
device=device, | ||
return_timestamps=True, | ||
language="fa" | ||
) | ||
|
||
result = pipe(audio_path) | ||
transcription = result["text"] | ||
return transcription |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
accelerate==1.2.1 | ||
asttokens==3.0.0 | ||
attrs==24.2.0 | ||
backcall==0.2.0 | ||
beautifulsoup4==4.12.3 | ||
bleach==6.2.0 | ||
certifi==2024.12.14 | ||
charset-normalizer==3.4.0 | ||
click==8.1.7 | ||
decorator==5.1.1 | ||
defusedxml==0.7.1 | ||
docopt==0.6.2 | ||
executing==2.1.0 | ||
fastjsonschema==2.21.1 | ||
filelock==3.16.1 | ||
fsspec==2024.10.0 | ||
gitdb==4.0.11 | ||
GitPython==3.1.41 | ||
huggingface-hub==0.26.5 | ||
idna==3.10 | ||
ipython==8.12.3 | ||
jedi==0.19.2 | ||
Jinja2==3.1.4 | ||
jiwer==3.0.5 | ||
joblib==1.4.2 | ||
jsonschema==4.23.0 | ||
jsonschema-specifications==2024.10.1 | ||
jupyter_client==8.6.3 | ||
jupyter_core==5.7.2 | ||
jupyterlab_pygments==0.3.0 | ||
MarkupSafe==3.0.2 | ||
matplotlib-inline==0.1.7 | ||
mistune==3.0.2 | ||
mpmath==1.3.0 | ||
nbclient==0.10.1 | ||
nbconvert==7.16.4 | ||
nbformat==5.10.4 | ||
networkx==3.4.2 | ||
nltk==3.9.1 | ||
num2fawords==1.1 | ||
numpy==2.2.0 | ||
nvidia-cublas-cu12==12.4.5.8 | ||
nvidia-cuda-cupti-cu12==12.4.127 | ||
nvidia-cuda-nvrtc-cu12==12.4.127 | ||
nvidia-cuda-runtime-cu12==12.4.127 | ||
nvidia-cudnn-cu12==9.1.0.70 | ||
nvidia-cufft-cu12==11.2.1.3 | ||
nvidia-curand-cu12==10.3.5.147 | ||
nvidia-cusolver-cu12==11.6.1.9 | ||
nvidia-cusparse-cu12==12.3.1.170 | ||
nvidia-nccl-cu12==2.21.5 | ||
nvidia-nvjitlink-cu12==12.4.127 | ||
nvidia-nvtx-cu12==12.4.127 | ||
packaging==24.2 | ||
pandocfilters==1.5.1 | ||
parsivar==0.2.3.1 | ||
parso==0.8.4 | ||
pexpect==4.9.0 | ||
pickleshare==0.7.5 | ||
pipreqs==0.5.0 | ||
platformdirs==4.3.6 | ||
prompt_toolkit==3.0.48 | ||
psutil==6.1.0 | ||
ptyprocess==0.7.0 | ||
pure_eval==0.2.3 | ||
Pygments==2.18.0 | ||
python-dateutil==2.9.0.post0 | ||
PyYAML==6.0.2 | ||
pyzmq==26.2.0 | ||
RapidFuzz==3.10.1 | ||
referencing==0.35.1 | ||
regex==2024.11.6 | ||
requests==2.32.3 | ||
rpds-py==0.22.3 | ||
safetensors==0.4.5 | ||
setuptools==69.0.3 | ||
six==1.17.0 | ||
smmap==5.0.1 | ||
soupsieve==2.6 | ||
stack-data==0.6.3 | ||
sympy==1.13.1 | ||
tinycss2==1.4.0 | ||
tokenizers==0.21.0 | ||
torch==2.5.1 | ||
torchaudio==2.5.1 | ||
tornado==6.4.2 | ||
tqdm==4.67.1 | ||
traitlets==5.14.3 | ||
transformers==4.47.0 | ||
triton==3.1.0 | ||
typing_extensions==4.12.2 | ||
urllib3==2.2.3 | ||
wcwidth==0.2.13 | ||
webencodings==0.5.1 | ||
yarg==0.1.9 |