-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Getting time offsets of beginning and end of each word in Wav2Vec2 #11307
Comments
This sounds like a nice feature, but I sadly won't have time to work on it - let's see if someone in the community could be interested :-) |
There is something like this which may help : https://github.com/lumaku/espnet/blob/espnet2_ctc_segmentation/espnet2/bin/asr_align.py I need some help in integrating it to wav2vec2 in hugging face. |
@theainerd are you working on this feature? |
I would also really like to see this feature. @theainerd I'd be happy to help in any way I can although I'm not too familiar with the Wav2Vec transformer. @patrickvonplaten do you think you could write out a brief outline of what you think the steps required would be? |
Hi there! I'm very very new to collaborating on open-source projects as well as on using huggingface/transformers in general therefore I'm not confident I can come up with a solution for this issue -- however I did some poking around with tutorials surrounding Wav2Vec2 and I was thinking of ways on how this might be implemented: It seems like the Wav2Vec2FeatureExtractor does most of the heavylifting of converting the raw audio array to suitable input values -> These input values are then fed into the model to obtain the logits (Dimension of the output is observed to be dropped a considerable amount here) -> after applying argmax to obtain the IDs, these IDs are then fed back into the Wav2Vec2CTCTokenizer decode/batch_decode function to obtain the transcription. Perhaps information of the sampling rate should be stored within the Tokenizer class such that during decode it's able to make use of this information to determine the timestamp? Or it might be possible to store it within the Wav2Vec2Processor class and have some wrapper functions take care of determining the timestamp and including it during the decode step A relation of how the input values dimensions are mapped to the output logit's dimensions would be needed for this, which I don't have the expertise at the moment to figure out CC: sources I've been referring to -- |
+1 on this, i'd really appreciate timestamped words as well. the datasets like timit, etc. seem to have this info, but i guess that's part of their test set, not an output from the model itself. |
Here's what i've found so far: this was for a 30s audio file.
` |
Thanks for investigating on this -- while I think it may be possible to just use the ratio and sampling rate to derive the timestamp, what I'm afraid of is that this ratio might just be a "magic number" and might differ if there are variations in the configuration of the Wav2Vec2 model Current ratio from input_values size to logits seem to be around 320 e.g.: Is this ratio constant for varying size of audio? (Experiment with different size WAV clips and check the ratio) |
Maybe @patrickvonplaten could shed some light of whether we are going in the right direction about this (if it's not too much trouble) 😓 🙏 |
hey @yushao2, what ratio are you referring to here ? sorry, not too familiar with audio processing |
@patrickvonplaten @yushao2 following up on this |
Hi there! Sorry for not being responsive here. The ratio here refers to the number you get when you divide the size of in this case, you mentioned
the ratio would be 480000/1499 which is approximately 320 |
Hello all, There is something I have found which may serve as a good starting point. Basically this returns the time offsets and the textual data as well . https://github.com/lumaku/ctc-segmentation import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import re
from ctc_segmentation import ctc_segmentation
from ctc_segmentation import CtcSegmentationParameters
from ctc_segmentation import determine_utterance_segments
from ctc_segmentation import prepare_text
# Get the Wav2Vec2 model and the predicted text
test_dataset = load_dataset("common_voice", "hi", split="test")
wer = load_metric("wer")
processor = Wav2Vec2Processor.from_pretrained("theainerd/Wav2Vec2-large-xlsr-hindi")
model = Wav2Vec2ForCTC.from_pretrained("theainerd/Wav2Vec2-large-xlsr-hindi")
model.to("cuda")
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“]'
resampler = torchaudio.transforms.Resample(48_000, 16_000)
# Preprocessing the datasets.
# We need to read the aduio files as arrays
def speech_file_to_array_fn(batch):
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
speech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = resampler(speech_array).squeeze().numpy()
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
input_values = processor(test_dataset["speech"][0], return_tensors="pt").input_values # Batch size 1
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
softmax = torch.nn.Softmax(dim = -1)
# apply configuration
config = CtcSegmentationParameters()
with torch.no_grad():
# Apply ctc layer to obtain log character probabilities
lpz = softmax(logits)[0].cpu().numpy()
char_dict = {"न": 0, "च": 1, "थ": 2, "ी": 3, "ऐ": 4, "ृ": 5, "ध": 6, "य": 7, "ह": 8, "ऊ": 9, "म": 10, "ण": 11, "ै": 13, "ौ": 14, "ा": 15, "ल": 16, "त": 17, "इ": 18, "ढ़": 19, "ष": 20, "भ": 21, "ग़": 22, "ख": 23, "ड़": 24, "ए": 25, "व": 26, "ु": 27, "ओ": 28, "र": 29, "श": 30, "औ": 31, "ट": 32, "आ": 33, "ो": 34, "ढ": 35, "झ": 36, "ग": 37, "ज़": 38, "अ": 39, "े": 40, "प": 41, "घ": 42, "द": 43, "ई": 44, "फ़": 45, "ब": 46, "ड": 47, "ँ": 48, "छ": 49, "ू": 50, "फ": 51, "ि": 52, "स": 53, "्": 54, "क": 55, "उ": 56, "ठ": 57, "ं": 58, "़": 59, "ज": 60, "क़": 61, "|": 12, "[UNK]": 62, "[PAD]": 63}
char_list = list(char_dict.keys())
# Prepare the text for aligning
ground_truth_mat, utt_begin_indices = prepare_text(config, transcription,char_list)
# Align using CTC segmentation
timings, char_probs, state_list = ctc_segmentation(config, lpz, ground_truth_mat)
# Obtain list of utterances with time intervals and confidence score
segments = determine_utterance_segments(config, utt_begin_indices, char_probs, timings, transcription)
# Sample Output : 0.26 1.73 -0.0154 THE SALE OF THE HOTELS * An example picked up from the ctc_segmentation Now if I have the time offsets but how to get this for each and every word rather than the segments. Please don't take this as an absolute solution as I am not sure that this is a good direction to go but still something is better than nothing. Please share your thoughts. |
Hi everyone, here is some sample code which I have created to get the word-level start and end timestamps. from itertools import groupby
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import soundfile as sf
##############
# load model & audio and run audio through model
##############
model_name = 'facebook/wav2vec2-large-960h-lv60-self'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).cuda()
audio_filepath = ''
speech, sample_rate = sf.read(audio_filepath)
input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values.cuda()
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0]).lower()
##############
# this is where the logic starts to get the start and end timestamp for each word
##############
words = [w for w in transcription.split(' ') if len(w) > 0]
predicted_ids = predicted_ids[0].tolist()
duration_sec = input_values.shape[1] / sample_rate
ids_w_time = [(i / len(predicted_ids) * duration_sec, _id) for i, _id in enumerate(predicted_ids)]
# remove entries which are just "padding" (i.e. no characers are recognized)
ids_w_time = [i for i in ids_w_time if i[1] != processor.tokenizer.pad_token_id]
# now split the ids into groups of ids where each group represents a word
split_ids_w_time = [list(group) for k, group
in groupby(ids_w_time, lambda x: x[1] == processor.tokenizer.word_delimiter_token_id)
if not k]
assert len(split_ids_w_time) == len(words) # make sure that there are the same number of id-groups as words. Otherwise something is wrong
word_start_times = []
word_end_times = []
for cur_ids_w_time, cur_word in zip(split_ids_w_time, words):
_times = [_time for _time, _id in cur_ids_w_time]
word_start_times.append(min(_times))
word_end_times.append(max(_times))
words, word_start_times, word_end_times |
@KB-g |
@KB-g from itertools import groupby
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import soundfile as sf
model_name = 'DewiBrynJones/wav2vec2-large-xlsr-welsh'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
audio_filepath = '/tmp/assert.wav'
speech, sample_rate = sf.read(audio_filepath)
input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0]).lower()
##############
# this is where the logic starts to get the start and end timestamp for each word
##############
words = [w for w in transcription.split(' ') if len(w) > 0]
predicted_ids = predicted_ids[0].tolist()
duration_sec = input_values.shape[1] / sample_rate
ids_w_time = [(i / len(predicted_ids) * duration_sec, _id) for i, _id in enumerate(predicted_ids)]
ids_w_time = [i for i in ids_w_time if i[1] != processor.tokenizer.pad_token_id]
split_ids_w_time = [list(group) for k, group
in groupby(ids_w_time, lambda x: x[1] == processor.tokenizer.word_delimiter_token_id)
if not k]
# make sure that there are the same number of id-groups as words. Otherwise something is wrong
assert len(split_ids_w_time) == len(words), (len(split_ids_w_time), len(words)) |
Hi @doublex , @abhirooptalasila, |
Hi @KB-g, @doublex and @abhirooptalasila, maybe this tutorial helps to find out a way to calculate a "per-word probability". In the function |
We need to document the time stamp retrieval a bit better here I think |
@KB-g Thanks for the code snippet, really useful. Made a small addition (no_grad) for inference, would help people facing OOM error(s): from itertools import groupby
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import soundfile as sf
##############
# load model & audio and run audio through model
##############
model_name = 'facebook/wav2vec2-large-960h-lv60-self'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).cuda()
audio_filepath = ''
speech, sample_rate = sf.read(audio_filepath)
input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values.cuda()
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0]).lower()
##############
# this is where the logic starts to get the start and end timestamp for each word
##############
words = [w for w in transcription.split(' ') if len(w) > 0]
predicted_ids = predicted_ids[0].tolist()
duration_sec = input_values.shape[1] / sample_rate
ids_w_time = [(i / len(predicted_ids) * duration_sec, _id) for i, _id in enumerate(predicted_ids)]
# remove entries which are just "padding" (i.e. no characers are recognized)
ids_w_time = [i for i in ids_w_time if i[1] != processor.tokenizer.pad_token_id]
# now split the ids into groups of ids where each group represents a word
split_ids_w_time = [list(group) for k, group
in groupby(ids_w_time, lambda x: x[1] == processor.tokenizer.word_delimiter_token_id)
if not k]
assert len(split_ids_w_time) == len(words) # make sure that there are the same number of id-groups as words. Otherwise something is wrong
word_start_times = []
word_end_times = []
for cur_ids_w_time, cur_word in zip(split_ids_w_time, words):
_times = [_time for _time, _id in cur_ids_w_time]
word_start_times.append(min(_times))
word_end_times.append(max(_times))
words, word_start_times, word_end_times |
@Ap1075, thank you for the example you provided above. I'm having a hard time figuring out where/how to pass in transcribed text so it can be aligned with the audio. Is passing in pre-transcribed text possible, or am I misunderstanding how it works? |
I'm trying to get word timing for karaoke I have the lyrics... Would this be possible? 🤔 |
Hi there @patrickvonplaten , I'd like to take a look at this issue and see if I can help fix it. Please let me know if it's already assigned to someone or if there's anything specific I should keep in mind while working on it. Thanks, |
🚀 Feature request
Hello I was thinking it would be of great help if I can get the time offsets of start and end of each word .
Motivation
I was going through Google Speech to text documentation and found this feature and thought will be really amazing if i can have something similar here.
Your contribution
I can really use some help in this task and would love to implement something similar.
The text was updated successfully, but these errors were encountered: