Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add speaker-aware transcription #147

Draft
wants to merge 26 commits into
base: develop
Choose a base branch
from
Draft

Add speaker-aware transcription #147

wants to merge 26 commits into from

Conversation

juanmc2005
Copy link
Owner

Depends on #144

This PR adds a new SpeakerAwareTranscription pipeline that combines streaming diarization and streaming transcription to determine "who says what" in a live conversation. By default, this is shown as colored words in the terminal.

The feature works as expected with diart.stream and diart.serve/diart.client.
The main thing preventing full compatibility with diart.benchmark and diart.tune is the evaluation metric.
Since the output of the pipeline is annotated text with the format: [speaker0]Hello [speaker1]Hi, the metric diart.metrics.WordErrorRate will count labels as insertion errors.

Next steps: implement a SpeakerWordErrorRate that computes the (weighted?) average WER across speakers.

Changelog

TBD

@juanmc2005 juanmc2005 added the feature New feature or request label Apr 26, 2023
@juanmc2005 juanmc2005 added this to the Version 0.8 milestone Apr 26, 2023
@juanmc2005 juanmc2005 modified the milestones: Version 0.8, Version 0.9 Oct 11, 2023
@C0RE1312
Copy link

C0RE1312 commented Apr 6, 2024

Hey, I am unable to use this:
(diart) :~/live-transcript$ diart.stream output.wav --pipeline SpeakerAwareTranscription
Traceback (most recent call last):
File "/home/user/anaconda3/envs/diart/lib/python3.8/site-packages/rx/core/operators/map.py", line 37, in on_next
result = _mapper(value)
File "/home/user/anaconda3/envs/diart/lib/python3.8/site-packages/diart/pipelines/speaker_transcription.py", line 325, in call
asr_outputs = self.asr(batch[has_voice])
File "/home/user/anaconda3/envs/diart/lib/python3.8/site-packages/diart/blocks/asr.py", line 65, in call
output = self.model(wave.to(self.device))
File "/home/user/anaconda3/envs/diart/lib/python3.8/site-packages/diart/models.py", line 80, in call
return super().call(*args, **kwargs)
File "/home/user/anaconda3/envs/diart/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/user/anaconda3/envs/diart/lib/python3.8/site-packages/diart/models.py", line 485, in forward
batch = whisper.log_mel_spectrogram(batch)
File "/home/user/anaconda3/envs/diart/lib/python3.8/site-packages/whisper/audio.py", line 148, in log_mel_spectrogram
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
File "/home/user/anaconda3/envs/diart/lib/python3.8/site-packages/torch/functional.py", line 632, in stft
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]
RuntimeError: cuFFT error: CUFFT_INTERNAL_ERROR

@juanmc2005
Copy link
Owner Author

@C0RE1312 sounds like a problem with pytorch not being able to compute the FFT. Have you tried updating the dependenciesof both torch and whisper? it's a pretty old PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants