Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add speaker-aware transcription #147

Draft
wants to merge 26 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
bca2873
New feature: streaming voice activity detection. Pipeline name changes
juanmc2005 Apr 19, 2023
5e44ad4
Merge branch 'develop' of github.com:juanmc2005/OnlineDiarization int…
juanmc2005 Apr 19, 2023
7447061
Update link in setup.cfg
juanmc2005 Apr 19, 2023
4985394
Update code snippets in README
juanmc2005 Apr 19, 2023
540ad0a
Add minor README modifications
juanmc2005 Apr 19, 2023
8cc9925
Initial ASR implementation. Broken stuff
juanmc2005 Apr 21, 2023
1ae4934
First working transcription pipeline. Using diarization is possible b…
juanmc2005 Apr 21, 2023
d8d7342
Reduce Whisper VRAM footprint (around 400Mb). Add fp16 option
juanmc2005 Apr 21, 2023
2cfc35d
Change whisper input type based on fp16 parameter
juanmc2005 Apr 21, 2023
a40112c
Implement batched inference for whisper. Re-implement decoding.
juanmc2005 Apr 22, 2023
e8196a7
Minor changes in transcription arguments
juanmc2005 Apr 22, 2023
07dd9ae
Greatly improve transcription pipeline by adding optional VAD
juanmc2005 Apr 23, 2023
0bf2522
Move pipelines to diart.pipelines. Add torchmetrics as a dependency
juanmc2005 Apr 23, 2023
42fe5f7
Add websocket compatibility to transcription pipeline
juanmc2005 Apr 23, 2023
49616e5
Transcription pipeline is now fully compatible with diart.stream
juanmc2005 Apr 23, 2023
babf49d
Make transcription pipeline compatible with diart.benchmark and diart…
juanmc2005 Apr 24, 2023
6609e3c
Rename base pipeline and config objects
juanmc2005 Apr 24, 2023
4c1aeba
Merge changes from branch feat/vad
juanmc2005 Apr 24, 2023
d19b044
New feature: streaming voice activity detection. Pipeline name changes
juanmc2005 Apr 19, 2023
6caa4a4
Update link in setup.cfg
juanmc2005 Apr 19, 2023
0993fe8
Update code snippets in README
juanmc2005 Apr 19, 2023
95d4fae
Add minor README modifications
juanmc2005 Apr 19, 2023
569c68f
Rename base pipeline and config objects
juanmc2005 Apr 24, 2023
eed864f
Update branch with develop
juanmc2005 Apr 24, 2023
a16bb5c
Add initial implementation of SpeakerAwareTranscription
juanmc2005 Apr 24, 2023
c7bbcc4
Refactor SpeakerAwareTranscription
juanmc2005 Apr 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
</a>
<span> | </span>
<a href="#-custom-models">
🤖 Custom models
🤖 Add your model
</a>
<span> | </span>
<a href="#-tune-hyper-parameters">
Expand Down Expand Up @@ -110,32 +110,32 @@ See `diart.stream -h` for more options.

### From python

Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk:
Use `StreamingInference` to run a pipeline on an audio source and write the results to disk:

```python
from diart import OnlineSpeakerDiarization
from diart import SpeakerDiarization
from diart.sources import MicrophoneAudioSource
from diart.inference import RealTimeInference
from diart.inference import StreamingInference
from diart.sinks import RTTMWriter

pipeline = OnlineSpeakerDiarization()
pipeline = SpeakerDiarization()
mic = MicrophoneAudioSource(pipeline.config.sample_rate)
inference = RealTimeInference(pipeline, mic, do_plot=True)
inference = StreamingInference(pipeline, mic, do_plot=True)
inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm"))
prediction = inference()
```

For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)).

## 🤖 Custom models
## 🤖 Add your model

Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel` (which are PyTorch `Module` subclasses):
Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`):

```python
from diart import OnlineSpeakerDiarization, PipelineConfig
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import EmbeddingModel, SegmentationModel
from diart.sources import MicrophoneAudioSource
from diart.inference import RealTimeInference
from diart.inference import StreamingInference


def model_loader():
Expand Down Expand Up @@ -168,19 +168,19 @@ class MyEmbeddingModel(EmbeddingModel):
return self.model(waveform, weights)


config = PipelineConfig(
config = SpeakerDiarizationConfig(
segmentation=MySegmentationModel(),
embedding=MyEmbeddingModel()
)
pipeline = OnlineSpeakerDiarization(config)
pipeline = SpeakerDiarization(config)
mic = MicrophoneAudioSource(config.sample_rate)
inference = RealTimeInference(pipeline, mic)
inference = StreamingInference(pipeline, mic)
prediction = inference()
```

## 📈 Tune hyper-parameters

Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset.
Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs.

### From the command line

Expand Down Expand Up @@ -281,7 +281,7 @@ diart.serve --host 0.0.0.0 --port 7007
diart.client microphone --host <server-address> --port 7007
```

**Note:** please make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.
**Note:** make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.

See `-h` for more options.

Expand All @@ -290,13 +290,13 @@ See `-h` for more options.
For customized solutions, a server can also be created in python using the `WebSocketAudioSource`:

```python
from diart import OnlineSpeakerDiarization
from diart import SpeakerDiarization
from diart.sources import WebSocketAudioSource
from diart.inference import RealTimeInference
from diart.inference import StreamingInference

pipeline = OnlineSpeakerDiarization()
pipeline = SpeakerDiarization()
source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007)
inference = RealTimeInference(pipeline, source)
inference = StreamingInference(pipeline, source)
inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm()))
prediction = inference()
```
Expand Down Expand Up @@ -354,14 +354,14 @@ or using the inference API:

```python
from diart.inference import Benchmark, Parallelize
from diart import OnlineSpeakerDiarization, PipelineConfig
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import SegmentationModel

benchmark = Benchmark("/wav/dir", "/rttm/dir")

name = "pyannote/segmentation@Interspeech2021"
segmentation = SegmentationModel.from_pyannote(name)
config = PipelineConfig(
config = SpeakerDiarizationConfig(
# Set the model used in the paper
segmentation=segmentation,
step=0.5,
Expand All @@ -370,12 +370,12 @@ config = PipelineConfig(
rho_update=0.422,
delta_new=1.517
)
benchmark(OnlineSpeakerDiarization, config)
benchmark(SpeakerDiarization, config)

# Run the same benchmark in parallel
p_benchmark = Parallelize(benchmark, num_workers=4)
if __name__ == "__main__": # Needed for multiprocessing
p_benchmark(OnlineSpeakerDiarization, config)
p_benchmark(SpeakerDiarization, config)
```

This pre-calculates model outputs in batches, so it runs a lot faster.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pandas>=1.4.2
torch>=1.12.1
torchvision>=0.14.0
torchaudio>=0.12.1,<1.0
torchmetrics>=0.11.1
pyannote.audio>=2.1.1
pyannote.core>=4.5
pyannote.database>=4.1.1
Expand Down
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
name=diart
version=0.7.0
author=Juan Manuel Coria
description=Speaker diarization in real time
description=Streaming speaker diarization in real-time
long_description=file: README.md
long_description_content_type=text/markdown
keywords=speaker diarization, streaming, online, real time, rxpy
url=https://github.com/juanmc2005/StreamingSpeakerDiarization
url=https://github.com/juanmc2005/diart
license=MIT
classifiers=
Development Status :: 4 - Beta
Expand All @@ -31,6 +31,7 @@ install_requires=
torch>=1.12.1
torchvision>=0.14.0
torchaudio>=0.12.1,<1.0
torchmetrics>=0.11.1
pyannote.audio>=2.1.1
pyannote.core>=4.5
pyannote.database>=4.1.1
Expand Down
12 changes: 8 additions & 4 deletions src/diart/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from .blocks import (
OnlineSpeakerDiarization,
BasePipeline,
from .pipelines import (
Pipeline,
PipelineConfig,
BasePipelineConfig,
SpeakerDiarization,
SpeakerDiarizationConfig,
VoiceActivityDetection,
VoiceActivityDetectionConfig,
Transcription,
TranscriptionConfig,
)
5 changes: 2 additions & 3 deletions src/diart/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
FirstOnlyStrategy,
DelayedAggregation,
)
from .clustering import OnlineSpeakerClustering
from .clustering import IncrementalSpeakerClustering
from .embedding import (
SpeakerEmbedding,
OverlappedSpeechPenalty,
EmbeddingNormalization,
OverlapAwareSpeakerEmbedding,
)
from .segmentation import SpeakerSegmentation
from .diarization import OnlineSpeakerDiarization, BasePipeline
from .config import BasePipelineConfig, PipelineConfig
from .utils import Binarize, Resample, AdjustVolume
from .asr import SpeechRecognition
66 changes: 66 additions & 0 deletions src/diart/blocks/asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from pathlib import Path
from typing import Optional, Union, List, Text

import torch
from einops import rearrange

from .. import models as m
from ..features import TemporalFeatureFormatter, TemporalFeatures


class SpeechRecognition:
def __init__(self, model: m.SpeechRecognitionModel, device: Optional[torch.device] = None):
self.model = model
self.model.eval()
self.device = device
if self.device is None:
self.device = torch.device("cpu")
self.model.to(self.device)
self.formatter = TemporalFeatureFormatter()

@staticmethod
def from_whisper(
name: Text,
download_path: Optional[Union[Text, Path]] = None,
in_memory: bool = False,
fp16: bool = False,
no_speech_threshold: float = 0.6,
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1,
decode_with_fallback: bool = False,
device: Optional[Union[Text, torch.device]] = None,
) -> 'SpeechRecognition':
asr_model = m.SpeechRecognitionModel.from_whisper(
name,
download_path,
in_memory,
fp16,
no_speech_threshold,
compression_ratio_threshold,
logprob_threshold,
decode_with_fallback,
)
return SpeechRecognition(asr_model, device)

def __call__(self, waveform: TemporalFeatures) -> List[m.TranscriptionResult]:
"""
Compute the transcription of input audio.

Parameters
----------
waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels)
Audio to transcribe

Returns
-------
transcriptions: List[Transcription]
A list of timestamped transcriptions
"""
with torch.no_grad():
wave = rearrange(
self.formatter.cast(waveform),
"batch sample channel -> batch channel sample"
)
# output = self.model(wave.to(self.device)).cpu()
output = self.model(wave.to(self.device))
return output
2 changes: 1 addition & 1 deletion src/diart/blocks/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..mapping import SpeakerMap, SpeakerMapBuilder


class OnlineSpeakerClustering:
class IncrementalSpeakerClustering:
"""Implements constrained incremental online clustering of speakers and manages cluster centers.

Parameters
Expand Down
Loading