Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wav2vec2-xls-r-300m): add initial implementation with model load… #6

Open
wants to merge 1 commit into
base: next
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
- path: wav2vec2-large-xlsr-persian-v3
- path: nemo
- path: whisper-large-v3
- path: facebook-wav2vec2-xls-r-300m

steps:
- name: Checkout repository
Expand Down
1 change: 1 addition & 0 deletions facebook-wav2vec2-xls-r-300m/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m)
34 changes: 34 additions & 0 deletions facebook-wav2vec2-xls-r-300m/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys
import json
from io import StringIO
from predict import load_model, predict
from metrics import calculate_wer

processor, model = load_model()

audio_path = '../assets/audio-01.wav'
ground_truth_path = '../assets/audio-01.txt'

# Read the ground truth text
with open(ground_truth_path, 'r') as file:
ground_truth = file.read()

transcription = predict(processor, model, audio_path)

# Calculate WER before normalizing
wer_score_before = calculate_wer(ground_truth, transcription)

output_data = {
'transcription': transcription,
'werBeforeNormalization': wer_score_before,
}

# Write JSON output to a file
with open('result.json', 'w', encoding='utf-8') as f:
json.dump(output_data, f, ensure_ascii=False, indent=4)

# Read and print the JSON output
with open('result.json', 'r', encoding='utf-8') as f:
json_output = f.read()

print(json_output)
5 changes: 5 additions & 0 deletions facebook-wav2vec2-xls-r-300m/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from jiwer import wer

def calculate_wer(ground_truth, transcription):
wer_score = wer(ground_truth, transcription)
return wer_score
36 changes: 36 additions & 0 deletions facebook-wav2vec2-xls-r-300m/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

# Set the device to GPU if available, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_model():
model_name='facebook/wav2vec2-xls-r-300m'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
return processor, model

def predict(processor, model, audio_path):
# Load an audio file
speech_array, sampling_rate = torchaudio.load(audio_path)
speech_array = speech_array.squeeze().numpy()

# Process the audio
features = processor(
speech_array,
sampling_rate=sampling_rate,
return_tensors='pt',
padding=True
)

input_values = features.input_values.to(device)
attention_mask = features.attention_mask.to(device)

with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits

pred_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(pred_ids)[0]

return transcription
6 changes: 6 additions & 0 deletions facebook-wav2vec2-xls-r-300m/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
jiwer==3.0.5
num2fawords==1.1
parsivar==0.2.3.1
torch==2.5.1
torchaudio==2.5.1
transformers==4.47.0
Loading