Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting time offsets of beginning and end of each word in Wav2Vec2 #11307

Open
theainerd opened this issue Apr 19, 2021 · 25 comments
Open

Getting time offsets of beginning and end of each word in Wav2Vec2 #11307

theainerd opened this issue Apr 19, 2021 · 25 comments
Labels
Good First Issue Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!

Comments

@theainerd
Copy link
Contributor

🚀 Feature request

Hello I was thinking it would be of great help if I can get the time offsets of start and end of each word .

Motivation

I was going through Google Speech to text documentation and found this feature and thought will be really amazing if i can have something similar here.

Your contribution

I can really use some help in this task and would love to implement something similar.

@theainerd
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten added Good First Issue Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! labels Apr 23, 2021
@patrickvonplaten
Copy link
Contributor

This sounds like a nice feature, but I sadly won't have time to work on it - let's see if someone in the community could be interested :-)

@theainerd
Copy link
Contributor Author

There is something like this which may help : https://github.com/lumaku/espnet/blob/espnet2_ctc_segmentation/espnet2/bin/asr_align.py

I need some help in integrating it to wav2vec2 in hugging face.

@Muktan
Copy link
Contributor

Muktan commented Apr 28, 2021

@theainerd are you working on this feature?

@MerryOscar
Copy link

I would also really like to see this feature.

@theainerd I'd be happy to help in any way I can although I'm not too familiar with the Wav2Vec transformer.

@patrickvonplaten do you think you could write out a brief outline of what you think the steps required would be?

@yushao2
Copy link

yushao2 commented May 3, 2021

Hi there!

I'm very very new to collaborating on open-source projects as well as on using huggingface/transformers in general therefore I'm not confident I can come up with a solution for this issue -- however I did some poking around with tutorials surrounding Wav2Vec2 and I was thinking of ways on how this might be implemented:

It seems like the Wav2Vec2FeatureExtractor does most of the heavylifting of converting the raw audio array to suitable input values

-> These input values are then fed into the model to obtain the logits (Dimension of the output is observed to be dropped a considerable amount here)

-> after applying argmax to obtain the IDs, these IDs are then fed back into the Wav2Vec2CTCTokenizer decode/batch_decode function to obtain the transcription.

Perhaps information of the sampling rate should be stored within the Tokenizer class such that during decode it's able to make use of this information to determine the timestamp? Or it might be possible to store it within the Wav2Vec2Processor class and have some wrapper functions take care of determining the timestamp and including it during the decode step

A relation of how the input values dimensions are mapped to the output logit's dimensions would be needed for this, which I don't have the expertise at the moment to figure out

CC:
@theainerd
@MerryOscar
@patrickvonplaten

sources I've been referring to --
https://www.kdnuggets.com/2021/03/speech-text-wav2vec.html (I realise this is outdated with the old tokenizer class, which seems to perform feature extraction as well)

https://huggingface.co/blog/fine-tune-wav2vec2-english

@krrishdholakia
Copy link

+1 on this, i'd really appreciate timestamped words as well. the datasets like timit, etc. seem to have this info, but i guess that's part of their test set, not an output from the model itself.

@krrishdholakia
Copy link

Here's what i've found so far:
if speech length is - 480,000
input_values lenth - 480,000
logits length - 1499

this was for a 30s audio file.
`
model = Wav2Vec2ForCTC
processor = Wav2Vec2Processor

    input_values = processor(speech, return_tensors="pt").input_values
    logits = model(input_values).logits

`

@yushao2
Copy link

yushao2 commented May 5, 2021

Here's what i've found so far:
if speech length is - 480,000
input_values lenth - 480,000
logits length - 1499

this was for a 30s audio file.
`
model = Wav2Vec2ForCTC
processor = Wav2Vec2Processor

    input_values = processor(speech, return_tensors="pt").input_values
    logits = model(input_values).logits

`

Thanks for investigating on this -- while I think it may be possible to just use the ratio and sampling rate to derive the timestamp, what I'm afraid of is that this ratio might just be a "magic number" and might differ if there are variations in the configuration of the Wav2Vec2 model

Current ratio from input_values size to logits seem to be around 320

e.g.:
Does the ratio change if the hyperparameters of the model are changed?

Is this ratio constant for varying size of audio? (Experiment with different size WAV clips and check the ratio)

@yushao2
Copy link

yushao2 commented May 5, 2021

Here's what i've found so far:
if speech length is - 480,000
input_values lenth - 480,000
logits length - 1499
this was for a 30s audio file.
`
model = Wav2Vec2ForCTC
processor = Wav2Vec2Processor

    input_values = processor(speech, return_tensors="pt").input_values
    logits = model(input_values).logits

`

Thanks for investigating on this -- while I think it may be possible to just use the ratio and sampling rate to derive the timestamp, what I'm afraid of is that this ratio might just be a "magic number" and might differ if there are variations in the configuration of the Wav2Vec2 model

Current ratio from input_values size to logits seem to be around 320

e.g.:
Does the ratio change if the hyperparameters of the model are changed?

Is this ratio constant for varying size of audio? (Experiment with different size WAV clips and check the ratio)

Maybe @patrickvonplaten could shed some light of whether we are going in the right direction about this (if it's not too much trouble) 😓 🙏

@krrishdholakia
Copy link

Here's what i've found so far:
if speech length is - 480,000
input_values lenth - 480,000
logits length - 1499
this was for a 30s audio file.
`
model = Wav2Vec2ForCTC
processor = Wav2Vec2Processor

    input_values = processor(speech, return_tensors="pt").input_values
    logits = model(input_values).logits

`

Thanks for investigating on this -- while I think it may be possible to just use the ratio and sampling rate to derive the timestamp, what I'm afraid of is that this ratio might just be a "magic number" and might differ if there are variations in the configuration of the Wav2Vec2 model

Current ratio from input_values size to logits seem to be around 320

e.g.:
Does the ratio change if the hyperparameters of the model are changed?

Is this ratio constant for varying size of audio? (Experiment with different size WAV clips and check the ratio)

hey @yushao2, what ratio are you referring to here ? sorry, not too familiar with audio processing

@krrishdholakia
Copy link

@patrickvonplaten @yushao2 following up on this

@yushao2
Copy link

yushao2 commented May 14, 2021

@patrickvonplaten @yushao2 following up on this

Hi there! Sorry for not being responsive here.

The ratio here refers to the number you get when you divide the size of input_values to the size of logits

in this case, you mentioned

input_values lenth - 480,000
logits length - 1499

the ratio would be 480000/1499 which is approximately 320

@theainerd
Copy link
Contributor Author

Hello all,

There is something I have found which may serve as a good starting point. Basically this returns the time offsets and the textual data as well .

https://github.com/lumaku/ctc-segmentation

import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import re

from ctc_segmentation import ctc_segmentation
from ctc_segmentation import CtcSegmentationParameters
from ctc_segmentation import determine_utterance_segments
from ctc_segmentation import prepare_text

# Get the Wav2Vec2 model and the predicted text
test_dataset = load_dataset("common_voice", "hi", split="test")
wer = load_metric("wer")

processor = Wav2Vec2Processor.from_pretrained("theainerd/Wav2Vec2-large-xlsr-hindi")
model = Wav2Vec2ForCTC.from_pretrained("theainerd/Wav2Vec2-large-xlsr-hindi")
model.to("cuda")

chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“]'

resampler = torchaudio.transforms.Resample(48_000, 16_000)

# Preprocessing the datasets.
# We need to read the aduio files as arrays
def speech_file_to_array_fn(batch):
  batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
  speech_array, sampling_rate = torchaudio.load(batch["path"])
  batch["speech"] = resampler(speech_array).squeeze().numpy()
  return batch

test_dataset = test_dataset.map(speech_file_to_array_fn)

input_values = processor(test_dataset["speech"][0], return_tensors="pt").input_values  # Batch size 1
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])

softmax = torch.nn.Softmax(dim = -1)

# apply configuration
config = CtcSegmentationParameters()

with torch.no_grad():
    # Apply ctc layer to obtain log character probabilities
    lpz = softmax(logits)[0].cpu().numpy()

char_dict = {"न": 0, "च": 1, "थ": 2, "ी": 3, "ऐ": 4, "ृ": 5, "ध": 6, "य": 7, "ह": 8, "ऊ": 9, "म": 10, "ण": 11, "ै": 13, "ौ": 14, "ा": 15, "ल": 16, "त": 17, "इ": 18, "ढ़": 19, "ष": 20, "भ": 21, "ग़": 22, "ख": 23, "ड़": 24, "ए": 25, "व": 26, "ु": 27, "ओ": 28, "र": 29, "श": 30, "औ": 31, "ट": 32, "आ": 33, "ो": 34, "ढ": 35, "झ": 36, "ग": 37, "ज़": 38, "अ": 39, "े": 40, "प": 41, "घ": 42, "द": 43, "ई": 44, "फ़": 45, "ब": 46, "ड": 47, "ँ": 48, "छ": 49, "ू": 50, "फ": 51, "ि": 52, "स": 53, "्": 54, "क": 55, "उ": 56, "ठ": 57, "ं": 58, "़": 59, "ज": 60, "क़": 61, "|": 12, "[UNK]": 62, "[PAD]": 63}
char_list = list(char_dict.keys())

# Prepare the text for aligning
ground_truth_mat, utt_begin_indices = prepare_text(config, transcription,char_list)
# Align using CTC segmentation
timings, char_probs, state_list = ctc_segmentation(config, lpz, ground_truth_mat)

# Obtain list of utterances with time intervals and confidence score
segments = determine_utterance_segments(config, utt_begin_indices, char_probs, timings, transcription)
# Sample Output : 0.26 1.73 -0.0154 THE SALE OF THE HOTELS * An example picked up from the ctc_segmentation 

Now if I have the time offsets but how to get this for each and every word rather than the segments. Please don't take this as an absolute solution as I am not sure that this is a good direction to go but still something is better than nothing. Please share your thoughts.

@huggingface huggingface deleted a comment Jun 7, 2021
@huggingface huggingface deleted a comment Jun 7, 2021
@huggingface huggingface deleted a comment Jun 7, 2021
@KB-g
Copy link

KB-g commented Jun 24, 2021

Hi everyone, here is some sample code which I have created to get the word-level start and end timestamps.
It's surely a bit hacky, and I could imagine there being some special cases where it might break, but for the cases I have tried it with it worked great.

from itertools import groupby
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import soundfile as sf

##############
# load model & audio and run audio through model
##############
model_name = 'facebook/wav2vec2-large-960h-lv60-self'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).cuda()


audio_filepath = ''
speech, sample_rate = sf.read(audio_filepath)
input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values.cuda()

logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0]).lower()

##############
# this is where the logic starts to get the start and end timestamp for each word
##############
words = [w for w in transcription.split(' ') if len(w) > 0]
predicted_ids = predicted_ids[0].tolist()
duration_sec = input_values.shape[1] / sample_rate


ids_w_time = [(i / len(predicted_ids) * duration_sec, _id) for i, _id in enumerate(predicted_ids)]
# remove entries which are just "padding" (i.e. no characers are recognized)
ids_w_time = [i for i in ids_w_time if i[1] != processor.tokenizer.pad_token_id]
# now split the ids into groups of ids where each group represents a word
split_ids_w_time = [list(group) for k, group
                    in groupby(ids_w_time, lambda x: x[1] == processor.tokenizer.word_delimiter_token_id)
                    if not k]

assert len(split_ids_w_time) == len(words)  # make sure that there are the same number of id-groups as words. Otherwise something is wrong

word_start_times = []
word_end_times = []
for cur_ids_w_time, cur_word in zip(split_ids_w_time, words):
    _times = [_time for _time, _id in cur_ids_w_time]
    word_start_times.append(min(_times))
    word_end_times.append(max(_times))
    
words, word_start_times, word_end_times

@doublex
Copy link

doublex commented Jun 24, 2021

@KB-g
Congrats!
Is there a chance to also extract the "per word probability"?

@doublex
Copy link

doublex commented Jun 27, 2021

@KB-g
The assert len() == len() triggers.
This audio: assert.zip
Testcase:

from itertools import groupby
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import soundfile as sf

model_name = 'DewiBrynJones/wav2vec2-large-xlsr-welsh'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

audio_filepath = '/tmp/assert.wav'
speech, sample_rate = sf.read(audio_filepath)
input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0]).lower()

##############
# this is where the logic starts to get the start and end timestamp for each word
##############
words = [w for w in transcription.split(' ') if len(w) > 0]
predicted_ids = predicted_ids[0].tolist()
duration_sec = input_values.shape[1] / sample_rate
ids_w_time = [(i / len(predicted_ids) * duration_sec, _id) for i, _id in enumerate(predicted_ids)]
ids_w_time = [i for i in ids_w_time if i[1] != processor.tokenizer.pad_token_id]
split_ids_w_time = [list(group) for k, group
                    in groupby(ids_w_time, lambda x: x[1] == processor.tokenizer.word_delimiter_token_id)
                    if not k]
# make sure that there are the same number of id-groups as words. Otherwise something is wrong
assert len(split_ids_w_time) == len(words), (len(split_ids_w_time), len(words))

@abhirooptalasila
Copy link

@KB-g Congrats! Is there a chance to also extract the "per word probability"?

Hey @KB-g
Any success on this?

@KB-g
Copy link

KB-g commented Jan 17, 2022

Hi @doublex , @abhirooptalasila,
I haven't tried to get the per-word probability. If you come up with a solution, it would be great if you could let me know. I'd also be interested in a solution :)

@jcsilva
Copy link

jcsilva commented Jan 25, 2022

Hi @KB-g, @doublex and @abhirooptalasila,

maybe this tutorial helps to find out a way to calculate a "per-word probability". In the function merge_words, the author calculates scores for each word based on tokens probabilities and theirs duration.

@patrickvonplaten
Copy link
Contributor

We need to document the time stamp retrieval a bit better here I think

@Ap1075
Copy link

Ap1075 commented Apr 24, 2022

@KB-g Thanks for the code snippet, really useful. Made a small addition (no_grad) for inference, would help people facing OOM error(s):

from itertools import groupby
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import soundfile as sf

##############
# load model & audio and run audio through model
##############
model_name = 'facebook/wav2vec2-large-960h-lv60-self'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).cuda()


audio_filepath = ''
speech, sample_rate = sf.read(audio_filepath)
input_values = processor(speech, sampling_rate=sample_rate, return_tensors="pt").input_values.cuda()

with torch.no_grad():
    logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0]).lower()

##############
# this is where the logic starts to get the start and end timestamp for each word
##############
words = [w for w in transcription.split(' ') if len(w) > 0]
predicted_ids = predicted_ids[0].tolist()
duration_sec = input_values.shape[1] / sample_rate


ids_w_time = [(i / len(predicted_ids) * duration_sec, _id) for i, _id in enumerate(predicted_ids)]
# remove entries which are just "padding" (i.e. no characers are recognized)
ids_w_time = [i for i in ids_w_time if i[1] != processor.tokenizer.pad_token_id]
# now split the ids into groups of ids where each group represents a word
split_ids_w_time = [list(group) for k, group
                    in groupby(ids_w_time, lambda x: x[1] == processor.tokenizer.word_delimiter_token_id)
                    if not k]

assert len(split_ids_w_time) == len(words)  # make sure that there are the same number of id-groups as words. Otherwise something is wrong

word_start_times = []
word_end_times = []
for cur_ids_w_time, cur_word in zip(split_ids_w_time, words):
    _times = [_time for _time, _id in cur_ids_w_time]
    word_start_times.append(min(_times))
    word_end_times.append(max(_times))
    
words, word_start_times, word_end_times

@samuelbradshaw
Copy link

@Ap1075, thank you for the example you provided above. I'm having a hard time figuring out where/how to pass in transcribed text so it can be aligned with the audio. Is passing in pre-transcribed text possible, or am I misunderstanding how it works?

@jmealo
Copy link

jmealo commented Jun 19, 2023

I'm trying to get word timing for karaoke I have the lyrics... Would this be possible? 🤔

@hegdeadithyak
Copy link
Contributor

Hi there @patrickvonplaten ,

I'd like to take a look at this issue and see if I can help fix it. Please let me know if it's already assigned to someone or if there's anything specific I should keep in mind while working on it.

Thanks,
Adithya Hegde Kota

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Good First Issue Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Projects
None yet
Development

No branches or pull requests