From bca2873e3ebaa901fdcaa37096f7e1904359f2ca Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Apr 2023 17:41:41 +0200 Subject: [PATCH 01/23] New feature: streaming voice activity detection. Pipeline name changes --- src/diart/__init__.py | 10 +- src/diart/blocks/__init__.py | 5 +- src/diart/blocks/base.py | 92 ++++++++++++++ src/diart/blocks/config.py | 153 ----------------------- src/diart/blocks/diarization.py | 145 ++++++++++++++++++---- src/diart/blocks/vad.py | 208 ++++++++++++++++++++++++++++++++ src/diart/console/benchmark.py | 12 +- src/diart/console/client.py | 6 +- src/diart/console/serve.py | 19 +-- src/diart/console/stream.py | 19 +-- src/diart/console/tune.py | 26 +++- src/diart/inference.py | 86 +++++++------ src/diart/optim.py | 56 ++++----- src/diart/sinks.py | 47 +++++--- src/diart/sources.py | 2 +- src/diart/utils.py | 16 ++- 16 files changed, 605 insertions(+), 297 deletions(-) create mode 100644 src/diart/blocks/base.py delete mode 100644 src/diart/blocks/config.py create mode 100644 src/diart/blocks/vad.py diff --git a/src/diart/__init__.py b/src/diart/__init__.py index c9692638..e29287a0 100644 --- a/src/diart/__init__.py +++ b/src/diart/__init__.py @@ -1,6 +1,8 @@ from .blocks import ( - OnlineSpeakerDiarization, - BasePipeline, - PipelineConfig, - BasePipelineConfig, + SpeakerDiarization, + StreamingPipeline, + SpeakerDiarizationConfig, + StreamingConfig, + VoiceActivityDetection, + VoiceActivityDetectionConfig, ) diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py index 59a6ef36..e6e8c479 100644 --- a/src/diart/blocks/__init__.py +++ b/src/diart/blocks/__init__.py @@ -13,6 +13,7 @@ OverlapAwareSpeakerEmbedding, ) from .segmentation import SpeakerSegmentation -from .diarization import OnlineSpeakerDiarization, BasePipeline -from .config import BasePipelineConfig, PipelineConfig +from .diarization import SpeakerDiarization, SpeakerDiarizationConfig +from .base import StreamingConfig, StreamingPipeline from .utils import Binarize, Resample, AdjustVolume +from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py new file mode 100644 index 00000000..28f313eb --- /dev/null +++ b/src/diart/blocks/base.py @@ -0,0 +1,92 @@ +from typing import Any, Tuple, Sequence, Text +from dataclasses import dataclass + +import numpy as np +from pyannote.core import SlidingWindowFeature +from pyannote.metrics.base import BaseMetric + +from .. import utils +from ..audio import FilePath, AudioLoader + + +@dataclass +class HyperParameter: + name: Text + low: float + high: float + + @staticmethod + def from_name(name: Text) -> 'HyperParameter': + if name == "tau_active": + return TauActive + if name == "rho_update": + return RhoUpdate + if name == "delta_new": + return DeltaNew + raise ValueError(f"Hyper-parameter '{name}' not recognized") + + +TauActive = HyperParameter("tau_active", low=0, high=1) +RhoUpdate = HyperParameter("rho_update", low=0, high=1) +DeltaNew = HyperParameter("delta_new", low=0, high=2) + + +class StreamingConfig: + @property + def duration(self) -> float: + raise NotImplementedError + + @property + def step(self) -> float: + raise NotImplementedError + + @property + def latency(self) -> float: + raise NotImplementedError + + @property + def sample_rate(self) -> int: + raise NotImplementedError + + @staticmethod + def from_dict(data: Any) -> 'StreamingConfig': + raise NotImplementedError + + def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: + file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath) + right = utils.get_padding_right(self.latency, self.step) + left = utils.get_padding_left(file_duration + right, self.duration) + return left, right + + def optimal_block_size(self) -> int: + return int(np.rint(self.step * self.sample_rate)) + + +class StreamingPipeline: + @staticmethod + def get_config_class() -> type: + raise NotImplementedError + + @staticmethod + def suggest_metric() -> BaseMetric: + raise NotImplementedError + + @staticmethod + def hyper_parameters() -> Sequence[HyperParameter]: + raise NotImplementedError + + @property + def config(self) -> StreamingConfig: + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + def set_timestamp_shift(self, shift: float): + raise NotImplementedError + + def __call__( + self, + waveforms: Sequence[SlidingWindowFeature] + ) -> Sequence[Tuple[Any, SlidingWindowFeature]]: + raise NotImplementedError diff --git a/src/diart/blocks/config.py b/src/diart/blocks/config.py deleted file mode 100644 index d8e2a656..00000000 --- a/src/diart/blocks/config.py +++ /dev/null @@ -1,153 +0,0 @@ -from typing import Any, Optional, Union, Tuple - -import numpy as np -import torch -from typing_extensions import Literal - -from .. import models as m -from .. import utils -from ..audio import FilePath, AudioLoader - - -class BasePipelineConfig: - @property - def duration(self) -> float: - raise NotImplementedError - - @property - def step(self) -> float: - raise NotImplementedError - - @property - def latency(self) -> float: - raise NotImplementedError - - @property - def sample_rate(self) -> int: - raise NotImplementedError - - @staticmethod - def from_dict(data: Any) -> 'BasePipelineConfig': - raise NotImplementedError - - def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: - file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath) - right = utils.get_padding_right(self.latency, self.step) - left = utils.get_padding_left(file_duration + right, self.duration) - return left, right - - def optimal_block_size(self) -> int: - return int(np.rint(self.step * self.sample_rate)) - - -class PipelineConfig(BasePipelineConfig): - def __init__( - self, - segmentation: Optional[m.SegmentationModel] = None, - embedding: Optional[m.EmbeddingModel] = None, - duration: Optional[float] = None, - step: float = 0.5, - latency: Optional[Union[float, Literal["max", "min"]]] = None, - tau_active: float = 0.6, - rho_update: float = 0.3, - delta_new: float = 1, - gamma: float = 3, - beta: float = 10, - max_speakers: int = 20, - device: Optional[torch.device] = None, - **kwargs, - ): - # Default segmentation model is pyannote/segmentation - self.segmentation = segmentation - if self.segmentation is None: - self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") - - # Default duration is the one given by the segmentation model - self._duration = duration - - # Expected sample rate is given by the segmentation model - self._sample_rate: Optional[int] = None - - # Default embedding model is pyannote/embedding - self.embedding = embedding - if self.embedding is None: - self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding") - - # Latency defaults to the step duration - self._step = step - self._latency = latency - if self._latency is None or self._latency == "min": - self._latency = self._step - elif self._latency == "max": - self._latency = self._duration - - self.tau_active = tau_active - self.rho_update = rho_update - self.delta_new = delta_new - self.gamma = gamma - self.beta = beta - self.max_speakers = max_speakers - - self.device = device - if self.device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - @staticmethod - def from_dict(data: Any) -> 'PipelineConfig': - # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None - device = utils.get(data, "device", None) - if device is None: - device = torch.device("cpu") if utils.get(data, "cpu", False) else None - - # Instantiate models - hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) - segmentation = utils.get(data, "segmentation", "pyannote/segmentation") - segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) - embedding = utils.get(data, "embedding", "pyannote/embedding") - embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token) - - # Hyper-parameters and their aliases - tau = utils.get(data, "tau_active", None) - if tau is None: - tau = utils.get(data, "tau", 0.6) - rho = utils.get(data, "rho_update", None) - if rho is None: - rho = utils.get(data, "rho", 0.3) - delta = utils.get(data, "delta_new", None) - if delta is None: - delta = utils.get(data, "delta", 1) - - return PipelineConfig( - segmentation=segmentation, - embedding=embedding, - duration=utils.get(data, "duration", None), - step=utils.get(data, "step", 0.5), - latency=utils.get(data, "latency", None), - tau_active=tau, - rho_update=rho, - delta_new=delta, - gamma=utils.get(data, "gamma", 3), - beta=utils.get(data, "beta", 10), - max_speakers=utils.get(data, "max_speakers", 20), - device=device, - ) - - @property - def duration(self) -> float: - if self._duration is None: - self._duration = self.segmentation.duration - return self._duration - - @property - def step(self) -> float: - return self._step - - @property - def latency(self) -> float: - return self._latency - - @property - def sample_rate(self) -> int: - if self._sample_rate is None: - self._sample_rate = self.segmentation.sample_rate - return self._sample_rate diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index 7f0e162c..f2a25119 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -1,42 +1,137 @@ -from typing import Optional, Tuple, Sequence +from typing import Optional, Tuple, Sequence, Union, Any import numpy as np import torch from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment +from pyannote.metrics.base import BaseMetric +from pyannote.metrics.diarization import DiarizationErrorRate +from typing_extensions import Literal from .aggregation import DelayedAggregation +from . import base from .clustering import OnlineSpeakerClustering from .embedding import OverlapAwareSpeakerEmbedding from .segmentation import SpeakerSegmentation from .utils import Binarize -from .config import BasePipelineConfig, PipelineConfig +from .. import models as m +from .. import utils -class BasePipeline: +class SpeakerDiarizationConfig(base.StreamingConfig): + def __init__( + self, + segmentation: Optional[m.SegmentationModel] = None, + embedding: Optional[m.EmbeddingModel] = None, + duration: Optional[float] = None, + step: float = 0.5, + latency: Optional[Union[float, Literal["max", "min"]]] = None, + tau_active: float = 0.6, + rho_update: float = 0.3, + delta_new: float = 1, + gamma: float = 3, + beta: float = 10, + max_speakers: int = 20, + device: Optional[torch.device] = None, + **kwargs, + ): + # Default segmentation model is pyannote/segmentation + self.segmentation = segmentation + if self.segmentation is None: + self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") + + self._duration = duration + self._sample_rate: Optional[int] = None + + # Default embedding model is pyannote/embedding + self.embedding = embedding + if self.embedding is None: + self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding") + + # Latency defaults to the step duration + self._step = step + self._latency = latency + if self._latency is None or self._latency == "min": + self._latency = self._step + elif self._latency == "max": + self._latency = self._duration + + self.tau_active = tau_active + self.rho_update = rho_update + self.delta_new = delta_new + self.gamma = gamma + self.beta = beta + self.max_speakers = max_speakers + + self.device = device + if self.device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + @staticmethod - def get_config_class() -> type: - raise NotImplementedError + def from_dict(data: Any) -> 'SpeakerDiarizationConfig': + # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None + device = utils.get(data, "device", None) + if device is None: + device = torch.device("cpu") if utils.get(data, "cpu", False) else None + + # Instantiate models + hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) + segmentation = utils.get(data, "segmentation", "pyannote/segmentation") + segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) + embedding = utils.get(data, "embedding", "pyannote/embedding") + embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token) + + # Hyper-parameters and their aliases + tau = utils.get(data, "tau_active", None) + if tau is None: + tau = utils.get(data, "tau", 0.6) + rho = utils.get(data, "rho_update", None) + if rho is None: + rho = utils.get(data, "rho", 0.3) + delta = utils.get(data, "delta_new", None) + if delta is None: + delta = utils.get(data, "delta", 1) + + return SpeakerDiarizationConfig( + segmentation=segmentation, + embedding=embedding, + duration=utils.get(data, "duration", None), + step=utils.get(data, "step", 0.5), + latency=utils.get(data, "latency", None), + tau_active=tau, + rho_update=rho, + delta_new=delta, + gamma=utils.get(data, "gamma", 3), + beta=utils.get(data, "beta", 10), + max_speakers=utils.get(data, "max_speakers", 20), + device=device, + ) @property - def config(self) -> BasePipelineConfig: - raise NotImplementedError + def duration(self) -> float: + # Default duration is the one given by the segmentation model + if self._duration is None: + self._duration = self.segmentation.duration + return self._duration - def reset(self): - raise NotImplementedError + @property + def step(self) -> float: + return self._step - def set_timestamp_shift(self, shift: float): - raise NotImplementedError + @property + def latency(self) -> float: + return self._latency - def __call__( - self, - waveforms: Sequence[SlidingWindowFeature] - ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: - raise NotImplementedError + @property + def sample_rate(self) -> int: + # Expected sample rate is given by the segmentation model + if self._sample_rate is None: + self._sample_rate = self.segmentation.sample_rate + return self._sample_rate -class OnlineSpeakerDiarization(BasePipeline): - def __init__(self, config: Optional[PipelineConfig] = None): - self._config = PipelineConfig() if config is None else config +class SpeakerDiarization(base.StreamingPipeline): + def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): + self._config = SpeakerDiarizationConfig() if config is None else config msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" assert self._config.step <= self._config.latency <= self._config.duration, msg @@ -67,10 +162,18 @@ def __init__(self, config: Optional[PipelineConfig] = None): @staticmethod def get_config_class() -> type: - return PipelineConfig + return SpeakerDiarizationConfig + + @staticmethod + def suggest_metric() -> BaseMetric: + return DiarizationErrorRate(collar=0, skip_overlap=False) + + @staticmethod + def hyper_parameters() -> Sequence[base.HyperParameter]: + return [base.TauActive, base.RhoUpdate, base.DeltaNew] @property - def config(self) -> PipelineConfig: + def config(self) -> SpeakerDiarizationConfig: return self._config def set_timestamp_shift(self, shift: float): diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py new file mode 100644 index 00000000..def833b6 --- /dev/null +++ b/src/diart/blocks/vad.py @@ -0,0 +1,208 @@ +from typing import Any, Optional, Union, Sequence, Tuple + +import numpy as np +import torch +from pyannote.core import Annotation, Timeline, SlidingWindowFeature, SlidingWindow, Segment +from pyannote.metrics.base import BaseMetric +from pyannote.metrics.detection import DetectionErrorRate +from typing_extensions import Literal + +from .aggregation import DelayedAggregation +from . import base +from .segmentation import SpeakerSegmentation +from .utils import Binarize +from .. import models as m +from .. import utils + + +class VoiceActivityDetectionConfig(base.StreamingConfig): + def __init__( + self, + segmentation: Optional[m.SegmentationModel] = None, + duration: Optional[float] = None, + step: float = 0.5, + latency: Optional[Union[float, Literal["max", "min"]]] = None, + tau_active: float = 0.6, + device: Optional[torch.device] = None, + **kwargs, + ): + # Default segmentation model is pyannote/segmentation + self.segmentation = segmentation + if self.segmentation is None: + self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") + + self._duration = duration + self._step = step + self._sample_rate: Optional[int] = None + + # Latency defaults to the step duration + self._latency = latency + if self._latency is None or self._latency == "min": + self._latency = self._step + elif self._latency == "max": + self._latency = self._duration + + self.tau_active = tau_active + self.device = device + if self.device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + @property + def duration(self) -> float: + # Default duration is the one given by the segmentation model + if self._duration is None: + self._duration = self.segmentation.duration + return self._duration + + @property + def step(self) -> float: + return self._step + + @property + def latency(self) -> float: + return self._latency + + @property + def sample_rate(self) -> int: + # Expected sample rate is given by the segmentation model + if self._sample_rate is None: + self._sample_rate = self.segmentation.sample_rate + return self._sample_rate + + @staticmethod + def from_dict(data: Any) -> 'VoiceActivityDetectionConfig': + # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None + device = utils.get(data, "device", None) + if device is None: + device = torch.device("cpu") if utils.get(data, "cpu", False) else None + + # Instantiate segmentation model + hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) + segmentation = utils.get(data, "segmentation", "pyannote/segmentation") + segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) + + # Tau active and its alias + tau = utils.get(data, "tau_active", None) + if tau is None: + tau = utils.get(data, "tau", 0.6) + + return VoiceActivityDetectionConfig( + segmentation=segmentation, + duration=utils.get(data, "duration", None), + step=utils.get(data, "step", 0.5), + latency=utils.get(data, "latency", None), + tau_active=tau, + device=device, + ) + + +class VoiceActivityDetection(base.StreamingPipeline): + def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): + self._config = VoiceActivityDetectionConfig() if config is None else config + + msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" + assert self._config.step <= self._config.latency <= self._config.duration, msg + + self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device) + self.pred_aggregation = DelayedAggregation( + self._config.step, + self._config.latency, + strategy="hamming", + cropping_mode="loose", + ) + self.audio_aggregation = DelayedAggregation( + self._config.step, + self._config.latency, + strategy="first", + cropping_mode="center", + ) + self.binarize = Binarize(self._config.tau_active) + + # Internal state, handle with care + self.timestamp_shift = 0 + self.chunk_buffer, self.pred_buffer = [], [] + + @staticmethod + def get_config_class() -> type: + return VoiceActivityDetectionConfig + + @staticmethod + def suggest_metric() -> BaseMetric: + return DetectionErrorRate(collar=0, skip_overlap=False) + + @staticmethod + def hyper_parameters() -> Sequence[base.HyperParameter]: + return [base.TauActive] + + @property + def config(self) -> base.StreamingConfig: + return self._config + + def reset(self): + self.set_timestamp_shift(0) + self.chunk_buffer, self.pred_buffer = [], [] + + def set_timestamp_shift(self, shift: float): + self.timestamp_shift = shift + + def __call__( + self, + waveforms: Sequence[SlidingWindowFeature], + ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: + batch_size = len(waveforms) + msg = "Pipeline expected at least 1 input" + assert batch_size >= 1, msg + + # Create batch from chunk sequence, shape (batch, samples, channels) + batch = torch.stack([torch.from_numpy(w.data) for w in waveforms]) + + expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate)) + msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" + assert batch.shape[1] == expected_num_samples, msg + + # Extract segmentation + segmentations = self.segmentation(batch) # shape (batch, frames, speakers) + voice_detection = torch.max(segmentations, dim=-1, keepdim=True)[0] # shape (batch, frames, 1) + + seg_resolution = waveforms[0].extent.duration / segmentations.shape[1] + + outputs = [] + for wav, vad in zip(waveforms, voice_detection): + # Add timestamps to segmentation + sw = SlidingWindow( + start=wav.extent.start, + duration=seg_resolution, + step=seg_resolution, + ) + vad = SlidingWindowFeature(vad.cpu().numpy(), sw) + + # Update sliding buffer + self.chunk_buffer.append(wav) + self.pred_buffer.append(vad) + + # Aggregate buffer outputs for this time step + agg_waveform = self.audio_aggregation(self.chunk_buffer) + agg_prediction = self.pred_aggregation(self.pred_buffer) + agg_prediction = self.binarize(agg_prediction).get_timeline(copy=False) + + # Shift prediction timestamps if required + if self.timestamp_shift != 0: + shifted_agg_prediction = Timeline(uri=agg_prediction.uri) + for segment in agg_prediction: + new_segment = Segment( + segment.start + self.timestamp_shift, + segment.end + self.timestamp_shift, + ) + shifted_agg_prediction.add(new_segment) + agg_prediction = shifted_agg_prediction + + # Convert timeline into annotation with single speaker "speech" + agg_prediction = agg_prediction.to_annotation(utils.repeat_label("speech")) + outputs.append((agg_prediction, agg_waveform)) + + # Make place for new chunks in buffer if required + if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows: + self.chunk_buffer = self.chunk_buffer[1:] + self.pred_buffer = self.pred_buffer[1:] + + return outputs diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py index b6a3f9ff..27d524c5 100644 --- a/src/diart/console/benchmark.py +++ b/src/diart/console/benchmark.py @@ -1,15 +1,17 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc import pandas as pd -from diart.blocks import OnlineSpeakerDiarization, PipelineConfig +from diart import argdoc +from diart import utils from diart.inference import Benchmark, Parallelize def run(): parser = argparse.ArgumentParser() parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") + parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -34,6 +36,8 @@ def run(): help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() + pipeline_class = utils.get_pipeline_class(args.pipeline) + benchmark = Benchmark( args.root, args.reference, @@ -43,11 +47,11 @@ def run(): batch_size=args.batch_size, ) - config = PipelineConfig.from_dict(vars(args)) + config = pipeline_class.get_config_class().from_dict(vars(args)) if args.num_workers > 0: benchmark = Parallelize(benchmark, args.num_workers) - report = benchmark(OnlineSpeakerDiarization, config) + report = benchmark(pipeline_class, config) if args.output is not None and isinstance(report, pd.DataFrame): report.to_csv(args.output / "benchmark_report.csv") diff --git a/src/diart/console/client.py b/src/diart/console/client.py index 084dbc13..db4915fa 100644 --- a/src/diart/console/client.py +++ b/src/diart/console/client.py @@ -3,11 +3,11 @@ from threading import Thread from typing import Text, Optional -import diart.argdoc as argdoc -import diart.sources as src -import diart.utils as utils import numpy as np import rx.operators as ops +from diart import argdoc +from diart import sources as src +from diart import utils from websocket import WebSocket diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index 2f632d57..46bb9328 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -1,10 +1,10 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc -import diart.sources as src -from diart.blocks import OnlineSpeakerDiarization, PipelineConfig -from diart.inference import RealTimeInference +from diart import argdoc +from diart import sources as src +from diart import utils +from diart.inference import StreamingInference from diart.sinks import RTTMWriter @@ -12,6 +12,8 @@ def run(): parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0", type=str, help="Server host") parser.add_argument("--port", default=7007, type=int, help="Server port") + parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -31,15 +33,16 @@ def run(): help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() - # Define online speaker diarization pipeline - config = PipelineConfig.from_dict(vars(args)) - pipeline = OnlineSpeakerDiarization(config) + # Resolve pipeline + pipeline_class = utils.get_pipeline_class(args.pipeline) + config = pipeline_class.get_config_class().from_dict(vars(args)) + pipeline = pipeline_class(config) # Create websocket audio source audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port) # Run online inference - inference = RealTimeInference( + inference = StreamingInference( pipeline, audio_source, batch_size=1, diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index d7218f07..e0c670c5 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -1,16 +1,18 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc -import diart.sources as src -from diart.blocks import OnlineSpeakerDiarization, PipelineConfig -from diart.inference import RealTimeInference +from diart import argdoc +from diart import sources as src +from diart import utils +from diart.inference import StreamingInference from diart.sinks import RTTMWriter def run(): parser = argparse.ArgumentParser() parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:'") + parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -32,9 +34,10 @@ def run(): help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() - # Define online speaker diarization pipeline - config = PipelineConfig.from_dict(vars(args)) - pipeline = OnlineSpeakerDiarization(config) + # Resolve pipeline + pipeline_class = utils.get_pipeline_class(args.pipeline) + config = pipeline_class.get_config_class().from_dict(vars(args)) + pipeline = pipeline_class(config) # Manage audio source block_size = config.optimal_block_size() @@ -51,7 +54,7 @@ def run(): audio_source = src.MicrophoneAudioSource(config.sample_rate, block_size, device) # Run online inference - inference = RealTimeInference( + inference = StreamingInference( pipeline, audio_source, batch_size=1, diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index 4ad8852a..a1f1b63a 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -1,10 +1,11 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc import optuna -from diart.blocks import PipelineConfig, OnlineSpeakerDiarization -from diart.optim import Optimizer, HyperParameter +from diart import argdoc +from diart import utils +from diart.blocks.base import HyperParameter +from diart.optim import Optimizer from optuna.samplers import TPESampler @@ -13,6 +14,8 @@ def run(): parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") parser.add_argument("--reference", required=True, type=str, help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files") + parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -38,17 +41,28 @@ def run(): help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() + # Retrieve pipeline class + pipeline_class = utils.get_pipeline_class(args.pipeline) + # Create the base configuration for each trial - base_config = PipelineConfig.from_dict(vars(args)) + base_config = pipeline_class.get_config_class().from_dict(vars(args)) # Create hyper-parameters to optimize + possible_hparams = pipeline_class.hyper_parameters() hparams = [HyperParameter.from_name(name) for name in args.hparams] + hparams = [hp for hp in hparams if hp in possible_hparams] + if not hparams: + print( + f"No hyper-parameters to optimize. " + f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}" + ) + exit(1) # Use a custom storage if given if args.output is not None: msg = "Both `output` and `storage` were set, but only one was expected" assert args.storage is None, msg - args.output = Path(args.output) + args.output = Path(args.output).expanduser() args.output.mkdir(parents=True, exist_ok=True) study_or_path = args.output elif args.storage is not None: @@ -60,11 +74,11 @@ def run(): # Run optimization Optimizer( + pipeline_class=pipeline_class, speech_path=args.root, reference_path=args.reference, study_or_path=study_or_path, batch_size=args.batch_size, - pipeline_class=OnlineSpeakerDiarization, hparams=hparams, base_config=base_config, )(num_iter=args.num_iter, show_progress=True) diff --git a/src/diart/inference.py b/src/diart/inference.py index f4b65f5f..6afda89e 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -4,32 +4,33 @@ from traceback import print_exc from typing import Union, Text, Optional, Callable, Tuple, List -import diart.operators as dops -import diart.sources as src import numpy as np import pandas as pd import rx import rx.operators as ops import torch -from diart import utils -from diart.blocks import BasePipeline, Resample, BasePipelineConfig -from diart.progress import ProgressBar, RichProgressBar, TQDMProgressBar -from diart.sinks import DiarizationPredictionAccumulator, RealTimePlot, WindowClosedException from pyannote.core import Annotation, SlidingWindowFeature from pyannote.database.util import load_rttm -from pyannote.metrics.diarization import DiarizationErrorRate +from pyannote.metrics.base import BaseMetric from rx.core import Observer from tqdm import tqdm +from . import blocks +from . import operators as dops +from . import sources as src +from . import utils +from .progress import ProgressBar, RichProgressBar, TQDMProgressBar +from .sinks import PredictionAccumulator, StreamingPlot, WindowClosedException -class RealTimeInference: + +class StreamingInference: """Performs inference in real time given a pipeline and an audio source. Streams an audio source to an online speaker diarization pipeline. It allows users to attach a chain of operations in the form of hooks. Parameters ---------- - pipeline: BasePipeline + pipeline: StreamingPipeline Configured speaker diarization pipeline. source: AudioSource Audio source to be read and streamed. @@ -52,7 +53,7 @@ class RealTimeInference: """ def __init__( self, - pipeline: BasePipeline, + pipeline: blocks.StreamingPipeline, source: src.AudioSource, batch_size: int = 1, do_profile: bool = True, @@ -66,7 +67,7 @@ def __init__( self.do_profile = do_profile self.do_plot = do_plot self.show_progress = show_progress - self.accumulator = DiarizationPredictionAccumulator(self.source.uri) + self.accumulator = PredictionAccumulator(self.source.uri) self.unit = "chunk" if self.batch_size == 1 else "batch" self._observers = [] @@ -102,7 +103,7 @@ def __init__( f"but pipeline's is {sample_rate}. Will resample." logging.warning(msg) self.stream = self.stream.pipe( - ops.map(Resample(self.source.sample_rate, sample_rate)) + ops.map(blocks.Resample(self.source.sample_rate, sample_rate)) ) # Add rx operators to manage the inputs and outputs of the pipeline @@ -202,7 +203,7 @@ def __call__(self) -> Annotation: latency=config.latency, sample_rate=config.sample_rate, ), - ops.do(RealTimePlot(config.duration, config.latency)), + ops.do(StreamingPlot(config.duration, config.latency)), ) observable.subscribe( on_error=self._handle_error, @@ -288,7 +289,7 @@ def get_file_paths(self) -> List[Path]: def run_single( self, - pipeline: BasePipeline, + pipeline: blocks.StreamingPipeline, filepath: Path, progress_bar: ProgressBar, ) -> Annotation: @@ -298,7 +299,7 @@ def run_single( Parameters ---------- - pipeline: BasePipeline + pipeline: StreamingPipeline Speaker diarization pipeline to run. filepath: Path Path to the target file. @@ -318,7 +319,7 @@ def run_single( pipeline.config.optimal_block_size(), ) pipeline.set_timestamp_shift(-padding[0]) - inference = RealTimeInference( + inference = StreamingInference( pipeline, source, self.batch_size, @@ -337,7 +338,11 @@ def run_single( return pred - def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[Annotation]]: + def evaluate( + self, + predictions: List[Annotation], + metric: BaseMetric, + ) -> Union[pd.DataFrame, List[Annotation]]: """If a reference path was provided, compute the diarization error rate of a list of predictions. @@ -345,6 +350,8 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An ---------- predictions: List[Annotation] Predictions to evaluate. + metric: BaseMetric + Evaluation metric from pyannote.metrics. Returns ------- @@ -353,8 +360,7 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An reference path was given. Otherwise return the same predictions. """ if self.reference_path is not None: - metric = DiarizationErrorRate(collar=0, skip_overlap=False) - progress_bar = TQDMProgressBar("Computing DER", leave=False) + progress_bar = TQDMProgressBar(f"Computing {metric.name}", leave=False) progress_bar.create(total=len(predictions), unit="file") progress_bar.start() for hyp in predictions: @@ -368,18 +374,22 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An def __call__( self, pipeline_class: type, - config: BasePipelineConfig, + config: blocks.StreamingConfig, + metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files. - Notice that the internal state of the pipeline is reset before benchmarking. + The internal state of the pipeline is reset before benchmarking. Parameters ---------- pipeline_class: class - Class from the BasePipeline hierarchy. + Class from the StreamingPipeline hierarchy. A pipeline from this class will be instantiated by each worker. - config: BasePipelineConfig - Diarization pipeline configuration. + config: StreamingConfig + Streaming pipeline configuration. + metric: Optional[BaseMetric] + Evaluation metric from pyannote.metrics. + Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`) Returns ------- @@ -400,7 +410,8 @@ def __call__( progress = TQDMProgressBar(desc, leave=False, do_close=True) predictions.append(self.run_single(pipeline, filepath, progress)) - return self.evaluate(predictions) + metric = pipeline.suggest_metric() if metric is None else metric + return self.evaluate(predictions, metric) class Parallelize: @@ -426,20 +437,20 @@ def __init__( def run_single_job( self, pipeline_class: type, - config: BasePipelineConfig, + config: blocks.StreamingConfig, filepath: Path, description: Text, - ): + ) -> Annotation: """Build and run a pipeline on a single file. Configure execution to show progress alongside parallel runs. Parameters ---------- pipeline_class: class - Class from the BasePipeline hierarchy. + Class from the StreamingPipeline hierarchy. A pipeline from this class will be instantiated. - config: BasePipelineConfig - Diarization pipeline configuration. + config: StreamingConfig + Streaming pipeline configuration. filepath: Path Path to the target file. description: Text @@ -463,7 +474,8 @@ def run_single_job( def __call__( self, pipeline_class: type, - config: BasePipelineConfig, + config: blocks.StreamingConfig, + metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files in parallel. Each worker will build and run the pipeline on a different file. @@ -471,10 +483,13 @@ def __call__( Parameters ---------- pipeline_class: class - Class from the BasePipeline hierarchy. + Class from the StreamingPipeline hierarchy. A pipeline from this class will be instantiated by each worker. - config: BasePipelineConfig - Diarization pipeline configuration. + config: StreamingConfig + Streaming pipeline configuration. + metric: Optional[BaseMetric] + Evaluation metric from pyannote.metrics. + Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`) Returns ------- @@ -512,4 +527,5 @@ def __call__( predictions = [job.get() for job in jobs] # Evaluate results - return self.benchmark.evaluate(predictions) + metric = pipeline_class.suggest_metric() if metric is None else metric + return self.benchmark.evaluate(predictions, metric) diff --git a/src/diart/optim.py b/src/diart/optim.py index 05800a05..f7a96a6e 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -1,51 +1,32 @@ from collections import OrderedDict -from dataclasses import dataclass from pathlib import Path from typing import Sequence, Text, Optional, Union from optuna import TrialPruned, Study, create_study from optuna.samplers import TPESampler from optuna.trial import Trial, FrozenTrial +from pyannote.metrics.base import BaseMetric from tqdm import trange, tqdm +from typing_extensions import Literal +from . import blocks from .audio import FilePath -from .blocks import BasePipelineConfig, PipelineConfig, OnlineSpeakerDiarization from .inference import Benchmark -@dataclass -class HyperParameter: - name: Text - low: float - high: float - - @staticmethod - def from_name(name: Text) -> 'HyperParameter': - if name == "tau_active": - return TauActive - if name == "rho_update": - return RhoUpdate - if name == "delta_new": - return DeltaNew - raise ValueError(f"Hyper-parameter '{name}' not recognized") - - -TauActive = HyperParameter("tau_active", low=0, high=1) -RhoUpdate = HyperParameter("rho_update", low=0, high=1) -DeltaNew = HyperParameter("delta_new", low=0, high=2) - - class Optimizer: def __init__( self, + pipeline_class: type, speech_path: Union[Text, Path], reference_path: Union[Text, Path], study_or_path: Union[FilePath, Study], batch_size: int = 32, - pipeline_class: type = OnlineSpeakerDiarization, - hparams: Optional[Sequence[HyperParameter]] = None, - base_config: Optional[BasePipelineConfig] = None, + hparams: Optional[Sequence[blocks.base.HyperParameter]] = None, + base_config: Optional[blocks.StreamingConfig] = None, do_kickstart_hparams: bool = True, + metric: Optional[BaseMetric] = None, + direction: Literal["minimize", "maximize"] = "minimize", ): self.pipeline_class = pipeline_class # FIXME can we run this benchmark in parallel? @@ -58,15 +39,17 @@ def __init__( batch_size=batch_size, ) + self.metric = metric + self.direction = direction self.base_config = base_config self.do_kickstart_hparams = do_kickstart_hparams if self.base_config is None: - self.base_config = PipelineConfig() + self.base_config = self.pipeline_class.get_config_class()() self.do_kickstart_hparams = False self.hparams = hparams if self.hparams is None: - self.hparams = [TauActive, RhoUpdate, DeltaNew] + self.hparams = self.pipeline_class.hyper_parameters() # Make sure hyper-parameters exist in the configuration class given possible_hparams = vars(self.base_config) @@ -85,7 +68,7 @@ def __init__( storage="sqlite:///" + str(study_or_path / f"{study_or_path.stem}.db"), sampler=TPESampler(), study_name=study_or_path.stem, - direction="minimize", + direction=self.direction, load_if_exists=True, ) else: @@ -105,7 +88,7 @@ def _callback(self, study: Study, trial: FrozenTrial): return self._progress.update(1) self._progress.set_description(f"Trial {trial.number + 1}") - values = {"best_der": study.best_value} + values = {"best_perf": study.best_value} for name, value in study.best_params.items(): values[f"best_{name}"] = value self._progress.set_postfix(OrderedDict(values)) @@ -125,11 +108,16 @@ def objective(self, trial: Trial) -> float: # Instantiate the new configuration for the trial config = self.base_config.__class__(**trial_config) + # Determine the evaluation metric + metric = self.metric + if metric is None: + metric = self.pipeline_class.suggest_metric() + # Run pipeline over the dataset - report = self.benchmark(self.pipeline_class, config) + report = self.benchmark(self.pipeline_class, config, metric) - # Extract DER from report - return report.loc["TOTAL", "diarization error rate"]["%"] + # Extract target metric from report + return report.loc["TOTAL", metric.name]["%"] def __call__(self, num_iter: int, show_progress: bool = True): self._progress = None diff --git a/src/diart/sinks.py b/src/diart/sinks.py index cf480bed..63c170d0 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -8,12 +8,14 @@ from rx.core import Observer from typing_extensions import Literal +from . import utils + class WindowClosedException(Exception): pass -def _extract_annotation(value: Union[Tuple, Annotation]) -> Annotation: +def _extract_prediction(value: Union[Tuple, Annotation]) -> Annotation: if isinstance(value, tuple): return value[0] if isinstance(value, Annotation): @@ -43,10 +45,11 @@ def patch(self): annotation.support(self.patch_collar).write_rttm(file) def on_next(self, value: Union[Tuple, Annotation]): - annotation = _extract_annotation(value) - annotation.uri = self.uri + prediction = _extract_prediction(value) + # Write prediction in RTTM format + prediction.uri = self.uri with open(self.path, 'a') as file: - annotation.write_rttm(file) + prediction.write_rttm(file) def on_error(self, error: Exception): self.patch() @@ -55,30 +58,30 @@ def on_completed(self): self.patch() -class DiarizationPredictionAccumulator(Observer): +class PredictionAccumulator(Observer): def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05): super().__init__() self.uri = uri self.patch_collar = patch_collar - self._annotation = None + self._prediction: Optional[Annotation] = None def patch(self): """Stitch same-speaker turns that are close to each other""" - if self._annotation is not None: - self._annotation = self._annotation.support(self.patch_collar) + if self._prediction is not None: + self._prediction = self._prediction.support(self.patch_collar) def get_prediction(self) -> Annotation: # Patch again in case this is called before on_completed self.patch() - return self._annotation + return self._prediction def on_next(self, value: Union[Tuple, Annotation]): - annotation = _extract_annotation(value) - annotation.uri = self.uri - if self._annotation is None: - self._annotation = annotation + prediction = _extract_prediction(value) + prediction.uri = self.uri + if self._prediction is None: + self._prediction = prediction else: - self._annotation.update(annotation) + self._prediction.update(prediction) def on_error(self, error: Exception): self.patch() @@ -87,7 +90,7 @@ def on_completed(self): self.patch() -class RealTimePlot(Observer): +class StreamingPlot(Observer): def __init__( self, duration: float, @@ -134,11 +137,15 @@ def get_plot_bounds(self, real_time: float) -> Segment: start_time = max(0., end_time - self.window_duration) return Segment(start_time, end_time) - def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): + def on_next( + self, + values: Tuple[Annotation, SlidingWindowFeature, float] + ): if self.window_closed: raise WindowClosedException prediction, waveform, real_time = values + # Initialize figure if first call if self.figure is None: self._init_figure() @@ -147,15 +154,21 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): # Set plot bounds notebook.crop = self.get_plot_bounds(real_time) - # Plot current values + # Align prediction and reference if possible if self.reference is not None: metric = DiarizationErrorRate() mapping = metric.optimal_mapping(self.reference, prediction) prediction.rename_labels(mapping=mapping, copy=False) + + # Plot prediction notebook.plot_annotation(prediction, self.axs[0]) self.axs[0].set_title("Output") + + # Plot waveform notebook.plot_feature(waveform, self.axs[1]) self.axs[1].set_title("Audio") + + # Plot reference if available if self.num_axs == 3: notebook.plot_annotation(self.reference, self.axs[2]) self.axs[2].set_title("Reference") diff --git a/src/diart/sources.py b/src/diart/sources.py index 0f5dedf7..b34d5cf3 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -5,12 +5,12 @@ import numpy as np import sounddevice as sd import torch -from diart import utils from einops import rearrange from rx.subject import Subject from torchaudio.io import StreamReader from websocket_server import WebsocketServer +from . import utils from .audio import FilePath, AudioLoader diff --git a/src/diart/utils.py b/src/diart/utils.py index e90861c7..e825ef29 100644 --- a/src/diart/utils.py +++ b/src/diart/utils.py @@ -4,9 +4,11 @@ import matplotlib.pyplot as plt import numpy as np -from diart.progress import ProgressBar from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook +from .progress import ProgressBar +from . import blocks + class Chronometer: def __init__(self, unit: Text, progress_bar: Optional[ProgressBar] = None): @@ -74,6 +76,18 @@ def get_padding_left(stream_duration: float, chunk_duration: float) -> float: return 0 +def repeat_label(label: Text): + while True: + yield label + + +def get_pipeline_class(class_name: Text) -> type: + pipeline_class = getattr(blocks, class_name, None) + msg = f"Pipeline '{class_name}' doesn't exist" + assert pipeline_class is not None, msg + return pipeline_class + + def get_padding_right(latency: float, step: float) -> float: return latency - step From 74470617b482d2d546a83bff5a0996d90d0df079 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Apr 2023 17:43:51 +0200 Subject: [PATCH 02/23] Update link in setup.cfg --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 594c876e..e67e4426 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,11 +2,11 @@ name=diart version=0.7.0 author=Juan Manuel Coria -description=Speaker diarization in real time +description=Streaming speaker diarization in real-time long_description=file: README.md long_description_content_type=text/markdown keywords=speaker diarization, streaming, online, real time, rxpy -url=https://github.com/juanmc2005/StreamingSpeakerDiarization +url=https://github.com/juanmc2005/diart license=MIT classifiers= Development Status :: 4 - Beta From 498539438861a4d7472b75387e7b0cb4e6768dc3 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Apr 2023 17:51:41 +0200 Subject: [PATCH 03/23] Update code snippets in README --- README.md | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index ef533946..57ca293a 100644 --- a/README.md +++ b/README.md @@ -110,17 +110,17 @@ See `diart.stream -h` for more options. ### From python -Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk: +Use `StreamingInference` to run a pipeline on an audio source and write the results to disk: ```python -from diart import OnlineSpeakerDiarization +from diart import SpeakerDiarization from diart.sources import MicrophoneAudioSource -from diart.inference import RealTimeInference +from diart.inference import StreamingInference from diart.sinks import RTTMWriter -pipeline = OnlineSpeakerDiarization() +pipeline = SpeakerDiarization() mic = MicrophoneAudioSource(pipeline.config.sample_rate) -inference = RealTimeInference(pipeline, mic, do_plot=True) +inference = StreamingInference(pipeline, mic, do_plot=True) inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm")) prediction = inference() ``` @@ -129,13 +129,13 @@ For inference and evaluation on a dataset we recommend to use `Benchmark` (see n ## 🤖 Custom models -Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel` (which are PyTorch `Module` subclasses): +Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`): ```python -from diart import OnlineSpeakerDiarization, PipelineConfig +from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart.models import EmbeddingModel, SegmentationModel from diart.sources import MicrophoneAudioSource -from diart.inference import RealTimeInference +from diart.inference import StreamingInference def model_loader(): @@ -168,19 +168,19 @@ class MyEmbeddingModel(EmbeddingModel): return self.model(waveform, weights) -config = PipelineConfig( +config = SpeakerDiarizationConfig( segmentation=MySegmentationModel(), embedding=MyEmbeddingModel() ) -pipeline = OnlineSpeakerDiarization(config) +pipeline = SpeakerDiarization(config) mic = MicrophoneAudioSource(config.sample_rate) -inference = RealTimeInference(pipeline, mic) +inference = StreamingInference(pipeline, mic) prediction = inference() ``` ## 📈 Tune hyper-parameters -Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset. +Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs. ### From the command line @@ -281,7 +281,7 @@ diart.serve --host 0.0.0.0 --port 7007 diart.client microphone --host --port 7007 ``` -**Note:** please make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`. +**Note:** make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`. See `-h` for more options. @@ -290,13 +290,13 @@ See `-h` for more options. For customized solutions, a server can also be created in python using the `WebSocketAudioSource`: ```python -from diart import OnlineSpeakerDiarization +from diart import SpeakerDiarization from diart.sources import WebSocketAudioSource -from diart.inference import RealTimeInference +from diart.inference import StreamingInference -pipeline = OnlineSpeakerDiarization() +pipeline = SpeakerDiarization() source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007) -inference = RealTimeInference(pipeline, source) +inference = StreamingInference(pipeline, source) inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm())) prediction = inference() ``` @@ -354,14 +354,14 @@ or using the inference API: ```python from diart.inference import Benchmark, Parallelize -from diart import OnlineSpeakerDiarization, PipelineConfig +from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart.models import SegmentationModel benchmark = Benchmark("/wav/dir", "/rttm/dir") name = "pyannote/segmentation@Interspeech2021" segmentation = SegmentationModel.from_pyannote(name) -config = PipelineConfig( +config = SpeakerDiarizationConfig( # Set the model used in the paper segmentation=segmentation, step=0.5, @@ -370,12 +370,12 @@ config = PipelineConfig( rho_update=0.422, delta_new=1.517 ) -benchmark(OnlineSpeakerDiarization, config) +benchmark(SpeakerDiarization, config) # Run the same benchmark in parallel p_benchmark = Parallelize(benchmark, num_workers=4) if __name__ == "__main__": # Needed for multiprocessing - p_benchmark(OnlineSpeakerDiarization, config) + p_benchmark(SpeakerDiarization, config) ``` This pre-calculates model outputs in batches, so it runs a lot faster. From 540ad0a97be45faaab2936b85cedbd18c9e456cc Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Apr 2023 21:18:36 +0200 Subject: [PATCH 04/23] Add minor README modifications --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 57ca293a..ae13059f 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ | - 🤖 Custom models + 🤖 Add your model | @@ -127,7 +127,7 @@ prediction = inference() For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)). -## 🤖 Custom models +## 🤖 Add your model Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`): From 8cc9925455a73f20d231560abf9833b17258db96 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Fri, 21 Apr 2023 12:23:02 +0200 Subject: [PATCH 05/23] Initial ASR implementation. Broken stuff --- src/diart/blocks/asr.py | 218 ++++++++++++++++++++++++++++++++ src/diart/blocks/base.py | 3 +- src/diart/blocks/diarization.py | 3 +- src/diart/inference.py | 39 +++--- src/diart/models.py | 136 +++++++++++++++++++- src/diart/sinks.py | 4 +- 6 files changed, 381 insertions(+), 22 deletions(-) create mode 100644 src/diart/blocks/asr.py diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py new file mode 100644 index 00000000..641096dd --- /dev/null +++ b/src/diart/blocks/asr.py @@ -0,0 +1,218 @@ +import math +from pathlib import Path +from typing import Sequence, Optional, Any, Union, List, Text, Tuple, Dict, Hashable + +import numpy as np +import torch +from einops import rearrange +from pyannote.core import SlidingWindowFeature, Annotation, Segment +from pyannote.metrics.base import BaseMetric + +from . import base +from .. import models as m +from .. import utils +from ..blocks.base import HyperParameter +from ..features import TemporalFeatureFormatter, TemporalFeatures + + +BeamSize = HyperParameter("beam_size", low=1, high=20) + + +class SpeechRecognition: + def __init__(self, model: m.SpeechRecognitionModel, device: Optional[torch.device] = None): + self.model = model + self.model.eval() + self.device = device + if self.device is None: + self.device = torch.device("cpu") + self.model.to(self.device) + self.formatter = TemporalFeatureFormatter() + + @staticmethod + def from_whisper( + name: Text, + download_path: Optional[Union[Text, Path]] = None, + in_memory: bool = False, + remember_transcriptions: bool = True, + device: Optional[Union[Text, torch.device]] = None, + ) -> 'SpeechRecognition': + asr_model = m.SpeechRecognitionModel.from_whisper( + name, download_path, in_memory, remember_transcriptions + ) + return SpeechRecognition(asr_model, device) + + def __call__(self, waveform: TemporalFeatures) -> List[m.Transcription]: + """ + Compute the transcription of input audio. + + Parameters + ---------- + waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels) + Audio to transcribe + + Returns + ------- + transcriptions: List[Transcription] + A list of timestamped transcriptions + """ + with torch.no_grad(): + wave = rearrange( + self.formatter.cast(waveform), + "batch sample channel -> batch channel sample" + ) + # output = self.model(wave.to(self.device)).cpu() + output = self.model(wave.to(self.device)) + return output + + +class TranscriptionConfig(base.StreamingConfig): + def __init__( + self, + asr: Optional[m.SpeechRecognitionModel] = None, + duration: Optional[float] = None, + language: Optional[Text] = None, + beam_size: int = 5, + device: Optional[torch.device] = None, + ): + self.device = device + if self.device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Default ASR model is Whisper small (244M parameters) + self.asr = asr + if self.asr is None: + self.asr = m.SpeechRecognitionModel.from_whisper("small") + self.asr.set_language(language) + + self._duration = duration + self._sample_rate: Optional[int] = None + + self.beam_size = beam_size + + @property + def duration(self) -> float: + if self._duration is None: + self._duration = self.asr.duration + return self._duration + + @property + def step(self) -> float: + return self.duration + + @property + def latency(self) -> float: + return self.duration + + @property + def sample_rate(self) -> int: + if self._sample_rate is None: + self._sample_rate = self.asr.sample_rate + return self._sample_rate + + @staticmethod + def from_dict(data: Any) -> 'TranscriptionConfig': + # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None + device = utils.get(data, "device", None) + if device is None: + device = torch.device("cpu") if utils.get(data, "cpu", False) else None + + name = utils.get(data, "whisper", "small") + asr = m.SpeechRecognitionModel.from_whisper(name) + + return TranscriptionConfig( + asr=asr, + duration=utils.get(data, "duration", None), + language=utils.get(data, "language", None), + beam_size=utils.get(data, "beam_size", 5), + device=device, + ) + + +class Transcription(base.StreamingPipeline): + def __init__( + self, + config: Optional[TranscriptionConfig] = None, + ): + self._config = TranscriptionConfig() if config is None else config + self.asr = SpeechRecognition(self.config.asr, self.config.device) + + @staticmethod + def get_config_class() -> type: + return TranscriptionConfig + + @staticmethod + def suggest_metric() -> BaseMetric: + # TODO word error rate + raise NotImplementedError + + @staticmethod + def hyper_parameters() -> Sequence[HyperParameter]: + return [BeamSize] + + @property + def config(self) -> TranscriptionConfig: + return self._config + + def reset(self): + # No internal state. Nothing to do + pass + + def set_timestamp_shift(self, shift: float): + # No timestamped output. Nothing to do + pass + + def __call__( + self, + waveforms: Sequence[SlidingWindowFeature], + diarization: Optional[Sequence[Annotation]] = None, + **kwargs + ) -> Sequence[Tuple[Text, SlidingWindowFeature]]: + # TODO implement batched inference + too_many_dia = diarization is not None and len(diarization) > 1 + msg = "Batched inference is not yet supported for 'Transcription'. Please set batch size to 1" + if len(waveforms) > 1 or too_many_dia: + print(msg) + exit(1) + + waveform = waveforms[0] + + # Add fake batch dimension, shape (1, samples, channels) + batch = torch.from_numpy(waveform.data).unsqueeze(0) + + expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate)) + msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" + assert batch.shape[1] == expected_num_samples, msg + + # Transcribe batch + # TODO only transcribe if there's speech + output = self.asr(batch)[0] + + if diarization is None: + return [(output.text, waveform)] + + diarization = diarization[0] + + # Align transcription with diarization to determine speakers + full_transcription = [] + buffer_shift = waveform.sliding_window.start + for text, timestamp in zip(output.chunks, output.timestamps): + target_region = Segment( + buffer_shift + timestamp.start, + buffer_shift + timestamp.end + ) + dia = diarization.crop(target_region) + speakers = dia.labels() + num_speakers = len(speakers) + if num_speakers == 0: + # Include transcription but don't assign a speaker + full_transcription.append(text) + elif num_speakers == 1: + # Typical case, annotate text with the only speaker + full_transcription.append(f"[{speakers[0]}]{text}") + else: + # Multiple speakers for the same text block, choose the most active one + # TODO match at the level of words? + max_spk = np.argmax([dia.label_duration(spk) for spk in speakers]) + full_transcription.append(f"[{speakers[max_spk]}]{text}") + + return [(" ".join(full_transcription), waveform)] diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index 28f313eb..d1a372c1 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -87,6 +87,7 @@ def set_timestamp_shift(self, shift: float): def __call__( self, - waveforms: Sequence[SlidingWindowFeature] + waveforms: Sequence[SlidingWindowFeature], + **kwargs, ) -> Sequence[Tuple[Any, SlidingWindowFeature]]: raise NotImplementedError diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index f2a25119..03169077 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -192,7 +192,8 @@ def reset(self): def __call__( self, - waveforms: Sequence[SlidingWindowFeature] + waveforms: Sequence[SlidingWindowFeature], + **kwargs, ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" diff --git a/src/diart/inference.py b/src/diart/inference.py index 6afda89e..14ab4736 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -2,7 +2,7 @@ from multiprocessing import Pool, freeze_support, RLock, current_process from pathlib import Path from traceback import print_exc -from typing import Union, Text, Optional, Callable, Tuple, List +from typing import Union, Text, Optional, Callable, Tuple, List, Any import numpy as np import pandas as pd @@ -20,7 +20,7 @@ from . import sources as src from . import utils from .progress import ProgressBar, RichProgressBar, TQDMProgressBar -from .sinks import PredictionAccumulator, StreamingPlot, WindowClosedException +from .sinks import DiarizationAccumulator, StreamingPlot, WindowClosedException class StreamingInference: @@ -67,9 +67,9 @@ def __init__( self.do_profile = do_profile self.do_plot = do_plot self.show_progress = show_progress - self.accumulator = PredictionAccumulator(self.source.uri) self.unit = "chunk" if self.batch_size == 1 else "batch" self._observers = [] + self._predictions = [] chunk_duration = self.pipeline.config.duration step_duration = self.pipeline.config.step @@ -123,7 +123,7 @@ def __init__( self.stream = self.stream.pipe( ops.flat_map(lambda results: rx.from_iterable(results)), - ops.do(self.accumulator), + ops.do_action(lambda pred_wav: self._predictions.append(pred_wav[0])), ) if show_progress: @@ -141,13 +141,13 @@ def _close_chronometer(self): self._chrono.stop(do_count=False) self._chrono.report() - def attach_hooks(self, *hooks: Callable[[Tuple[Annotation, SlidingWindowFeature]], None]): + def attach_hooks(self, *hooks: Callable[[Tuple[Any, SlidingWindowFeature]], None]): """Attach hooks to the pipeline. Parameters ---------- - *hooks: (Tuple[Annotation, SlidingWindowFeature]) -> None - Hook functions to consume emitted annotations and audio. + *hooks: (Tuple[Any, SlidingWindowFeature]) -> None + Hook functions to consume emitted predictions and audio. """ self.stream = self.stream.pipe(*[ops.do_action(hook) for hook in hooks]) @@ -157,7 +157,7 @@ def attach_observers(self, *observers: Observer): Parameters ---------- *observers: Observer - Observers to consume emitted annotations and audio. + Observers to consume emitted predictions and audio. """ self.stream = self.stream.pipe(*[ops.do(sink) for sink in observers]) self._observers.extend(observers) @@ -182,13 +182,13 @@ def _handle_completion(self): self._close_pbar() self._close_chronometer() - def __call__(self) -> Annotation: - """Stream audio chunks from `source` to `pipeline`. + def __call__(self) -> List[Any]: + """Stream audio chunks from a source to a pipeline. Returns ------- - predictions: Annotation - Speaker diarization pipeline predictions + predictions: List[Any] + Streaming pipeline predictions """ if self.show_progress: self._pbar.start() @@ -209,9 +209,9 @@ def __call__(self) -> Annotation: on_error=self._handle_error, on_completed=self._handle_completion, ) - # FIXME if read() isn't blocking, the prediction returned is empty + # FIXME if read() isn't blocking, predictions are empty self.source.read() - return self.accumulator.get_prediction() + return self._predictions class Benchmark: @@ -329,8 +329,15 @@ def run_single( progress_bar=progress_bar, ) - pred = inference() - pred.uri = source.uri + # Accumulate predictions in memory + pred_accumulator = DiarizationAccumulator(source.uri) + inference.attach_observers(pred_accumulator) + + # Run the pipeline on this file + inference() + + # Extract prediction + pred = pred_accumulator.get_prediction() if self.output_path is not None: with open(self.output_path / f"{source.uri}.rttm", "w") as out_file: diff --git a/src/diart/models.py b/src/diart/models.py index df66e166..7a653837 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -1,7 +1,10 @@ -from typing import Optional, Text, Union, Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Text, Union, Callable, List, Tuple, Dict import torch import torch.nn as nn +from pyannote.core import Segment try: import pyannote.audio.pipelines.utils as pyannote_loader @@ -9,6 +12,12 @@ except ImportError: _has_pyannote = False +try: + import whisper + _has_whisper = True +except ImportError: + _has_whisper = False + class PyannoteLoader: def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): @@ -20,6 +29,25 @@ def __call__(self) -> nn.Module: return pyannote_loader.get_model(self.model_info, self.hf_token) +class WhisperLoader: + def __init__( + self, + name: Text, + download_path: Optional[Union[Text, Path]] = None, + in_memory: bool = False, + ): + self.name = name + self.download_path = download_path + self.in_memory = in_memory + + def __call__(self) -> nn.Module: + return whisper.load_model( + name=self.name, + download_root=self.download_path, + in_memory=self.in_memory, + ) + + class LazyModel(nn.Module): def __init__(self, loader: Callable[[], nn.Module]): super().__init__() @@ -163,3 +191,109 @@ def forward( weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(waveform, weights=weights) + + +@dataclass(frozen=True) +class Transcription: + text: Text + chunks: List[Text] + timestamps: List[Segment] + + +class SpeechRecognitionModel(LazyModel): + @staticmethod + def from_whisper( + name: Text, + download_path: Optional[Union[Text, Path]] = None, + in_memory: bool = False, + remember_transcriptions: bool = True, + ) -> 'SpeechRecognitionModel': + msg = "No whisper-transcribed installation found. " \ + "Visit https://github.com/linto-ai/whisper-timestamped#installation to install" + assert _has_whisper, msg + return WhisperSpeechRecognitionModel( + name, download_path, in_memory, remember_transcriptions + ) + + @property + def duration(self) -> float: + raise NotImplementedError + + @property + def sample_rate(self) -> int: + raise NotImplementedError + + def set_language(self, language: Optional[Text] = None): + raise NotImplementedError + + def forward(self, waveform: torch.Tensor) -> List[Transcription]: + """ + Forward pass of the speech recognition model. + + Parameters + ---------- + waveform: torch.Tensor, shape (batch, channels, samples) + Batch of audio chunks to transcribe + + Returns + ------- + transcriptions: List[Transcription] + A list of timestamped transcriptions + """ + raise NotImplementedError + + +class WhisperSpeechRecognitionModel(SpeechRecognitionModel): + def __init__( + self, + name: Text, + download_path: Optional[Union[Text, Path]] = None, + in_memory: bool = False, + remember_transcriptions: bool = True, + ): + super().__init__(WhisperLoader(name, download_path, in_memory)) + self.remember_transcriptions = remember_transcriptions + self.language = None + self._cache = None + + @property + def duration(self) -> float: + # Whisper's maximum duration per input is 30s + return whisper.audio.CHUNK_LENGTH + + @property + def sample_rate(self) -> int: + return whisper.audio.SAMPLE_RATE + + def set_language(self, language: Optional[Text] = None): + self.language = language + + def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]: + results = [] + for waveform in waveform_batch: + audio = whisper.pad_or_trim(waveform.type(torch.float32).reshape(-1)) + transcription = whisper.transcribe( + self.model, + audio, + initial_prompt=self._cache, + verbose=None, + task="transcribe", + language=self.language, + ) + + # Extract chunks and timestamps + chunks, timestamps = [], [] + for chunk in transcription["segments"]: + chunks.append(chunk["text"]) + timestamps.append(Segment(chunk["start"], chunk["end"])) + + # Create transcription object + transcription = Transcription(transcription["text"], chunks, timestamps) + results.append(transcription) + + # Update transcription buffer + if self.remember_transcriptions: + # TODO handle overlapping transcriptions + self._cache = transcription.text + + return results diff --git a/src/diart/sinks.py b/src/diart/sinks.py index 63c170d0..8d9217a1 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -8,8 +8,6 @@ from rx.core import Observer from typing_extensions import Literal -from . import utils - class WindowClosedException(Exception): pass @@ -58,7 +56,7 @@ def on_completed(self): self.patch() -class PredictionAccumulator(Observer): +class DiarizationAccumulator(Observer): def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05): super().__init__() self.uri = uri From 1ae4934c11879c293defa9bc172df1e4f7b9763d Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Fri, 21 Apr 2023 16:52:41 +0200 Subject: [PATCH 06/23] First working transcription pipeline. Using diarization is possible but a bit quirky --- src/diart/blocks/asr.py | 28 +++++----- src/diart/blocks/base.py | 13 +++-- src/diart/blocks/diarization.py | 23 ++++++-- src/diart/blocks/vad.py | 26 ++++++--- src/diart/console/tune.py | 9 ++-- src/diart/inference.py | 92 +++++++++++++++++-------------- src/diart/metrics.py | 96 +++++++++++++++++++++++++++++++++ 7 files changed, 214 insertions(+), 73 deletions(-) create mode 100644 src/diart/metrics.py diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py index 641096dd..0e034a3e 100644 --- a/src/diart/blocks/asr.py +++ b/src/diart/blocks/asr.py @@ -1,19 +1,17 @@ -import math from pathlib import Path -from typing import Sequence, Optional, Any, Union, List, Text, Tuple, Dict, Hashable +from typing import Sequence, Optional, Any, Union, List, Text, Tuple import numpy as np import torch from einops import rearrange from pyannote.core import SlidingWindowFeature, Annotation, Segment -from pyannote.metrics.base import BaseMetric from . import base from .. import models as m from .. import utils from ..blocks.base import HyperParameter from ..features import TemporalFeatureFormatter, TemporalFeatures - +from ..metrics import Metric, WordErrorRate BeamSize = HyperParameter("beam_size", low=1, high=20) @@ -141,9 +139,8 @@ def get_config_class() -> type: return TranscriptionConfig @staticmethod - def suggest_metric() -> BaseMetric: - # TODO word error rate - raise NotImplementedError + def suggest_metric() -> Metric: + return WordErrorRate() @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: @@ -161,6 +158,13 @@ def set_timestamp_shift(self, shift: float): # No timestamped output. Nothing to do pass + def join_predictions(self, predictions: List[Text]) -> Text: + return "\n".join(predictions) + + def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Path]): + with open(Path(dir_path) / f"{uri}.txt", "w") as out_file: + out_file.write(prediction) + def __call__( self, waveforms: Sequence[SlidingWindowFeature], @@ -168,11 +172,9 @@ def __call__( **kwargs ) -> Sequence[Tuple[Text, SlidingWindowFeature]]: # TODO implement batched inference - too_many_dia = diarization is not None and len(diarization) > 1 + only_one_dia = diarization is None or len(diarization) == 1 msg = "Batched inference is not yet supported for 'Transcription'. Please set batch size to 1" - if len(waveforms) > 1 or too_many_dia: - print(msg) - exit(1) + assert len(waveforms) == 1 and only_one_dia, msg waveform = waveforms[0] @@ -188,7 +190,7 @@ def __call__( output = self.asr(batch)[0] if diarization is None: - return [(output.text, waveform)] + return [(output.text.strip(), waveform)] diarization = diarization[0] @@ -215,4 +217,4 @@ def __call__( max_spk = np.argmax([dia.label_duration(spk) for spk in speakers]) full_transcription.append(f"[{speakers[max_spk]}]{text}") - return [(" ".join(full_transcription), waveform)] + return [(" ".join(full_transcription).strip(), waveform)] diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index d1a372c1..40d1d22d 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -1,12 +1,13 @@ -from typing import Any, Tuple, Sequence, Text from dataclasses import dataclass +from typing import Any, Tuple, Sequence, Text, List, Union +from pathlib import Path import numpy as np from pyannote.core import SlidingWindowFeature -from pyannote.metrics.base import BaseMetric from .. import utils from ..audio import FilePath, AudioLoader +from ..metrics import Metric @dataclass @@ -68,7 +69,7 @@ def get_config_class() -> type: raise NotImplementedError @staticmethod - def suggest_metric() -> BaseMetric: + def suggest_metric() -> Metric: raise NotImplementedError @staticmethod @@ -85,6 +86,12 @@ def reset(self): def set_timestamp_shift(self, shift: float): raise NotImplementedError + def join_predictions(self, predictions: List[Any]) -> Any: + raise NotImplementedError + + def write_prediction(self, uri: Text, prediction: Any, dir_path: Union[Text, Path]): + raise NotImplementedError + def __call__( self, waveforms: Sequence[SlidingWindowFeature], diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index 03169077..2f8de3f5 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -1,20 +1,20 @@ -from typing import Optional, Tuple, Sequence, Union, Any +from pathlib import Path +from typing import Optional, Tuple, Sequence, Union, Any, Text, List import numpy as np import torch from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment -from pyannote.metrics.base import BaseMetric -from pyannote.metrics.diarization import DiarizationErrorRate from typing_extensions import Literal -from .aggregation import DelayedAggregation from . import base +from .aggregation import DelayedAggregation from .clustering import OnlineSpeakerClustering from .embedding import OverlapAwareSpeakerEmbedding from .segmentation import SpeakerSegmentation from .utils import Binarize from .. import models as m from .. import utils +from ..metrics import Metric, DiarizationErrorRate class SpeakerDiarizationConfig(base.StreamingConfig): @@ -31,6 +31,7 @@ def __init__( gamma: float = 3, beta: float = 10, max_speakers: int = 20, + merge_collar: float = 0.05, device: Optional[torch.device] = None, **kwargs, ): @@ -61,6 +62,7 @@ def __init__( self.gamma = gamma self.beta = beta self.max_speakers = max_speakers + self.merge_collar = merge_collar self.device = device if self.device is None: @@ -103,6 +105,7 @@ def from_dict(data: Any) -> 'SpeakerDiarizationConfig': gamma=utils.get(data, "gamma", 3), beta=utils.get(data, "beta", 10), max_speakers=utils.get(data, "max_speakers", 20), + merge_collar=utils.get(data, "merge_collar", 0.05), device=device, ) @@ -165,7 +168,7 @@ def get_config_class() -> type: return SpeakerDiarizationConfig @staticmethod - def suggest_metric() -> BaseMetric: + def suggest_metric() -> Metric: return DiarizationErrorRate(collar=0, skip_overlap=False) @staticmethod @@ -179,6 +182,16 @@ def config(self) -> SpeakerDiarizationConfig: def set_timestamp_shift(self, shift: float): self.timestamp_shift = shift + def join_predictions(self, predictions: List[Annotation]) -> Annotation: + result = Annotation(uri=predictions[0].uri) + for pred in predictions: + result.update(pred) + return result.support(self.config.merge_collar) + + def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Text, Path]): + with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file: + prediction.write_rttm(out_file) + def reset(self): self.set_timestamp_shift(0) self.clustering = OnlineSpeakerClustering( diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py index def833b6..47afb4da 100644 --- a/src/diart/blocks/vad.py +++ b/src/diart/blocks/vad.py @@ -1,18 +1,18 @@ -from typing import Any, Optional, Union, Sequence, Tuple +from pathlib import Path +from typing import Any, Optional, Union, Sequence, Tuple, Text, List import numpy as np import torch from pyannote.core import Annotation, Timeline, SlidingWindowFeature, SlidingWindow, Segment -from pyannote.metrics.base import BaseMetric -from pyannote.metrics.detection import DetectionErrorRate from typing_extensions import Literal -from .aggregation import DelayedAggregation from . import base +from .aggregation import DelayedAggregation from .segmentation import SpeakerSegmentation from .utils import Binarize from .. import models as m from .. import utils +from ..metrics import Metric, DetectionErrorRate class VoiceActivityDetectionConfig(base.StreamingConfig): @@ -23,6 +23,7 @@ def __init__( step: float = 0.5, latency: Optional[Union[float, Literal["max", "min"]]] = None, tau_active: float = 0.6, + merge_collar: float = 0.05, device: Optional[torch.device] = None, **kwargs, ): @@ -43,6 +44,7 @@ def __init__( self._latency = self._duration self.tau_active = tau_active + self.merge_collar = merge_collar self.device = device if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -92,6 +94,7 @@ def from_dict(data: Any) -> 'VoiceActivityDetectionConfig': step=utils.get(data, "step", 0.5), latency=utils.get(data, "latency", None), tau_active=tau, + merge_collar=utils.get(data, "merge_collar", 0.05), device=device, ) @@ -127,7 +130,7 @@ def get_config_class() -> type: return VoiceActivityDetectionConfig @staticmethod - def suggest_metric() -> BaseMetric: + def suggest_metric() -> Metric: return DetectionErrorRate(collar=0, skip_overlap=False) @staticmethod @@ -135,7 +138,7 @@ def hyper_parameters() -> Sequence[base.HyperParameter]: return [base.TauActive] @property - def config(self) -> base.StreamingConfig: + def config(self) -> VoiceActivityDetectionConfig: return self._config def reset(self): @@ -145,9 +148,20 @@ def reset(self): def set_timestamp_shift(self, shift: float): self.timestamp_shift = shift + def join_predictions(self, predictions: List[Annotation]) -> Annotation: + result = Annotation(uri=predictions[0].uri) + for pred in predictions: + result.update(pred) + return result.support(self.config.merge_collar) + + def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Text, Path]): + with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file: + prediction.write_rttm(out_file) + def __call__( self, waveforms: Sequence[SlidingWindowFeature], + **kwargs, ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index a1f1b63a..6affda50 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -51,12 +51,9 @@ def run(): possible_hparams = pipeline_class.hyper_parameters() hparams = [HyperParameter.from_name(name) for name in args.hparams] hparams = [hp for hp in hparams if hp in possible_hparams] - if not hparams: - print( - f"No hyper-parameters to optimize. " - f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}" - ) - exit(1) + msg = f"No hyper-parameters to optimize. " \ + f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}" + assert hparams, msg # Use a custom storage if given if args.output is not None: diff --git a/src/diart/inference.py b/src/diart/inference.py index 14ab4736..ee22a3cb 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -2,7 +2,7 @@ from multiprocessing import Pool, freeze_support, RLock, current_process from pathlib import Path from traceback import print_exc -from typing import Union, Text, Optional, Callable, Tuple, List, Any +from typing import Union, Text, Optional, Callable, Tuple, List, Any, Dict import numpy as np import pandas as pd @@ -10,8 +10,6 @@ import rx.operators as ops import torch from pyannote.core import Annotation, SlidingWindowFeature -from pyannote.database.util import load_rttm -from pyannote.metrics.base import BaseMetric from rx.core import Observer from tqdm import tqdm @@ -19,8 +17,9 @@ from . import operators as dops from . import sources as src from . import utils +from .metrics import Metric from .progress import ProgressBar, RichProgressBar, TQDMProgressBar -from .sinks import DiarizationAccumulator, StreamingPlot, WindowClosedException +from .sinks import StreamingPlot, WindowClosedException class StreamingInference: @@ -292,7 +291,7 @@ def run_single( pipeline: blocks.StreamingPipeline, filepath: Path, progress_bar: ProgressBar, - ) -> Annotation: + ) -> Tuple[Text, Any]: """Run a given pipeline on a given file. Note that this method does NOT reset the state of the pipeline before execution. @@ -329,36 +328,29 @@ def run_single( progress_bar=progress_bar, ) - # Accumulate predictions in memory - pred_accumulator = DiarizationAccumulator(source.uri) - inference.attach_observers(pred_accumulator) - - # Run the pipeline on this file - inference() - - # Extract prediction - pred = pred_accumulator.get_prediction() + # Run the pipeline and concatenate predictions + pred = pipeline.join_predictions(inference()) + # Write prediction to disk if required if self.output_path is not None: - with open(self.output_path / f"{source.uri}.rttm", "w") as out_file: - pred.write_rttm(out_file) + pipeline.write_prediction(source.uri, pred, self.output_path) - return pred + return source.uri, pred def evaluate( self, - predictions: List[Annotation], - metric: BaseMetric, - ) -> Union[pd.DataFrame, List[Annotation]]: + predictions: Dict[Text, Any], + metric: Metric, + ) -> Union[pd.DataFrame, Dict[Text, Any]]: """If a reference path was provided, compute the diarization error rate of a list of predictions. Parameters ---------- - predictions: List[Annotation] + predictions: List[Any] Predictions to evaluate. - metric: BaseMetric - Evaluation metric from pyannote.metrics. + metric: Metric + Evaluation metric. Returns ------- @@ -367,23 +359,37 @@ def evaluate( reference path was given. Otherwise return the same predictions. """ if self.reference_path is not None: + # Initialize progress bar progress_bar = TQDMProgressBar(f"Computing {metric.name}", leave=False) progress_bar.create(total=len(predictions), unit="file") progress_bar.start() - for hyp in predictions: - ref = load_rttm(self.reference_path / f"{hyp.uri}.rttm").popitem()[1] - metric(ref, hyp) + + # Evaluate each prediction + uris = [] + for uri, pred in predictions.items(): + ref_file = list(self.reference_path.glob(f"{uri}.*")) + if ref_file: + ref = metric.load_reference(ref_file[0]) + metric(ref, pred) + uris.append(uri) + else: + msg = f"Reference file for {uri} not found. Skipping evaluation." + logging.warning(msg) progress_bar.update() + + # Close progress bar safely progress_bar.close() - return metric.report(display=self.show_report) + # Return performance report + return metric.report(uris, self.show_report) + return predictions def __call__( self, pipeline_class: type, config: blocks.StreamingConfig, - metric: Optional[BaseMetric] = None, - ) -> Union[pd.DataFrame, List[Annotation]]: + metric: Optional[Metric] = None, + ) -> Union[pd.DataFrame, Dict[Text, Any]]: """Run a given pipeline on a set of audio files. The internal state of the pipeline is reset before benchmarking. @@ -394,8 +400,8 @@ def __call__( A pipeline from this class will be instantiated by each worker. config: StreamingConfig Streaming pipeline configuration. - metric: Optional[BaseMetric] - Evaluation metric from pyannote.metrics. + metric: Optional[Metric] + Evaluation metric. Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`) Returns @@ -410,12 +416,13 @@ def __call__( num_audio_files = len(audio_file_paths) pipeline = pipeline_class(config) - predictions = [] + predictions = {} for i, filepath in enumerate(audio_file_paths): pipeline.reset() desc = f"Streaming {filepath.stem} ({i + 1}/{num_audio_files})" progress = TQDMProgressBar(desc, leave=False, do_close=True) - predictions.append(self.run_single(pipeline, filepath, progress)) + uri, pred = self.run_single(pipeline, filepath, progress) + predictions[uri] = pred metric = pipeline.suggest_metric() if metric is None else metric return self.evaluate(predictions, metric) @@ -447,7 +454,7 @@ def run_single_job( config: blocks.StreamingConfig, filepath: Path, description: Text, - ) -> Annotation: + ) -> Tuple[Text, Any]: """Build and run a pipeline on a single file. Configure execution to show progress alongside parallel runs. @@ -482,8 +489,8 @@ def __call__( self, pipeline_class: type, config: blocks.StreamingConfig, - metric: Optional[BaseMetric] = None, - ) -> Union[pd.DataFrame, List[Annotation]]: + metric: Optional[Metric] = None, + ) -> Union[pd.DataFrame, Dict[Text, Any]]: """Run a given pipeline on a set of audio files in parallel. Each worker will build and run the pipeline on a different file. @@ -494,8 +501,8 @@ def __call__( A pipeline from this class will be instantiated by each worker. config: StreamingConfig Streaming pipeline configuration. - metric: Optional[BaseMetric] - Evaluation metric from pyannote.metrics. + metric: Optional[Metric] + Evaluation metric. Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`) Returns @@ -529,9 +536,14 @@ def __call__( # Submit all jobs jobs = [pool.apply_async(self.run_single_job, args=args) for args in arg_list] - # Wait and collect results + # Wait for all jobs to finish pool.close() - predictions = [job.get() for job in jobs] + + # Collect results + predictions = {} + for job in jobs: + uri, pred = job.get() + predictions[uri] = pred # Evaluate results metric = pipeline_class.suggest_metric() if metric is None else metric diff --git a/src/diart/metrics.py b/src/diart/metrics.py new file mode 100644 index 00000000..d5ae56e4 --- /dev/null +++ b/src/diart/metrics.py @@ -0,0 +1,96 @@ +from pathlib import Path +from typing import Text, Any, List, Union + +import pandas as pd +from pyannote.core import Annotation +from pyannote.metrics import diarization as dia, detection as det +from pyannote.metrics.base import BaseMetric as PyannoteBaseMetric +from pyannote.database.util import load_rttm +from torchmetrics import text + + +class Metric: + @property + def name(self) -> Text: + raise NotImplementedError + + def __call__(self, reference: Any, prediction: Any) -> float: + raise NotImplementedError + + def report(self, uris: List[Text], display: bool = False) -> pd.DataFrame: + raise NotImplementedError + + def load_reference(self, filepath: Union[Text, Path]) -> Any: + raise NotImplementedError + + +class PyannoteMetric(Metric): + def __init__(self, metric: PyannoteBaseMetric): + self._metric = metric + + @property + def name(self) -> Text: + return self._metric.name + + def __call__(self, reference: Annotation, prediction: Annotation) -> float: + return self._metric(reference, prediction) + + def report(self, uris: List[Text], display: bool = False) -> pd.DataFrame: + return self._metric.report(display) + + def load_reference(self, filepath: Union[Text, Path]) -> Annotation: + return load_rttm(filepath).popitem()[1] + + +class DiarizationErrorRate(PyannoteMetric): + def __init__(self, collar: float = 0, skip_overlap: bool = False): + super().__init__(dia.DiarizationErrorRate(collar, skip_overlap)) + + +class DetectionErrorRate(PyannoteMetric): + def __init__(self, collar: float = 0, skip_overlap: bool = False): + super().__init__(det.DetectionErrorRate(collar, skip_overlap)) + + +class WordErrorRate(Metric): + def __init__(self, unify_case: bool = False): + self.unify_case = unify_case + self._metric = text.WordErrorRate() + self._values = [] + + @property + def name(self) -> Text: + return "word error rate" + + def __call__(self, reference: Text, prediction: Text) -> float: + if self.unify_case: + prediction = prediction.lower() + reference = reference.lower() + # Torchmetrics requires predictions first, then reference + value = self._metric(prediction, reference).item() + self._values.append(value) + return value + + def report(self, uris: List[Text], display: bool = False) -> pd.DataFrame: + num_uris, num_values = len(uris), len(self._values) + msg = f"URI list size must match values. Found {num_uris} but expected {num_values}" + assert num_uris == num_values, msg + + rows = self._values + [self._metric.compute().item()] + index = uris + ["TOTAL"] + report = pd.DataFrame(rows, index=index, columns=[self.name]) + + if display: + print(report.to_string( + index=True, + sparsify=False, + justify="right", + float_format=lambda f: "{0:.2f}".format(f), + )) + + return report + + def load_reference(self, filepath: Union[Text, Path]) -> Text: + with open(filepath, "r") as file: + lines = [line.strip() for line in file.readlines()] + return " ".join(lines) From d8d73428fd475b1cc9515862188c2ec634abc6e3 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Fri, 21 Apr 2023 17:20:06 +0200 Subject: [PATCH 07/23] Reduce Whisper VRAM footprint (around 400Mb). Add fp16 option --- src/diart/blocks/asr.py | 4 +++- src/diart/models.py | 7 ++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py index 0e034a3e..976ddb1d 100644 --- a/src/diart/blocks/asr.py +++ b/src/diart/blocks/asr.py @@ -13,6 +13,7 @@ from ..features import TemporalFeatureFormatter, TemporalFeatures from ..metrics import Metric, WordErrorRate + BeamSize = HyperParameter("beam_size", low=1, high=20) @@ -32,10 +33,11 @@ def from_whisper( download_path: Optional[Union[Text, Path]] = None, in_memory: bool = False, remember_transcriptions: bool = True, + fp16: bool = False, device: Optional[Union[Text, torch.device]] = None, ) -> 'SpeechRecognition': asr_model = m.SpeechRecognitionModel.from_whisper( - name, download_path, in_memory, remember_transcriptions + name, download_path, in_memory, remember_transcriptions, fp16 ) return SpeechRecognition(asr_model, device) diff --git a/src/diart/models.py b/src/diart/models.py index 7a653837..bffebe0d 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -43,6 +43,7 @@ def __init__( def __call__(self) -> nn.Module: return whisper.load_model( name=self.name, + device="cpu", download_root=self.download_path, in_memory=self.in_memory, ) @@ -207,12 +208,13 @@ def from_whisper( download_path: Optional[Union[Text, Path]] = None, in_memory: bool = False, remember_transcriptions: bool = True, + fp16: bool = False, ) -> 'SpeechRecognitionModel': msg = "No whisper-transcribed installation found. " \ "Visit https://github.com/linto-ai/whisper-timestamped#installation to install" assert _has_whisper, msg return WhisperSpeechRecognitionModel( - name, download_path, in_memory, remember_transcriptions + name, download_path, in_memory, remember_transcriptions, fp16 ) @property @@ -250,9 +252,11 @@ def __init__( download_path: Optional[Union[Text, Path]] = None, in_memory: bool = False, remember_transcriptions: bool = True, + fp16: bool = False, ): super().__init__(WhisperLoader(name, download_path, in_memory)) self.remember_transcriptions = remember_transcriptions + self.fp16 = fp16 self.language = None self._cache = None @@ -279,6 +283,7 @@ def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]: verbose=None, task="transcribe", language=self.language, + fp16=self.fp16, ) # Extract chunks and timestamps From 2cfc35da8035f07c27c3336e0d44ee5e499cd983 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Fri, 21 Apr 2023 17:33:06 +0200 Subject: [PATCH 08/23] Change whisper input type based on fp16 parameter --- src/diart/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diart/models.py b/src/diart/models.py index bffebe0d..57921df6 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -275,7 +275,8 @@ def set_language(self, language: Optional[Text] = None): def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]: results = [] for waveform in waveform_batch: - audio = whisper.pad_or_trim(waveform.type(torch.float32).reshape(-1)) + dtype = torch.float16 if self.fp16 else torch.float32 + audio = whisper.pad_or_trim(waveform.type(dtype).reshape(-1)) transcription = whisper.transcribe( self.model, audio, From a40112c42543e4d70d37c2fe999530a161e2b7a9 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Sat, 22 Apr 2023 17:40:35 +0200 Subject: [PATCH 09/23] Implement batched inference for whisper. Re-implement decoding. --- src/diart/blocks/asr.py | 78 +++++++-------- src/diart/models.py | 214 +++++++++++++++++++++++++++++++++------- 2 files changed, 216 insertions(+), 76 deletions(-) diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py index 976ddb1d..c485724e 100644 --- a/src/diart/blocks/asr.py +++ b/src/diart/blocks/asr.py @@ -32,12 +32,11 @@ def from_whisper( name: Text, download_path: Optional[Union[Text, Path]] = None, in_memory: bool = False, - remember_transcriptions: bool = True, fp16: bool = False, device: Optional[Union[Text, torch.device]] = None, ) -> 'SpeechRecognition': asr_model = m.SpeechRecognitionModel.from_whisper( - name, download_path, in_memory, remember_transcriptions, fp16 + name, download_path, in_memory, fp16 ) return SpeechRecognition(asr_model, device) @@ -173,15 +172,12 @@ def __call__( diarization: Optional[Sequence[Annotation]] = None, **kwargs ) -> Sequence[Tuple[Text, SlidingWindowFeature]]: - # TODO implement batched inference - only_one_dia = diarization is None or len(diarization) == 1 - msg = "Batched inference is not yet supported for 'Transcription'. Please set batch size to 1" - assert len(waveforms) == 1 and only_one_dia, msg + batch_size = len(waveforms) + msg = "Pipeline expected at least 1 input" + assert batch_size >= 1, msg - waveform = waveforms[0] - - # Add fake batch dimension, shape (1, samples, channels) - batch = torch.from_numpy(waveform.data).unsqueeze(0) + # Create batch from chunk sequence, shape (batch, samples, channels) + batch = torch.stack([torch.from_numpy(w.data) for w in waveforms]) expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate)) msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" @@ -189,34 +185,34 @@ def __call__( # Transcribe batch # TODO only transcribe if there's speech - output = self.asr(batch)[0] - - if diarization is None: - return [(output.text.strip(), waveform)] - - diarization = diarization[0] - - # Align transcription with diarization to determine speakers - full_transcription = [] - buffer_shift = waveform.sliding_window.start - for text, timestamp in zip(output.chunks, output.timestamps): - target_region = Segment( - buffer_shift + timestamp.start, - buffer_shift + timestamp.end - ) - dia = diarization.crop(target_region) - speakers = dia.labels() - num_speakers = len(speakers) - if num_speakers == 0: - # Include transcription but don't assign a speaker - full_transcription.append(text) - elif num_speakers == 1: - # Typical case, annotate text with the only speaker - full_transcription.append(f"[{speakers[0]}]{text}") - else: - # Multiple speakers for the same text block, choose the most active one - # TODO match at the level of words? - max_spk = np.argmax([dia.label_duration(spk) for spk in speakers]) - full_transcription.append(f"[{speakers[max_spk]}]{text}") - - return [(" ".join(full_transcription).strip(), waveform)] + outputs = self.asr(batch) + + return [(out.text, wav) for out, wav in zip(outputs, waveforms)] + + # TODO align text with speakers if diarization is not None + + # diarization = diarization[0] + # + # # Align transcription with diarization to determine speakers + # full_transcription = [] + # buffer_shift = waveform.sliding_window.start + # for text, timestamp in zip(outputs.chunks, outputs.timestamps): + # target_region = Segment( + # buffer_shift + timestamp.start, + # buffer_shift + timestamp.end + # ) + # dia = diarization.crop(target_region) + # speakers = dia.labels() + # num_speakers = len(speakers) + # if num_speakers == 0: + # # Include transcription but don't assign a speaker + # full_transcription.append(text) + # elif num_speakers == 1: + # # Typical case, annotate text with the only speaker + # full_transcription.append(f"[{speakers[0]}]{text}") + # else: + # # Multiple speakers for the same text block, choose the most active one + # max_spk = np.argmax([dia.label_duration(spk) for spk in speakers]) + # full_transcription.append(f"[{speakers[max_spk]}]{text}") + # + # return [(" ".join(full_transcription).strip(), waveform)] diff --git a/src/diart/models.py b/src/diart/models.py index 57921df6..0879ef9a 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -1,7 +1,9 @@ +import time from dataclasses import dataclass from pathlib import Path -from typing import Optional, Text, Union, Callable, List, Tuple, Dict +from typing import Optional, Text, Union, Callable, List, Any +import numpy as np import torch import torch.nn as nn from pyannote.core import Segment @@ -14,9 +16,16 @@ try: import whisper + from whisper.tokenizer import get_tokenizer _has_whisper = True + DecodingResult = whisper.DecodingResult + DecodingOptions = whisper.DecodingOptions + Tokenizer = whisper.tokenizer.Tokenizer except ImportError: _has_whisper = False + DecodingResult = Any + DecodingOptions = Any + Tokenizer = Any class PyannoteLoader: @@ -207,15 +216,12 @@ def from_whisper( name: Text, download_path: Optional[Union[Text, Path]] = None, in_memory: bool = False, - remember_transcriptions: bool = True, fp16: bool = False, ) -> 'SpeechRecognitionModel': msg = "No whisper-transcribed installation found. " \ "Visit https://github.com/linto-ai/whisper-timestamped#installation to install" assert _has_whisper, msg - return WhisperSpeechRecognitionModel( - name, download_path, in_memory, remember_transcriptions, fp16 - ) + return WhisperSpeechRecognitionModel(name, download_path, in_memory, fp16) @property def duration(self) -> float: @@ -245,20 +251,141 @@ def forward(self, waveform: torch.Tensor) -> List[Transcription]: raise NotImplementedError +class WhisperDecoder: + def __init__( + self, + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1, + ): + self.compression_ratio_threshold = compression_ratio_threshold + self.logprob_threshold = logprob_threshold + self.temperatures = (0, 0.2, 0.4, 0.6, 0.8, 1) + + @staticmethod + def get_temperature_options(initial: DecodingOptions, t: float) -> DecodingOptions: + t_options = {**vars(initial)} + if t > 0: + t_options.pop("beam_size", None) + t_options.pop("patience", None) + else: + t_options.pop("best_of", None) + t_options["temperature"] = t + return DecodingOptions(**t_options) + + @staticmethod + def decode( + model, + batch: torch.Tensor, + options: DecodingOptions + ) -> DecodingResult: + return model.decode(batch, options) + + def check_compression(self) -> bool: + return self.compression_ratio_threshold is not None + + def check_logprob(self) -> bool: + return self.logprob_threshold is not None + + def needs_fallback(self, output: DecodingResult) -> bool: + if self.check_compression and output.compression_ratio > self.compression_ratio_threshold: + # Transcription is too repetitive + return True + if self.check_logprob and output.avg_logprob < self.logprob_threshold: + # Average log probability is too low + return True + return False + + def decode_with_fallback( + self, + model, + batch: torch.Tensor, + options: DecodingOptions, + ) -> DecodingResult: + batch_size = batch.shape[0] + results = [None] * batch_size + retry_idx = torch.ones(batch_size).type(torch.bool) + + for t in self.temperatures: + # Transcribe with the given temperature + t_options = self.get_temperature_options(options, t) + outputs = model.decode(batch[retry_idx], t_options) + + # Determine which outputs need to be transcribed again + # based on quality estimates + output_idx = torch.where(retry_idx)[0] + for idx, out in zip(output_idx, outputs): + results[idx] = out + if not self.needs_fallback(out): + retry_idx[idx] = False + + # No output needs fallback, get out of the loop + if torch.sum(retry_idx).item() == 0: + break + + return results + + @staticmethod + def split_with_timestamps( + result: DecodingResult, + tokenizer: Tokenizer, + chunk_duration: float, + token_duration: float, + ) -> Transcription: + tokens = torch.tensor(result.tokens) + chunks, timestamps = [], [] + ts_tokens = tokens.ge(tokenizer.timestamp_begin) + single_ts_ending = ts_tokens[-2:].tolist() == [False, True] + consecutive = torch.where(ts_tokens[:-1] & ts_tokens[1:])[0] + 1 + if len(consecutive) > 0: + # Output contains two consecutive timestamp tokens + slices = consecutive.tolist() + if single_ts_ending: + slices.append(len(tokens)) + + last_slice = 0 + for current_slice in slices: + sliced_tokens = tokens[last_slice:current_slice] + start_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin + end_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin + text_tokens = [token for token in sliced_tokens if token < tokenizer.eot] + text = tokenizer.decode(text_tokens).strip() + timestamp = Segment(start_pos * token_duration, end_pos * token_duration) + if text and timestamp.start != timestamp.end: + chunks.append(text) + timestamps.append(timestamp) + last_slice = current_slice + else: + duration = chunk_duration + ts = tokens[ts_tokens.nonzero().flatten()] + if len(ts) > 0 and ts[-1].item() != tokenizer.timestamp_begin: + # Use last timestamp as end time for the unique chunk + last_ts_pos = ts[-1].item() - tokenizer.timestamp_begin + duration = last_ts_pos * token_duration + text_tokens = [token for token in tokens if token < tokenizer.eot] + text = tokenizer.decode(text_tokens).strip() + if text: + chunks.append(text) + timestamps.append(Segment(0, duration)) + + return Transcription(result.text, chunks, timestamps) + + class WhisperSpeechRecognitionModel(SpeechRecognitionModel): def __init__( self, name: Text, download_path: Optional[Union[Text, Path]] = None, in_memory: bool = False, - remember_transcriptions: bool = True, fp16: bool = False, + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1, ): super().__init__(WhisperLoader(name, download_path, in_memory)) - self.remember_transcriptions = remember_transcriptions self.fp16 = fp16 + self.beam_size = None self.language = None - self._cache = None + self._token_duration: Optional[float] = None + self.decoder = WhisperDecoder(compression_ratio_threshold, logprob_threshold) @property def duration(self) -> float: @@ -269,37 +396,54 @@ def duration(self) -> float: def sample_rate(self) -> int: return whisper.audio.SAMPLE_RATE + @property + def token_duration(self) -> float: + if self._token_duration is None: + # 2 mel frames per output token + input_stride = int(np.rint(whisper.audio.N_FRAMES / self.model.dims.n_audio_ctx)) + # Output token duration is 0.02 seconds + self._token_duration = input_stride * whisper.audio.HOP_LENGTH / self.sample_rate + return self._token_duration + def set_language(self, language: Optional[Text] = None): self.language = language - def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]: - results = [] - for waveform in waveform_batch: - dtype = torch.float16 if self.fp16 else torch.float32 - audio = whisper.pad_or_trim(waveform.type(dtype).reshape(-1)) - transcription = whisper.transcribe( - self.model, - audio, - initial_prompt=self._cache, - verbose=None, - task="transcribe", - language=self.language, - fp16=self.fp16, - ) + def set_beam_size(self, size: int): + self.beam_size = size - # Extract chunks and timestamps - chunks, timestamps = [], [] - for chunk in transcription["segments"]: - chunks.append(chunk["text"]) - timestamps.append(Segment(chunk["start"], chunk["end"])) + def set_fp16(self, value: bool): + self.fp16 = value - # Create transcription object - transcription = Transcription(transcription["text"], chunks, timestamps) - results.append(transcription) + def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]: + # Remove channel dimension + batch = waveform_batch.squeeze(1) + num_chunk_samples = batch.shape[-1] + # Compute log mel spectrogram + batch = whisper.log_mel_spectrogram(batch) + # Add padding + dtype = torch.float16 if self.fp16 else torch.float32 + batch = whisper.pad_or_trim(batch, whisper.audio.N_FRAMES).to(batch.device).type(dtype) + + # Transcribe batch + options = whisper.DecodingOptions( + task="transcribe", + language=self.language, + beam_size=self.beam_size, + fp16=self.fp16, + ) + results = self.decoder.decode_with_fallback(self.model, batch, options) + tokenizer = get_tokenizer( + self.model.is_multilingual, + language=options.language, + task=options.task, + ) - # Update transcription buffer - if self.remember_transcriptions: - # TODO handle overlapping transcriptions - self._cache = transcription.text + chunk_duration = int(np.rint(num_chunk_samples / self.sample_rate)) + transcriptions = [ + self.decoder.split_with_timestamps( + res, tokenizer, chunk_duration, self.token_duration + ) + for res in results + ] - return results + return transcriptions From e8196a7ce8c7cc75f769f706265c735b539f770b Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Sat, 22 Apr 2023 18:31:48 +0200 Subject: [PATCH 10/23] Minor changes in transcription arguments --- src/diart/blocks/asr.py | 17 +++++------------ src/diart/blocks/base.py | 1 - src/diart/blocks/diarization.py | 1 - src/diart/blocks/vad.py | 1 - src/diart/models.py | 8 ++++---- 5 files changed, 9 insertions(+), 19 deletions(-) diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py index c485724e..ecd13b8e 100644 --- a/src/diart/blocks/asr.py +++ b/src/diart/blocks/asr.py @@ -4,7 +4,7 @@ import numpy as np import torch from einops import rearrange -from pyannote.core import SlidingWindowFeature, Annotation, Segment +from pyannote.core import SlidingWindowFeature from . import base from .. import models as m @@ -13,7 +13,6 @@ from ..features import TemporalFeatureFormatter, TemporalFeatures from ..metrics import Metric, WordErrorRate - BeamSize = HyperParameter("beam_size", low=1, high=20) @@ -70,7 +69,7 @@ def __init__( asr: Optional[m.SpeechRecognitionModel] = None, duration: Optional[float] = None, language: Optional[Text] = None, - beam_size: int = 5, + beam_size: int = None, device: Optional[torch.device] = None, ): self.device = device @@ -82,12 +81,11 @@ def __init__( if self.asr is None: self.asr = m.SpeechRecognitionModel.from_whisper("small") self.asr.set_language(language) + self.asr.set_beam_size(beam_size) self._duration = duration self._sample_rate: Optional[int] = None - self.beam_size = beam_size - @property def duration(self) -> float: if self._duration is None: @@ -122,16 +120,13 @@ def from_dict(data: Any) -> 'TranscriptionConfig': asr=asr, duration=utils.get(data, "duration", None), language=utils.get(data, "language", None), - beam_size=utils.get(data, "beam_size", 5), + beam_size=utils.get(data, "beam_size", None), device=device, ) class Transcription(base.StreamingPipeline): - def __init__( - self, - config: Optional[TranscriptionConfig] = None, - ): + def __init__(self, config: Optional[TranscriptionConfig] = None): self._config = TranscriptionConfig() if config is None else config self.asr = SpeechRecognition(self.config.asr, self.config.device) @@ -169,8 +164,6 @@ def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Pa def __call__( self, waveforms: Sequence[SlidingWindowFeature], - diarization: Optional[Sequence[Annotation]] = None, - **kwargs ) -> Sequence[Tuple[Text, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index 40d1d22d..6494a9bf 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -95,6 +95,5 @@ def write_prediction(self, uri: Text, prediction: Any, dir_path: Union[Text, Pat def __call__( self, waveforms: Sequence[SlidingWindowFeature], - **kwargs, ) -> Sequence[Tuple[Any, SlidingWindowFeature]]: raise NotImplementedError diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index 2f8de3f5..fe3f4c98 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -206,7 +206,6 @@ def reset(self): def __call__( self, waveforms: Sequence[SlidingWindowFeature], - **kwargs, ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py index 47afb4da..42061c86 100644 --- a/src/diart/blocks/vad.py +++ b/src/diart/blocks/vad.py @@ -161,7 +161,6 @@ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Te def __call__( self, waveforms: Sequence[SlidingWindowFeature], - **kwargs, ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: batch_size = len(waveforms) msg = "Pipeline expected at least 1 input" diff --git a/src/diart/models.py b/src/diart/models.py index 0879ef9a..274478c8 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -234,6 +234,9 @@ def sample_rate(self) -> int: def set_language(self, language: Optional[Text] = None): raise NotImplementedError + def set_beam_size(self, size: Optional[int] = None): + raise NotImplementedError + def forward(self, waveform: torch.Tensor) -> List[Transcription]: """ Forward pass of the speech recognition model. @@ -408,12 +411,9 @@ def token_duration(self) -> float: def set_language(self, language: Optional[Text] = None): self.language = language - def set_beam_size(self, size: int): + def set_beam_size(self, size: Optional[int] = None): self.beam_size = size - def set_fp16(self, value: bool): - self.fp16 = value - def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]: # Remove channel dimension batch = waveform_batch.squeeze(1) From 07dd9ae36781cdc25bd379c1f4f1cdaa4d3862da Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Sun, 23 Apr 2023 15:15:54 +0200 Subject: [PATCH 11/23] Greatly improve transcription pipeline by adding optional VAD --- src/diart/blocks/asr.py | 57 ++++++++++++++++++----- src/diart/models.py | 101 ++++++++++++++++++++++++++++++++++------ 2 files changed, 132 insertions(+), 26 deletions(-) diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py index ecd13b8e..fdef7a15 100644 --- a/src/diart/blocks/asr.py +++ b/src/diart/blocks/asr.py @@ -9,6 +9,7 @@ from . import base from .. import models as m from .. import utils +from ..blocks import SpeakerSegmentation from ..blocks.base import HyperParameter from ..features import TemporalFeatureFormatter, TemporalFeatures from ..metrics import Metric, WordErrorRate @@ -32,14 +33,25 @@ def from_whisper( download_path: Optional[Union[Text, Path]] = None, in_memory: bool = False, fp16: bool = False, + no_speech_threshold: float = 0.6, + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1, + decode_with_fallback: bool = False, device: Optional[Union[Text, torch.device]] = None, ) -> 'SpeechRecognition': asr_model = m.SpeechRecognitionModel.from_whisper( - name, download_path, in_memory, fp16 + name, + download_path, + in_memory, + fp16, + no_speech_threshold, + compression_ratio_threshold, + logprob_threshold, + decode_with_fallback, ) return SpeechRecognition(asr_model, device) - def __call__(self, waveform: TemporalFeatures) -> List[m.Transcription]: + def __call__(self, waveform: TemporalFeatures) -> List[m.TranscriptionResult]: """ Compute the transcription of input audio. @@ -67,6 +79,8 @@ class TranscriptionConfig(base.StreamingConfig): def __init__( self, asr: Optional[m.SpeechRecognitionModel] = None, + vad: Optional[m.SegmentationModel] = None, + tau_active: float = 0.5, duration: Optional[float] = None, language: Optional[Text] = None, beam_size: int = None, @@ -83,6 +97,9 @@ def __init__( self.asr.set_language(language) self.asr.set_beam_size(beam_size) + self.vad = vad + self.tau_active = tau_active + self._duration = duration self._sample_rate: Optional[int] = None @@ -112,12 +129,10 @@ def from_dict(data: Any) -> 'TranscriptionConfig': device = utils.get(data, "device", None) if device is None: device = torch.device("cpu") if utils.get(data, "cpu", False) else None - - name = utils.get(data, "whisper", "small") - asr = m.SpeechRecognitionModel.from_whisper(name) - return TranscriptionConfig( - asr=asr, + asr=utils.get(data, "asr", None), + vad=utils.get(data, "vad", None), + tau_active=utils.get(data, "tau_active", None), duration=utils.get(data, "duration", None), language=utils.get(data, "language", None), beam_size=utils.get(data, "beam_size", None), @@ -129,6 +144,9 @@ class Transcription(base.StreamingPipeline): def __init__(self, config: Optional[TranscriptionConfig] = None): self._config = TranscriptionConfig() if config is None else config self.asr = SpeechRecognition(self.config.asr, self.config.device) + self.segmentation = None + if self.config.vad is not None: + self.segmentation = SpeakerSegmentation(self.config.vad, self.config.device) @staticmethod def get_config_class() -> type: @@ -176,11 +194,28 @@ def __call__( msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" assert batch.shape[1] == expected_num_samples, msg - # Transcribe batch - # TODO only transcribe if there's speech - outputs = self.asr(batch) + # Run voice detection if required + if self.segmentation is None: + has_voice = torch.arange(0, batch_size) + else: + segmentations = self.segmentation(batch) # shape (batch, frames, speakers) + has_voice = torch.max(segmentations, dim=-1)[0] # shape (batch, frames) + has_voice = torch.any(has_voice >= self.config.tau_active, dim=-1) # shape (batch,) + has_voice = torch.where(has_voice)[0] + + # Return empty strings if no speech in the entire batch + if len(has_voice) == 0: + return [("", wav) for wav in waveforms] - return [(out.text, wav) for out, wav in zip(outputs, waveforms)] + # Transcribe batch + outputs = self.asr(batch[has_voice]) + mapping = {i_voice.item(): i_output for i_output, i_voice in enumerate(has_voice)} + + # No-speech outputs are empty strings + return [ + (outputs[mapping[i]].text if i in has_voice else "", waveforms[i]) + for i in range(batch_size) + ] # TODO align text with speakers if diarization is not None diff --git a/src/diart/models.py b/src/diart/models.py index 274478c8..5afbec45 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -204,7 +204,7 @@ def forward( @dataclass(frozen=True) -class Transcription: +class TranscriptionResult: text: Text chunks: List[Text] timestamps: List[Segment] @@ -217,11 +217,24 @@ def from_whisper( download_path: Optional[Union[Text, Path]] = None, in_memory: bool = False, fp16: bool = False, + no_speech_threshold: float = 0.6, + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1, + decode_with_fallback: bool = False, ) -> 'SpeechRecognitionModel': msg = "No whisper-transcribed installation found. " \ "Visit https://github.com/linto-ai/whisper-timestamped#installation to install" assert _has_whisper, msg - return WhisperSpeechRecognitionModel(name, download_path, in_memory, fp16) + return WhisperSpeechRecognitionModel( + name, + download_path, + in_memory, + fp16, + no_speech_threshold, + compression_ratio_threshold, + logprob_threshold, + decode_with_fallback, + ) @property def duration(self) -> float: @@ -237,7 +250,7 @@ def set_language(self, language: Optional[Text] = None): def set_beam_size(self, size: Optional[int] = None): raise NotImplementedError - def forward(self, waveform: torch.Tensor) -> List[Transcription]: + def forward(self, waveform: torch.Tensor) -> List[TranscriptionResult]: """ Forward pass of the speech recognition model. @@ -248,7 +261,7 @@ def forward(self, waveform: torch.Tensor) -> List[Transcription]: Returns ------- - transcriptions: List[Transcription] + transcriptions: List[TranscriptionResult] A list of timestamped transcriptions """ raise NotImplementedError @@ -257,9 +270,11 @@ def forward(self, waveform: torch.Tensor) -> List[Transcription]: class WhisperDecoder: def __init__( self, + no_speech_threshold: float = 0.6, compression_ratio_threshold: Optional[float] = 2.4, logprob_threshold: Optional[float] = -1, ): + self.no_speech_threshold = no_speech_threshold self.compression_ratio_threshold = compression_ratio_threshold self.logprob_threshold = logprob_threshold self.temperatures = (0, 0.2, 0.4, 0.6, 0.8, 1) @@ -303,7 +318,24 @@ def decode_with_fallback( model, batch: torch.Tensor, options: DecodingOptions, - ) -> DecodingResult: + ) -> List[DecodingResult]: + """Transcribe batch and retry with ever-increasing + temperatures if the estimated quality of the transcription is not good. + + Parameters + ---------- + model: whisper.Whisper + Whisper ASR model (contains 'decode' method). + batch: torch.Tensor, shape (batch, channel, samples) + Log mel spectrogram batch. + options: whisper.DecodingOptions + Configuration to decode transcription. + + Returns + ------- + result: List[whisper.DecodingResult] + Transcription results for this batch. + """ batch_size = batch.shape[0] results = [None] * batch_size retry_idx = torch.ones(batch_size).type(torch.bool) @@ -314,26 +346,51 @@ def decode_with_fallback( outputs = model.decode(batch[retry_idx], t_options) # Determine which outputs need to be transcribed again - # based on quality estimates output_idx = torch.where(retry_idx)[0] for idx, out in zip(output_idx, outputs): results[idx] = out if not self.needs_fallback(out): retry_idx[idx] = False - # No output needs fallback, get out of the loop + # No output needs fallback, get out of the loop early if torch.sum(retry_idx).item() == 0: break return results - @staticmethod def split_with_timestamps( + self, result: DecodingResult, tokenizer: Tokenizer, chunk_duration: float, token_duration: float, - ) -> Transcription: + ) -> TranscriptionResult: + """Split a Whisper transcription into segments with their respective timestamps. + Replace with empty string if no-speech probability is high. + + Parameters + ---------- + result: whisper.DecodingResult + A single transcription output from Whisper. + tokenizer: whisper.tokenizer.Tokenizer + Tokenizer needed to decode outputs. + chunk_duration: float + Actual duration of each input chunk. + token_duration: float + Duration of each output token. + + Returns + ------- + result: TranscriptionResult + Transcription with identified segments and timestamps. + """ + # Check if the model detects no speech and do not decode + if self.no_speech_threshold is not None: + no_speech = result.no_speech_prob > self.no_speech_threshold + low_confidence = self.logprob_threshold is None or result.avg_logprob < self.logprob_threshold + if no_speech and low_confidence: + return TranscriptionResult("", [""], [Segment(0, chunk_duration)]) + tokens = torch.tensor(result.tokens) chunks, timestamps = [], [] ts_tokens = tokens.ge(tokenizer.timestamp_begin) @@ -345,6 +402,7 @@ def split_with_timestamps( if single_ts_ending: slices.append(len(tokens)) + # Split into segments based on timestamp tokens last_slice = 0 for current_slice in slices: sliced_tokens = tokens[last_slice:current_slice] @@ -358,6 +416,7 @@ def split_with_timestamps( timestamps.append(timestamp) last_slice = current_slice else: + # There is a single segment, identify timestamps duration = chunk_duration ts = tokens[ts_tokens.nonzero().flatten()] if len(ts) > 0 and ts[-1].item() != tokenizer.timestamp_begin: @@ -370,7 +429,7 @@ def split_with_timestamps( chunks.append(text) timestamps.append(Segment(0, duration)) - return Transcription(result.text, chunks, timestamps) + return TranscriptionResult(result.text, chunks, timestamps) class WhisperSpeechRecognitionModel(SpeechRecognitionModel): @@ -380,15 +439,20 @@ def __init__( download_path: Optional[Union[Text, Path]] = None, in_memory: bool = False, fp16: bool = False, + no_speech_threshold: float = 0.6, compression_ratio_threshold: Optional[float] = 2.4, logprob_threshold: Optional[float] = -1, + decode_with_fallback: bool = False, ): super().__init__(WhisperLoader(name, download_path, in_memory)) self.fp16 = fp16 self.beam_size = None self.language = None + self.decode_with_fallback = decode_with_fallback + self.decoder = WhisperDecoder( + no_speech_threshold, compression_ratio_threshold, logprob_threshold + ) self._token_duration: Optional[float] = None - self.decoder = WhisperDecoder(compression_ratio_threshold, logprob_threshold) @property def duration(self) -> float: @@ -414,7 +478,7 @@ def set_language(self, language: Optional[Text] = None): def set_beam_size(self, size: Optional[int] = None): self.beam_size = size - def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]: + def forward(self, waveform_batch: torch.Tensor) -> List[TranscriptionResult]: # Remove channel dimension batch = waveform_batch.squeeze(1) num_chunk_samples = batch.shape[-1] @@ -424,20 +488,27 @@ def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]: dtype = torch.float16 if self.fp16 else torch.float32 batch = whisper.pad_or_trim(batch, whisper.audio.N_FRAMES).to(batch.device).type(dtype) - # Transcribe batch + # Configure transcription decoding options = whisper.DecodingOptions( task="transcribe", language=self.language, beam_size=self.beam_size, fp16=self.fp16, ) - results = self.decoder.decode_with_fallback(self.model, batch, options) + + # Transcribe batch with fallback if required + if self.decode_with_fallback: + decode_fn = self.decoder.decode_with_fallback + else: + decode_fn = self.decoder.decode + results = decode_fn(self.model, batch, options) + + # Split into segments and add timestamps tokenizer = get_tokenizer( self.model.is_multilingual, language=options.language, task=options.task, ) - chunk_duration = int(np.rint(num_chunk_samples / self.sample_rate)) transcriptions = [ self.decoder.split_with_timestamps( From 0bf25228d6b280e5be589fc93d7a1c1f7547d03f Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Sun, 23 Apr 2023 16:30:25 +0200 Subject: [PATCH 12/23] Move pipelines to diart.pipelines. Add torchmetrics as a dependency --- requirements.txt | 1 + setup.cfg | 1 + src/diart/__init__.py | 8 +- src/diart/blocks/__init__.py | 6 +- src/diart/blocks/asr.py | 182 +---------------- src/diart/blocks/clustering.py | 2 +- src/diart/console/tune.py | 2 +- src/diart/inference.py | 11 +- src/diart/models.py | 1 - src/diart/optim.py | 7 +- src/diart/pipelines/__init__.py | 4 + src/diart/{blocks => pipelines}/base.py | 26 +-- .../{blocks => pipelines}/diarization.py | 23 +-- src/diart/pipelines/hparams.py | 24 +++ src/diart/pipelines/transcription.py | 184 ++++++++++++++++++ .../{blocks/vad.py => pipelines/voice.py} | 17 +- src/diart/utils.py | 4 +- 17 files changed, 256 insertions(+), 247 deletions(-) create mode 100644 src/diart/pipelines/__init__.py rename src/diart/{blocks => pipelines}/base.py (77%) rename src/diart/{blocks => pipelines}/diarization.py (93%) create mode 100644 src/diart/pipelines/hparams.py create mode 100644 src/diart/pipelines/transcription.py rename src/diart/{blocks/vad.py => pipelines/voice.py} (94%) diff --git a/requirements.txt b/requirements.txt index 50662023..3241ddc4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ pandas>=1.4.2 torch>=1.12.1 torchvision>=0.14.0 torchaudio>=0.12.1,<1.0 +torchmetrics>=0.11.1 pyannote.audio>=2.1.1 pyannote.core>=4.5 pyannote.database>=4.1.1 diff --git a/setup.cfg b/setup.cfg index e67e4426..c70eac0b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,6 +31,7 @@ install_requires= torch>=1.12.1 torchvision>=0.14.0 torchaudio>=0.12.1,<1.0 + torchmetrics>=0.11.1 pyannote.audio>=2.1.1 pyannote.core>=4.5 pyannote.database>=4.1.1 diff --git a/src/diart/__init__.py b/src/diart/__init__.py index e29287a0..0d67c9c5 100644 --- a/src/diart/__init__.py +++ b/src/diart/__init__.py @@ -1,8 +1,10 @@ -from .blocks import ( - SpeakerDiarization, +from .pipelines import ( StreamingPipeline, - SpeakerDiarizationConfig, StreamingConfig, + SpeakerDiarization, + SpeakerDiarizationConfig, VoiceActivityDetection, VoiceActivityDetectionConfig, + Transcription, + TranscriptionConfig, ) diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py index e6e8c479..96fae0e7 100644 --- a/src/diart/blocks/__init__.py +++ b/src/diart/blocks/__init__.py @@ -5,7 +5,7 @@ FirstOnlyStrategy, DelayedAggregation, ) -from .clustering import OnlineSpeakerClustering +from .clustering import IncrementalSpeakerClustering from .embedding import ( SpeakerEmbedding, OverlappedSpeechPenalty, @@ -13,7 +13,5 @@ OverlapAwareSpeakerEmbedding, ) from .segmentation import SpeakerSegmentation -from .diarization import SpeakerDiarization, SpeakerDiarizationConfig -from .base import StreamingConfig, StreamingPipeline from .utils import Binarize, Resample, AdjustVolume -from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig +from .asr import SpeechRecognition diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py index fdef7a15..83dc0d90 100644 --- a/src/diart/blocks/asr.py +++ b/src/diart/blocks/asr.py @@ -1,20 +1,11 @@ from pathlib import Path -from typing import Sequence, Optional, Any, Union, List, Text, Tuple +from typing import Optional, Union, List, Text -import numpy as np import torch from einops import rearrange -from pyannote.core import SlidingWindowFeature -from . import base from .. import models as m -from .. import utils -from ..blocks import SpeakerSegmentation -from ..blocks.base import HyperParameter from ..features import TemporalFeatureFormatter, TemporalFeatures -from ..metrics import Metric, WordErrorRate - -BeamSize = HyperParameter("beam_size", low=1, high=20) class SpeechRecognition: @@ -73,174 +64,3 @@ def __call__(self, waveform: TemporalFeatures) -> List[m.TranscriptionResult]: # output = self.model(wave.to(self.device)).cpu() output = self.model(wave.to(self.device)) return output - - -class TranscriptionConfig(base.StreamingConfig): - def __init__( - self, - asr: Optional[m.SpeechRecognitionModel] = None, - vad: Optional[m.SegmentationModel] = None, - tau_active: float = 0.5, - duration: Optional[float] = None, - language: Optional[Text] = None, - beam_size: int = None, - device: Optional[torch.device] = None, - ): - self.device = device - if self.device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Default ASR model is Whisper small (244M parameters) - self.asr = asr - if self.asr is None: - self.asr = m.SpeechRecognitionModel.from_whisper("small") - self.asr.set_language(language) - self.asr.set_beam_size(beam_size) - - self.vad = vad - self.tau_active = tau_active - - self._duration = duration - self._sample_rate: Optional[int] = None - - @property - def duration(self) -> float: - if self._duration is None: - self._duration = self.asr.duration - return self._duration - - @property - def step(self) -> float: - return self.duration - - @property - def latency(self) -> float: - return self.duration - - @property - def sample_rate(self) -> int: - if self._sample_rate is None: - self._sample_rate = self.asr.sample_rate - return self._sample_rate - - @staticmethod - def from_dict(data: Any) -> 'TranscriptionConfig': - # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None - device = utils.get(data, "device", None) - if device is None: - device = torch.device("cpu") if utils.get(data, "cpu", False) else None - return TranscriptionConfig( - asr=utils.get(data, "asr", None), - vad=utils.get(data, "vad", None), - tau_active=utils.get(data, "tau_active", None), - duration=utils.get(data, "duration", None), - language=utils.get(data, "language", None), - beam_size=utils.get(data, "beam_size", None), - device=device, - ) - - -class Transcription(base.StreamingPipeline): - def __init__(self, config: Optional[TranscriptionConfig] = None): - self._config = TranscriptionConfig() if config is None else config - self.asr = SpeechRecognition(self.config.asr, self.config.device) - self.segmentation = None - if self.config.vad is not None: - self.segmentation = SpeakerSegmentation(self.config.vad, self.config.device) - - @staticmethod - def get_config_class() -> type: - return TranscriptionConfig - - @staticmethod - def suggest_metric() -> Metric: - return WordErrorRate() - - @staticmethod - def hyper_parameters() -> Sequence[HyperParameter]: - return [BeamSize] - - @property - def config(self) -> TranscriptionConfig: - return self._config - - def reset(self): - # No internal state. Nothing to do - pass - - def set_timestamp_shift(self, shift: float): - # No timestamped output. Nothing to do - pass - - def join_predictions(self, predictions: List[Text]) -> Text: - return "\n".join(predictions) - - def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Path]): - with open(Path(dir_path) / f"{uri}.txt", "w") as out_file: - out_file.write(prediction) - - def __call__( - self, - waveforms: Sequence[SlidingWindowFeature], - ) -> Sequence[Tuple[Text, SlidingWindowFeature]]: - batch_size = len(waveforms) - msg = "Pipeline expected at least 1 input" - assert batch_size >= 1, msg - - # Create batch from chunk sequence, shape (batch, samples, channels) - batch = torch.stack([torch.from_numpy(w.data) for w in waveforms]) - - expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate)) - msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" - assert batch.shape[1] == expected_num_samples, msg - - # Run voice detection if required - if self.segmentation is None: - has_voice = torch.arange(0, batch_size) - else: - segmentations = self.segmentation(batch) # shape (batch, frames, speakers) - has_voice = torch.max(segmentations, dim=-1)[0] # shape (batch, frames) - has_voice = torch.any(has_voice >= self.config.tau_active, dim=-1) # shape (batch,) - has_voice = torch.where(has_voice)[0] - - # Return empty strings if no speech in the entire batch - if len(has_voice) == 0: - return [("", wav) for wav in waveforms] - - # Transcribe batch - outputs = self.asr(batch[has_voice]) - mapping = {i_voice.item(): i_output for i_output, i_voice in enumerate(has_voice)} - - # No-speech outputs are empty strings - return [ - (outputs[mapping[i]].text if i in has_voice else "", waveforms[i]) - for i in range(batch_size) - ] - - # TODO align text with speakers if diarization is not None - - # diarization = diarization[0] - # - # # Align transcription with diarization to determine speakers - # full_transcription = [] - # buffer_shift = waveform.sliding_window.start - # for text, timestamp in zip(outputs.chunks, outputs.timestamps): - # target_region = Segment( - # buffer_shift + timestamp.start, - # buffer_shift + timestamp.end - # ) - # dia = diarization.crop(target_region) - # speakers = dia.labels() - # num_speakers = len(speakers) - # if num_speakers == 0: - # # Include transcription but don't assign a speaker - # full_transcription.append(text) - # elif num_speakers == 1: - # # Typical case, annotate text with the only speaker - # full_transcription.append(f"[{speakers[0]}]{text}") - # else: - # # Multiple speakers for the same text block, choose the most active one - # max_spk = np.argmax([dia.label_duration(spk) for spk in speakers]) - # full_transcription.append(f"[{speakers[max_spk]}]{text}") - # - # return [(" ".join(full_transcription).strip(), waveform)] diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py index 882001b9..4b737175 100644 --- a/src/diart/blocks/clustering.py +++ b/src/diart/blocks/clustering.py @@ -7,7 +7,7 @@ from ..mapping import SpeakerMap, SpeakerMapBuilder -class OnlineSpeakerClustering: +class IncrementalSpeakerClustering: """Implements constrained incremental online clustering of speakers and manages cluster centers. Parameters diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index 6affda50..111d97c2 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -4,7 +4,7 @@ import optuna from diart import argdoc from diart import utils -from diart.blocks.base import HyperParameter +from diart.pipelines.hparams import HyperParameter from diart.optim import Optimizer from optuna.samplers import TPESampler diff --git a/src/diart/inference.py b/src/diart/inference.py index ee22a3cb..258f773a 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -18,6 +18,7 @@ from . import sources as src from . import utils from .metrics import Metric +from .pipelines import StreamingPipeline, StreamingConfig from .progress import ProgressBar, RichProgressBar, TQDMProgressBar from .sinks import StreamingPlot, WindowClosedException @@ -52,7 +53,7 @@ class StreamingInference: """ def __init__( self, - pipeline: blocks.StreamingPipeline, + pipeline: StreamingPipeline, source: src.AudioSource, batch_size: int = 1, do_profile: bool = True, @@ -288,7 +289,7 @@ def get_file_paths(self) -> List[Path]: def run_single( self, - pipeline: blocks.StreamingPipeline, + pipeline: StreamingPipeline, filepath: Path, progress_bar: ProgressBar, ) -> Tuple[Text, Any]: @@ -387,7 +388,7 @@ def evaluate( def __call__( self, pipeline_class: type, - config: blocks.StreamingConfig, + config: StreamingConfig, metric: Optional[Metric] = None, ) -> Union[pd.DataFrame, Dict[Text, Any]]: """Run a given pipeline on a set of audio files. @@ -451,7 +452,7 @@ def __init__( def run_single_job( self, pipeline_class: type, - config: blocks.StreamingConfig, + config: StreamingConfig, filepath: Path, description: Text, ) -> Tuple[Text, Any]: @@ -488,7 +489,7 @@ def run_single_job( def __call__( self, pipeline_class: type, - config: blocks.StreamingConfig, + config: StreamingConfig, metric: Optional[Metric] = None, ) -> Union[pd.DataFrame, Dict[Text, Any]]: """Run a given pipeline on a set of audio files in parallel. diff --git a/src/diart/models.py b/src/diart/models.py index 5afbec45..485cb43a 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -1,4 +1,3 @@ -import time from dataclasses import dataclass from pathlib import Path from typing import Optional, Text, Union, Callable, List, Any diff --git a/src/diart/optim.py b/src/diart/optim.py index f7a96a6e..0ea27910 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -9,9 +9,10 @@ from tqdm import trange, tqdm from typing_extensions import Literal -from . import blocks from .audio import FilePath from .inference import Benchmark +from .pipelines import StreamingConfig +from .pipelines.hparams import HyperParameter class Optimizer: @@ -22,8 +23,8 @@ def __init__( reference_path: Union[Text, Path], study_or_path: Union[FilePath, Study], batch_size: int = 32, - hparams: Optional[Sequence[blocks.base.HyperParameter]] = None, - base_config: Optional[blocks.StreamingConfig] = None, + hparams: Optional[Sequence[HyperParameter]] = None, + base_config: Optional[StreamingConfig] = None, do_kickstart_hparams: bool = True, metric: Optional[BaseMetric] = None, direction: Literal["minimize", "maximize"] = "minimize", diff --git a/src/diart/pipelines/__init__.py b/src/diart/pipelines/__init__.py new file mode 100644 index 00000000..f430e4f7 --- /dev/null +++ b/src/diart/pipelines/__init__.py @@ -0,0 +1,4 @@ +from .base import StreamingPipeline, StreamingConfig +from .diarization import SpeakerDiarization, SpeakerDiarizationConfig +from .transcription import Transcription, TranscriptionConfig +from .voice import VoiceActivityDetection, VoiceActivityDetectionConfig diff --git a/src/diart/blocks/base.py b/src/diart/pipelines/base.py similarity index 77% rename from src/diart/blocks/base.py rename to src/diart/pipelines/base.py index 6494a9bf..507bfe5d 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/pipelines/base.py @@ -1,37 +1,15 @@ -from dataclasses import dataclass -from typing import Any, Tuple, Sequence, Text, List, Union from pathlib import Path +from typing import Any, Tuple, Sequence, Text, List, Union import numpy as np from pyannote.core import SlidingWindowFeature +from .hparams import HyperParameter from .. import utils from ..audio import FilePath, AudioLoader from ..metrics import Metric -@dataclass -class HyperParameter: - name: Text - low: float - high: float - - @staticmethod - def from_name(name: Text) -> 'HyperParameter': - if name == "tau_active": - return TauActive - if name == "rho_update": - return RhoUpdate - if name == "delta_new": - return DeltaNew - raise ValueError(f"Hyper-parameter '{name}' not recognized") - - -TauActive = HyperParameter("tau_active", low=0, high=1) -RhoUpdate = HyperParameter("rho_update", low=0, high=1) -DeltaNew = HyperParameter("delta_new", low=0, high=2) - - class StreamingConfig: @property def duration(self) -> float: diff --git a/src/diart/blocks/diarization.py b/src/diart/pipelines/diarization.py similarity index 93% rename from src/diart/blocks/diarization.py rename to src/diart/pipelines/diarization.py index fe3f4c98..7c901799 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/pipelines/diarization.py @@ -7,11 +7,8 @@ from typing_extensions import Literal from . import base -from .aggregation import DelayedAggregation -from .clustering import OnlineSpeakerClustering -from .embedding import OverlapAwareSpeakerEmbedding -from .segmentation import SpeakerSegmentation -from .utils import Binarize +from .hparams import HyperParameter, TauActive, RhoUpdate, DeltaNew +from .. import blocks from .. import models as m from .. import utils from ..metrics import Metric, DiarizationErrorRate @@ -139,23 +136,23 @@ def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" assert self._config.step <= self._config.latency <= self._config.duration, msg - self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device) - self.embedding = OverlapAwareSpeakerEmbedding( + self.segmentation = blocks.SpeakerSegmentation(self._config.segmentation, self._config.device) + self.embedding = blocks.OverlapAwareSpeakerEmbedding( self._config.embedding, self._config.gamma, self._config.beta, norm=1, device=self._config.device ) - self.pred_aggregation = DelayedAggregation( + self.pred_aggregation = blocks.DelayedAggregation( self._config.step, self._config.latency, strategy="hamming", cropping_mode="loose", ) - self.audio_aggregation = DelayedAggregation( + self.audio_aggregation = blocks.DelayedAggregation( self._config.step, self._config.latency, strategy="first", cropping_mode="center", ) - self.binarize = Binarize(self._config.tau_active) + self.binarize = blocks.Binarize(self._config.tau_active) # Internal state, handle with care self.timestamp_shift = 0 @@ -172,8 +169,8 @@ def suggest_metric() -> Metric: return DiarizationErrorRate(collar=0, skip_overlap=False) @staticmethod - def hyper_parameters() -> Sequence[base.HyperParameter]: - return [base.TauActive, base.RhoUpdate, base.DeltaNew] + def hyper_parameters() -> Sequence[HyperParameter]: + return [TauActive, RhoUpdate, DeltaNew] @property def config(self) -> SpeakerDiarizationConfig: @@ -194,7 +191,7 @@ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Te def reset(self): self.set_timestamp_shift(0) - self.clustering = OnlineSpeakerClustering( + self.clustering = blocks.IncrementalSpeakerClustering( self.config.tau_active, self.config.rho_update, self.config.delta_new, diff --git a/src/diart/pipelines/hparams.py b/src/diart/pipelines/hparams.py new file mode 100644 index 00000000..740a1edf --- /dev/null +++ b/src/diart/pipelines/hparams.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from typing import Text + + +@dataclass +class HyperParameter: + name: Text + low: float + high: float + + @staticmethod + def from_name(name: Text) -> 'HyperParameter': + if name == "tau_active": + return TauActive + if name == "rho_update": + return RhoUpdate + if name == "delta_new": + return DeltaNew + raise ValueError(f"Hyper-parameter '{name}' not recognized") + + +TauActive = HyperParameter("tau_active", low=0, high=1) +RhoUpdate = HyperParameter("rho_update", low=0, high=1) +DeltaNew = HyperParameter("delta_new", low=0, high=2) \ No newline at end of file diff --git a/src/diart/pipelines/transcription.py b/src/diart/pipelines/transcription.py new file mode 100644 index 00000000..2bf4a270 --- /dev/null +++ b/src/diart/pipelines/transcription.py @@ -0,0 +1,184 @@ +from pathlib import Path +from typing import Sequence, Optional, Any, Union, List, Text, Tuple + +import numpy as np +import torch +from pyannote.core import SlidingWindowFeature + +from . import base +from .hparams import HyperParameter, TauActive +from .. import blocks +from .. import models as m +from .. import utils +from ..metrics import Metric, WordErrorRate + + +class TranscriptionConfig(base.StreamingConfig): + def __init__( + self, + asr: Optional[m.SpeechRecognitionModel] = None, + vad: Optional[m.SegmentationModel] = None, + tau_active: float = 0.5, + duration: Optional[float] = None, + language: Optional[Text] = None, + beam_size: int = None, + device: Optional[torch.device] = None, + ): + self.device = device + if self.device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Default ASR model is Whisper small (244M parameters) + self.asr = asr + if self.asr is None: + self.asr = m.SpeechRecognitionModel.from_whisper("small") + self.asr.set_language(language) + self.asr.set_beam_size(beam_size) + + self.vad = vad + self.tau_active = tau_active + + self._duration = duration + self._sample_rate: Optional[int] = None + + @property + def duration(self) -> float: + if self._duration is None: + self._duration = self.asr.duration + return self._duration + + @property + def step(self) -> float: + return self.duration + + @property + def latency(self) -> float: + return self.duration + + @property + def sample_rate(self) -> int: + if self._sample_rate is None: + self._sample_rate = self.asr.sample_rate + return self._sample_rate + + @staticmethod + def from_dict(data: Any) -> 'TranscriptionConfig': + # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None + device = utils.get(data, "device", None) + if device is None: + device = torch.device("cpu") if utils.get(data, "cpu", False) else None + return TranscriptionConfig( + asr=utils.get(data, "asr", None), + vad=utils.get(data, "vad", None), + tau_active=utils.get(data, "tau_active", None), + duration=utils.get(data, "duration", None), + language=utils.get(data, "language", None), + beam_size=utils.get(data, "beam_size", None), + device=device, + ) + + +class Transcription(base.StreamingPipeline): + def __init__(self, config: Optional[TranscriptionConfig] = None): + self._config = TranscriptionConfig() if config is None else config + self.asr = blocks.SpeechRecognition(self.config.asr, self.config.device) + self.segmentation = None + if self.config.vad is not None: + self.segmentation = blocks.SpeakerSegmentation(self.config.vad, self.config.device) + + @staticmethod + def get_config_class() -> type: + return TranscriptionConfig + + @staticmethod + def suggest_metric() -> Metric: + return WordErrorRate() + + @staticmethod + def hyper_parameters() -> Sequence[HyperParameter]: + return [TauActive] + + @property + def config(self) -> TranscriptionConfig: + return self._config + + def reset(self): + # No internal state. Nothing to do + pass + + def set_timestamp_shift(self, shift: float): + # No timestamped output. Nothing to do + pass + + def join_predictions(self, predictions: List[Text]) -> Text: + return "\n".join(predictions) + + def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Path]): + with open(Path(dir_path) / f"{uri}.txt", "w") as out_file: + out_file.write(prediction) + + def __call__( + self, + waveforms: Sequence[SlidingWindowFeature], + ) -> Sequence[Tuple[Text, SlidingWindowFeature]]: + batch_size = len(waveforms) + msg = "Pipeline expected at least 1 input" + assert batch_size >= 1, msg + + # Create batch from chunk sequence, shape (batch, samples, channels) + batch = torch.stack([torch.from_numpy(w.data) for w in waveforms]) + + expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate)) + msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" + assert batch.shape[1] == expected_num_samples, msg + + # Run voice detection if required + if self.segmentation is None: + has_voice = torch.arange(0, batch_size) + else: + segmentations = self.segmentation(batch) # shape (batch, frames, speakers) + has_voice = torch.max(segmentations, dim=-1)[0] # shape (batch, frames) + has_voice = torch.any(has_voice >= self.config.tau_active, dim=-1) # shape (batch,) + has_voice = torch.where(has_voice)[0] + + # Return empty strings if no speech in the entire batch + if len(has_voice) == 0: + return [("", wav) for wav in waveforms] + + # Transcribe batch + outputs = self.asr(batch[has_voice]) + mapping = {i_voice.item(): i_output for i_output, i_voice in enumerate(has_voice)} + + # No-speech outputs are empty strings + return [ + (outputs[mapping[i]].text if i in has_voice else "", waveforms[i]) + for i in range(batch_size) + ] + + # TODO align text with speakers if diarization is not None + + # diarization = diarization[0] + # + # # Align transcription with diarization to determine speakers + # full_transcription = [] + # buffer_shift = waveform.sliding_window.start + # for text, timestamp in zip(outputs.chunks, outputs.timestamps): + # target_region = Segment( + # buffer_shift + timestamp.start, + # buffer_shift + timestamp.end + # ) + # dia = diarization.crop(target_region) + # speakers = dia.labels() + # num_speakers = len(speakers) + # if num_speakers == 0: + # # Include transcription but don't assign a speaker + # full_transcription.append(text) + # elif num_speakers == 1: + # # Typical case, annotate text with the only speaker + # full_transcription.append(f"[{speakers[0]}]{text}") + # else: + # # Multiple speakers for the same text block, choose the most active one + # max_spk = np.argmax([dia.label_duration(spk) for spk in speakers]) + # full_transcription.append(f"[{speakers[max_spk]}]{text}") + # + # return [(" ".join(full_transcription).strip(), waveform)] \ No newline at end of file diff --git a/src/diart/blocks/vad.py b/src/diart/pipelines/voice.py similarity index 94% rename from src/diart/blocks/vad.py rename to src/diart/pipelines/voice.py index 42061c86..f93c28a6 100644 --- a/src/diart/blocks/vad.py +++ b/src/diart/pipelines/voice.py @@ -7,9 +7,8 @@ from typing_extensions import Literal from . import base -from .aggregation import DelayedAggregation -from .segmentation import SpeakerSegmentation -from .utils import Binarize +from .hparams import HyperParameter, TauActive +from .. import blocks from .. import models as m from .. import utils from ..metrics import Metric, DetectionErrorRate @@ -106,20 +105,20 @@ def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" assert self._config.step <= self._config.latency <= self._config.duration, msg - self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device) - self.pred_aggregation = DelayedAggregation( + self.segmentation = blocks.SpeakerSegmentation(self._config.segmentation, self._config.device) + self.pred_aggregation = blocks.DelayedAggregation( self._config.step, self._config.latency, strategy="hamming", cropping_mode="loose", ) - self.audio_aggregation = DelayedAggregation( + self.audio_aggregation = blocks.DelayedAggregation( self._config.step, self._config.latency, strategy="first", cropping_mode="center", ) - self.binarize = Binarize(self._config.tau_active) + self.binarize = blocks.Binarize(self._config.tau_active) # Internal state, handle with care self.timestamp_shift = 0 @@ -134,8 +133,8 @@ def suggest_metric() -> Metric: return DetectionErrorRate(collar=0, skip_overlap=False) @staticmethod - def hyper_parameters() -> Sequence[base.HyperParameter]: - return [base.TauActive] + def hyper_parameters() -> Sequence[HyperParameter]: + return [TauActive] @property def config(self) -> VoiceActivityDetectionConfig: diff --git a/src/diart/utils.py b/src/diart/utils.py index e825ef29..3714a99d 100644 --- a/src/diart/utils.py +++ b/src/diart/utils.py @@ -7,7 +7,7 @@ from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook from .progress import ProgressBar -from . import blocks +from . import pipelines class Chronometer: @@ -82,7 +82,7 @@ def repeat_label(label: Text): def get_pipeline_class(class_name: Text) -> type: - pipeline_class = getattr(blocks, class_name, None) + pipeline_class = getattr(pipelines, class_name, None) msg = f"Pipeline '{class_name}' doesn't exist" assert pipeline_class is not None, msg return pipeline_class From 42fe5f7505e823166a335cd7d4142379f43e2ec7 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Sun, 23 Apr 2023 18:23:39 +0200 Subject: [PATCH 13/23] Add websocket compatibility to transcription pipeline --- src/diart/console/serve.py | 14 ++++++---- src/diart/pipelines/base.py | 5 ++++ src/diart/pipelines/diarization.py | 10 +++++-- src/diart/pipelines/transcription.py | 40 +++++++++++++++++++++------- src/diart/pipelines/voice.py | 10 +++++-- src/diart/sinks.py | 23 +++++++++++----- src/diart/sources.py | 5 ++-- src/diart/utils.py | 6 +++++ 8 files changed, 87 insertions(+), 26 deletions(-) diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index 46bb9328..4e7ca1ce 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -5,7 +5,7 @@ from diart import sources as src from diart import utils from diart.inference import StreamingInference -from diart.sinks import RTTMWriter +from diart.pipelines import StreamingPipeline def run(): @@ -14,6 +14,10 @@ def run(): parser.add_argument("--port", default=7007, type=int, help="Server port") parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") + parser.add_argument("--whisper", default="small", type=str, + help=f"Whisper model for transcription pipeline. Defaults to 'small'") + parser.add_argument("--language", default="en", type=str, + help=f"Transcribe in this language. Defaults to 'en' (English)") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -36,7 +40,7 @@ def run(): # Resolve pipeline pipeline_class = utils.get_pipeline_class(args.pipeline) config = pipeline_class.get_config_class().from_dict(vars(args)) - pipeline = pipeline_class(config) + pipeline: StreamingPipeline = pipeline_class(config) # Create websocket audio source audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port) @@ -53,10 +57,10 @@ def run(): # Write to disk if required if args.output is not None: - inference.attach_observers(RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm")) + inference.attach_observers(pipeline.suggest_writer(audio_source.uri, args.output)) - # Send back responses as RTTM text lines - inference.attach_hooks(lambda ann_wav: audio_source.send(ann_wav[0].to_rttm())) + # Send back responses as text + inference.attach_hooks(lambda pred_wav: audio_source.send(utils.serialize_prediction(pred_wav[0]))) # Run server and pipeline inference() diff --git a/src/diart/pipelines/base.py b/src/diart/pipelines/base.py index 507bfe5d..d47115d5 100644 --- a/src/diart/pipelines/base.py +++ b/src/diart/pipelines/base.py @@ -3,6 +3,7 @@ import numpy as np from pyannote.core import SlidingWindowFeature +from rx.core import Observer from .hparams import HyperParameter from .. import utils @@ -50,6 +51,10 @@ def get_config_class() -> type: def suggest_metric() -> Metric: raise NotImplementedError + @staticmethod + def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer: + raise NotImplementedError + @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: raise NotImplementedError diff --git a/src/diart/pipelines/diarization.py b/src/diart/pipelines/diarization.py index 7c901799..dd9ef070 100644 --- a/src/diart/pipelines/diarization.py +++ b/src/diart/pipelines/diarization.py @@ -4,12 +4,14 @@ import numpy as np import torch from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment +from rx.core import Observer from typing_extensions import Literal from . import base from .hparams import HyperParameter, TauActive, RhoUpdate, DeltaNew from .. import blocks from .. import models as m +from .. import sinks from .. import utils from ..metrics import Metric, DiarizationErrorRate @@ -22,7 +24,7 @@ def __init__( duration: Optional[float] = None, step: float = 0.5, latency: Optional[Union[float, Literal["max", "min"]]] = None, - tau_active: float = 0.6, + tau_active: float = 0.5, rho_update: float = 0.3, delta_new: float = 1, gamma: float = 3, @@ -82,7 +84,7 @@ def from_dict(data: Any) -> 'SpeakerDiarizationConfig': # Hyper-parameters and their aliases tau = utils.get(data, "tau_active", None) if tau is None: - tau = utils.get(data, "tau", 0.6) + tau = utils.get(data, "tau", 0.5) rho = utils.get(data, "rho_update", None) if rho is None: rho = utils.get(data, "rho", 0.3) @@ -168,6 +170,10 @@ def get_config_class() -> type: def suggest_metric() -> Metric: return DiarizationErrorRate(collar=0, skip_overlap=False) + @staticmethod + def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer: + return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm") + @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: return [TauActive, RhoUpdate, DeltaNew] diff --git a/src/diart/pipelines/transcription.py b/src/diart/pipelines/transcription.py index 2bf4a270..6b3249c9 100644 --- a/src/diart/pipelines/transcription.py +++ b/src/diart/pipelines/transcription.py @@ -4,11 +4,13 @@ import numpy as np import torch from pyannote.core import SlidingWindowFeature +from rx.core import Observer from . import base from .hparams import HyperParameter, TauActive from .. import blocks from .. import models as m +from .. import sinks from .. import utils from ..metrics import Metric, WordErrorRate @@ -17,9 +19,9 @@ class TranscriptionConfig(base.StreamingConfig): def __init__( self, asr: Optional[m.SpeechRecognitionModel] = None, - vad: Optional[m.SegmentationModel] = None, + segmentation: Optional[m.SegmentationModel] = None, tau_active: float = 0.5, - duration: Optional[float] = None, + duration: Optional[float] = 3, language: Optional[Text] = None, beam_size: int = None, device: Optional[torch.device] = None, @@ -35,7 +37,7 @@ def __init__( self.asr.set_language(language) self.asr.set_beam_size(beam_size) - self.vad = vad + self.segmentation = segmentation self.tau_active = tau_active self._duration = duration @@ -67,11 +69,27 @@ def from_dict(data: Any) -> 'TranscriptionConfig': device = utils.get(data, "device", None) if device is None: device = torch.device("cpu") if utils.get(data, "cpu", False) else None + + # Default ASR model is Whisper small (244M parameters) + whisper_size = utils.get(data, "whisper", "small") + asr = m.SpeechRecognitionModel.from_whisper(whisper_size) + + # No VAD segmentation by default + segmentation = utils.get(data, "segmentation", None) + if segmentation is not None: + hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) + segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) + + # Tau hyper-parameter and its alias + tau = utils.get(data, "tau_active", None) + if tau is None: + tau = utils.get(data, "tau", 0.5) + return TranscriptionConfig( - asr=utils.get(data, "asr", None), - vad=utils.get(data, "vad", None), - tau_active=utils.get(data, "tau_active", None), - duration=utils.get(data, "duration", None), + asr=asr, + segmentation=segmentation, + tau_active=tau, + duration=utils.get(data, "duration", 3), language=utils.get(data, "language", None), beam_size=utils.get(data, "beam_size", None), device=device, @@ -83,8 +101,8 @@ def __init__(self, config: Optional[TranscriptionConfig] = None): self._config = TranscriptionConfig() if config is None else config self.asr = blocks.SpeechRecognition(self.config.asr, self.config.device) self.segmentation = None - if self.config.vad is not None: - self.segmentation = blocks.SpeakerSegmentation(self.config.vad, self.config.device) + if self.config.segmentation is not None: + self.segmentation = blocks.SpeakerSegmentation(self.config.segmentation, self.config.device) @staticmethod def get_config_class() -> type: @@ -94,6 +112,10 @@ def get_config_class() -> type: def suggest_metric() -> Metric: return WordErrorRate() + @staticmethod + def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer: + return sinks.TextWriter(Path(output_dir) / f"{uri}.txt") + @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: return [TauActive] diff --git a/src/diart/pipelines/voice.py b/src/diart/pipelines/voice.py index f93c28a6..f050a972 100644 --- a/src/diart/pipelines/voice.py +++ b/src/diart/pipelines/voice.py @@ -4,12 +4,14 @@ import numpy as np import torch from pyannote.core import Annotation, Timeline, SlidingWindowFeature, SlidingWindow, Segment +from rx.core import Observer from typing_extensions import Literal from . import base from .hparams import HyperParameter, TauActive from .. import blocks from .. import models as m +from .. import sinks from .. import utils from ..metrics import Metric, DetectionErrorRate @@ -21,7 +23,7 @@ def __init__( duration: Optional[float] = None, step: float = 0.5, latency: Optional[Union[float, Literal["max", "min"]]] = None, - tau_active: float = 0.6, + tau_active: float = 0.5, merge_collar: float = 0.05, device: Optional[torch.device] = None, **kwargs, @@ -85,7 +87,7 @@ def from_dict(data: Any) -> 'VoiceActivityDetectionConfig': # Tau active and its alias tau = utils.get(data, "tau_active", None) if tau is None: - tau = utils.get(data, "tau", 0.6) + tau = utils.get(data, "tau", 0.5) return VoiceActivityDetectionConfig( segmentation=segmentation, @@ -132,6 +134,10 @@ def get_config_class() -> type: def suggest_metric() -> Metric: return DetectionErrorRate(collar=0, skip_overlap=False) + @staticmethod + def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer: + return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm") + @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: return [TauActive] diff --git a/src/diart/sinks.py b/src/diart/sinks.py index 8d9217a1..5d77e09e 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Union, Text, Optional, Tuple +from typing import Union, Text, Optional, Tuple, Any import matplotlib.pyplot as plt from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook @@ -13,13 +13,10 @@ class WindowClosedException(Exception): pass -def _extract_prediction(value: Union[Tuple, Annotation]) -> Annotation: +def _extract_prediction(value: Union[Tuple, Any]) -> Any: if isinstance(value, tuple): return value[0] - if isinstance(value, Annotation): - return value - msg = f"Expected tuple or Annotation, but got {type(value)}" - raise ValueError(msg) + return value class RTTMWriter(Observer): @@ -56,6 +53,20 @@ def on_completed(self): self.patch() +class TextWriter(Observer): + def __init__(self, path: Union[Path, Text]): + super().__init__() + self.path = Path(path).expanduser() + if self.path.exists(): + self.path.unlink() + + def on_next(self, value: Union[Tuple, Text]): + # Write transcription to file + prediction = _extract_prediction(value) + with open(self.path, 'a') as file: + file.write(prediction + "\n") + + class DiarizationAccumulator(Observer): def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05): super().__init__() diff --git a/src/diart/sources.py b/src/diart/sources.py index b34d5cf3..76149bb4 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -247,8 +247,9 @@ def send(self, message: AnyStr): message: AnyStr Bytes or string to send. """ - if len(message) > 0: - self.server.send_message(self.client, message) + msg = message.strip() + if len(msg) > 0: + self.server.send_message(self.client, msg + "\n") class TorchStreamAudioSource(AudioSource): diff --git a/src/diart/utils.py b/src/diart/utils.py index 3714a99d..a5f06e21 100644 --- a/src/diart/utils.py +++ b/src/diart/utils.py @@ -92,6 +92,12 @@ def get_padding_right(latency: float, step: float) -> float: return latency - step +def serialize_prediction(value: Union[Annotation, Text]) -> Text: + if isinstance(value, Annotation): + return value.to_rttm() + return value + + def visualize_feature(duration: Optional[float] = None): def apply(feature: SlidingWindowFeature): if duration is None: From 49616e58bfafbf0edba2a92ce8cf3e95d9f2edce Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Sun, 23 Apr 2023 19:49:46 +0200 Subject: [PATCH 14/23] Transcription pipeline is now fully compatible with diart.stream --- src/diart/console/serve.py | 1 - src/diart/console/stream.py | 20 ++- src/diart/inference.py | 23 +--- src/diart/operators.py | 193 +-------------------------- src/diart/pipelines/base.py | 17 +-- src/diart/pipelines/diarization.py | 22 +-- src/diart/pipelines/transcription.py | 17 +-- src/diart/pipelines/voice.py | 22 +-- src/diart/sinks.py | 114 +++++++++++++--- 9 files changed, 161 insertions(+), 268 deletions(-) diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index 4e7ca1ce..5c7c68b0 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -51,7 +51,6 @@ def run(): audio_source, batch_size=1, do_profile=False, - do_plot=False, show_progress=True, ) diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index e0c670c5..3563dd2a 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -5,7 +5,7 @@ from diart import sources as src from diart import utils from diart.inference import StreamingInference -from diart.sinks import RTTMWriter +from diart.pipelines import StreamingPipeline, StreamingConfig def run(): @@ -13,6 +13,10 @@ def run(): parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:'") parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") + parser.add_argument("--whisper", default="small", type=str, + help=f"Whisper model for transcription pipeline. Defaults to 'small'") + parser.add_argument("--language", default="en", type=str, + help=f"Transcribe in this language. Defaults to 'en' (English)") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -36,8 +40,8 @@ def run(): # Resolve pipeline pipeline_class = utils.get_pipeline_class(args.pipeline) - config = pipeline_class.get_config_class().from_dict(vars(args)) - pipeline = pipeline_class(config) + config: StreamingConfig = pipeline_class.get_config_class().from_dict(vars(args)) + pipeline: StreamingPipeline = pipeline_class(config) # Manage audio source block_size = config.optimal_block_size() @@ -59,10 +63,16 @@ def run(): audio_source, batch_size=1, do_profile=True, - do_plot=not args.no_plot, show_progress=True, ) - inference.attach_observers(RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm")) + + # Attach observers for required side effects + observers = [pipeline.suggest_writer(audio_source.uri, args.output)] + if not args.no_plot: + observers.append(pipeline.suggest_display()) + inference.attach_observers(*observers) + + # Run pipeline inference() diff --git a/src/diart/inference.py b/src/diart/inference.py index 258f773a..b9f2a789 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -20,7 +20,7 @@ from .metrics import Metric from .pipelines import StreamingPipeline, StreamingConfig from .progress import ProgressBar, RichProgressBar, TQDMProgressBar -from .sinks import StreamingPlot, WindowClosedException +from .sinks import WindowClosedException class StreamingInference: @@ -40,9 +40,6 @@ class StreamingInference: do_profile: bool If True, compute and report the processing time of the pipeline. Defaults to True. - do_plot: bool - If True, draw predictions in a moving plot. - Defaults to False. show_progress: bool If True, show a progress bar. Defaults to True. @@ -57,7 +54,6 @@ def __init__( source: src.AudioSource, batch_size: int = 1, do_profile: bool = True, - do_plot: bool = False, show_progress: bool = True, progress_bar: Optional[ProgressBar] = None, ): @@ -65,7 +61,6 @@ def __init__( self.source = source self.batch_size = batch_size self.do_profile = do_profile - self.do_plot = do_plot self.show_progress = show_progress self.unit = "chunk" if self.batch_size == 1 else "batch" self._observers = [] @@ -192,20 +187,7 @@ def __call__(self) -> List[Any]: """ if self.show_progress: self._pbar.start() - config = self.pipeline.config - observable = self.stream - if self.do_plot: - # Buffering is needed for the real-time plot, so we do this at the very end - observable = self.stream.pipe( - dops.buffer_output( - duration=config.duration, - step=config.step, - latency=config.latency, - sample_rate=config.sample_rate, - ), - ops.do(StreamingPlot(config.duration, config.latency)), - ) - observable.subscribe( + self.stream.subscribe( on_error=self._handle_error, on_completed=self._handle_completion, ) @@ -324,7 +306,6 @@ def run_single( source, self.batch_size, do_profile=False, - do_plot=False, show_progress=self.show_progress, progress_bar=progress_bar, ) diff --git a/src/diart/operators.py b/src/diart/operators.py index 6d73fc9d..c67bf99d 100644 --- a/src/diart/operators.py +++ b/src/diart/operators.py @@ -1,9 +1,9 @@ from dataclasses import dataclass -from typing import Callable, Optional, List, Any, Tuple +from typing import Callable, Optional, List, Any import numpy as np import rx -from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature, Segment +from pyannote.core import SlidingWindow, SlidingWindowFeature from rx import operators as ops from rx.core import Observable @@ -97,192 +97,3 @@ def accumulate(state: List[Any], value: Any) -> List[Any]: return new_state[1:] return new_state return rx.pipe(ops.scan(accumulate, [])) - - -@dataclass -class PredictionWithAudio: - prediction: Annotation - waveform: Optional[SlidingWindowFeature] = None - - @property - def has_audio(self) -> bool: - return self.waveform is not None - - -@dataclass -class OutputAccumulationState: - annotation: Optional[Annotation] - waveform: Optional[SlidingWindowFeature] - real_time: float - next_sample: Optional[int] - - @staticmethod - def initial() -> 'OutputAccumulationState': - return OutputAccumulationState(None, None, 0, 0) - - @property - def cropped_waveform(self) -> SlidingWindowFeature: - return SlidingWindowFeature( - self.waveform[:self.next_sample], - self.waveform.sliding_window, - ) - - def to_tuple(self) -> Tuple[Optional[Annotation], Optional[SlidingWindowFeature], float]: - return self.annotation, self.cropped_waveform, self.real_time - - -def accumulate_output( - duration: float, - step: float, - patch_collar: float = 0.05, -) -> Operator: - """Accumulate predictions and audio to infinity: O(N) space complexity. - Uses a pre-allocated buffer that doubles its size once full: O(logN) concat operations. - - Parameters - ---------- - duration: float - Buffer duration in seconds. - step: float - Duration of the chunks at each event in seconds. - The first chunk may be bigger given the latency. - patch_collar: float, optional - Collar to merge speaker turns of the same speaker, in seconds. - Defaults to 0.05 (i.e. 50ms). - Returns - ------- - A reactive x operator implementing this behavior. - """ - def accumulate( - state: OutputAccumulationState, - value: Tuple[Annotation, Optional[SlidingWindowFeature]] - ) -> OutputAccumulationState: - value = PredictionWithAudio(*value) - annotation, waveform = None, None - - # Determine the real time of the stream - real_time = duration if state.annotation is None else state.real_time + step - - # Update total annotation with current predictions - if state.annotation is None: - annotation = value.prediction - else: - annotation = state.annotation.update(value.prediction).support(patch_collar) - - # Update total waveform if there's audio in the input - new_next_sample = 0 - if value.has_audio: - num_new_samples = value.waveform.data.shape[0] - new_next_sample = state.next_sample + num_new_samples - sw_holder = state - if state.waveform is None: - # Initialize the audio buffer with 10 times the size of the first chunk - waveform, sw_holder = np.zeros((10 * num_new_samples, 1)), value - elif new_next_sample < state.waveform.data.shape[0]: - # The buffer still has enough space to accommodate the chunk - waveform = state.waveform.data - else: - # The buffer is full, double its size - waveform = np.concatenate( - (state.waveform.data, np.zeros_like(state.waveform.data)), axis=0 - ) - # Copy chunk into buffer - waveform[state.next_sample:new_next_sample] = value.waveform.data - waveform = SlidingWindowFeature(waveform, sw_holder.waveform.sliding_window) - - return OutputAccumulationState(annotation, waveform, real_time, new_next_sample) - - return rx.pipe( - ops.scan(accumulate, OutputAccumulationState.initial()), - ops.map(OutputAccumulationState.to_tuple), - ) - - -def buffer_output( - duration: float, - step: float, - latency: float, - sample_rate: int, - patch_collar: float = 0.05, -) -> Operator: - """Store last predictions and audio inside a fixed buffer. - Provides the best time/space complexity trade-off if the past data is not needed. - - Parameters - ---------- - duration: float - Buffer duration in seconds. - step: float - Duration of the chunks at each event in seconds. - The first chunk may be bigger given the latency. - latency: float - Latency of the system in seconds. - sample_rate: int - Sample rate of the audio source. - patch_collar: float, optional - Collar to merge speaker turns of the same speaker, in seconds. - Defaults to 0.05 (i.e. 50ms). - - Returns - ------- - A reactive x operator implementing this behavior. - """ - # Define some useful constants - num_samples = int(round(duration * sample_rate)) - num_step_samples = int(round(step * sample_rate)) - resolution = 1 / sample_rate - - def accumulate( - state: OutputAccumulationState, - value: Tuple[Annotation, Optional[SlidingWindowFeature]] - ) -> OutputAccumulationState: - value = PredictionWithAudio(*value) - annotation, waveform = None, None - - # Determine the real time of the stream and the start time of the buffer - real_time = duration if state.annotation is None else state.real_time + step - start_time = max(0., real_time - latency - duration) - - # Update annotation and constrain its bounds to the buffer - if state.annotation is None: - annotation = value.prediction - else: - annotation = state.annotation.update(value.prediction).support(patch_collar) - if start_time > 0: - annotation = annotation.extrude(Segment(0, start_time)) - - # Update the audio buffer if there's audio in the input - new_next_sample = state.next_sample + num_step_samples - if value.has_audio: - if state.waveform is None: - # Determine the size of the first chunk - expected_duration = duration + step - latency - expected_samples = int(round(expected_duration * sample_rate)) - # Shift indicator to start copying new audio in the buffer - new_next_sample = state.next_sample + expected_samples - # Buffer size is duration + step - waveform = np.zeros((num_samples + num_step_samples, 1)) - # Copy first chunk into buffer (slicing because of rounding errors) - waveform[:expected_samples] = value.waveform.data[:expected_samples] - elif state.next_sample <= num_samples: - # The buffer isn't full, copy into next free buffer chunk - waveform = state.waveform.data - waveform[state.next_sample:new_next_sample] = value.waveform.data - else: - # The buffer is full, shift values to the left and copy into last buffer chunk - waveform = np.roll(state.waveform.data, -num_step_samples, axis=0) - # If running on a file, the online prediction may be shorter depending on the latency - # The remaining audio at the end is appended, so value.waveform may be longer than num_step_samples - # In that case, we simply ignore the appended samples. - waveform[-num_step_samples:] = value.waveform.data[:num_step_samples] - - # Wrap waveform in a sliding window feature to include timestamps - window = SlidingWindow(start=start_time, duration=resolution, step=resolution) - waveform = SlidingWindowFeature(waveform, window) - - return OutputAccumulationState(annotation, waveform, real_time, new_next_sample) - - return rx.pipe( - ops.scan(accumulate, OutputAccumulationState.initial()), - ops.map(OutputAccumulationState.to_tuple), - ) diff --git a/src/diart/pipelines/base.py b/src/diart/pipelines/base.py index d47115d5..1a220203 100644 --- a/src/diart/pipelines/base.py +++ b/src/diart/pipelines/base.py @@ -47,14 +47,6 @@ class StreamingPipeline: def get_config_class() -> type: raise NotImplementedError - @staticmethod - def suggest_metric() -> Metric: - raise NotImplementedError - - @staticmethod - def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer: - raise NotImplementedError - @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: raise NotImplementedError @@ -75,6 +67,15 @@ def join_predictions(self, predictions: List[Any]) -> Any: def write_prediction(self, uri: Text, prediction: Any, dir_path: Union[Text, Path]): raise NotImplementedError + def suggest_metric(self) -> Metric: + raise NotImplementedError + + def suggest_display(self) -> Observer: + raise NotImplementedError + + def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer: + raise NotImplementedError + def __call__( self, waveforms: Sequence[SlidingWindowFeature], diff --git a/src/diart/pipelines/diarization.py b/src/diart/pipelines/diarization.py index dd9ef070..57cb7739 100644 --- a/src/diart/pipelines/diarization.py +++ b/src/diart/pipelines/diarization.py @@ -166,14 +166,6 @@ def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): def get_config_class() -> type: return SpeakerDiarizationConfig - @staticmethod - def suggest_metric() -> Metric: - return DiarizationErrorRate(collar=0, skip_overlap=False) - - @staticmethod - def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer: - return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm") - @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: return [TauActive, RhoUpdate, DeltaNew] @@ -195,6 +187,20 @@ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Te with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file: prediction.write_rttm(out_file) + def suggest_metric(self) -> Metric: + return DiarizationErrorRate(collar=0, skip_overlap=False) + + def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer: + return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm") + + def suggest_display(self) -> Observer: + return sinks.StreamingPlot( + self.config.duration, + self.config.step, + self.config.latency, + self.config.sample_rate + ) + def reset(self): self.set_timestamp_shift(0) self.clustering = blocks.IncrementalSpeakerClustering( diff --git a/src/diart/pipelines/transcription.py b/src/diart/pipelines/transcription.py index 6b3249c9..cb5f40e7 100644 --- a/src/diart/pipelines/transcription.py +++ b/src/diart/pipelines/transcription.py @@ -108,14 +108,6 @@ def __init__(self, config: Optional[TranscriptionConfig] = None): def get_config_class() -> type: return TranscriptionConfig - @staticmethod - def suggest_metric() -> Metric: - return WordErrorRate() - - @staticmethod - def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer: - return sinks.TextWriter(Path(output_dir) / f"{uri}.txt") - @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: return [TauActive] @@ -139,6 +131,15 @@ def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Pa with open(Path(dir_path) / f"{uri}.txt", "w") as out_file: out_file.write(prediction) + def suggest_metric(self) -> Metric: + return WordErrorRate() + + def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer: + return sinks.TextWriter(Path(output_dir) / f"{uri}.txt") + + def suggest_display(self) -> Observer: + return sinks.RichScreen() + def __call__( self, waveforms: Sequence[SlidingWindowFeature], diff --git a/src/diart/pipelines/voice.py b/src/diart/pipelines/voice.py index f050a972..b22806ab 100644 --- a/src/diart/pipelines/voice.py +++ b/src/diart/pipelines/voice.py @@ -130,14 +130,6 @@ def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): def get_config_class() -> type: return VoiceActivityDetectionConfig - @staticmethod - def suggest_metric() -> Metric: - return DetectionErrorRate(collar=0, skip_overlap=False) - - @staticmethod - def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer: - return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm") - @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: return [TauActive] @@ -163,6 +155,20 @@ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Te with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file: prediction.write_rttm(out_file) + def suggest_metric(self) -> Metric: + return DetectionErrorRate(collar=0, skip_overlap=False) + + def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer: + return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm") + + def suggest_display(self) -> Observer: + return sinks.StreamingPlot( + self.config.duration, + self.config.step, + self.config.latency, + self.config.sample_rate + ) + def __call__( self, waveforms: Sequence[SlidingWindowFeature], diff --git a/src/diart/sinks.py b/src/diart/sinks.py index 5d77e09e..be461fff 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -1,12 +1,14 @@ +import re from pathlib import Path -from typing import Union, Text, Optional, Tuple, Any +from typing import Union, Text, Optional, Tuple, Any, List import matplotlib.pyplot as plt -from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook +import numpy as np +import rich +from pyannote.core import Annotation, Segment, SlidingWindowFeature, SlidingWindow, notebook from pyannote.database.util import load_rttm from pyannote.metrics.diarization import DiarizationErrorRate from rx.core import Observer -from typing_extensions import Literal class WindowClosedException(Exception): @@ -99,25 +101,59 @@ def on_completed(self): self.patch() +class RichScreen(Observer): + def __init__(self, speaker_colors: Optional[List[Text]] = None): + super().__init__() + self.colors = speaker_colors + if self.colors is None: + self.colors = [ + "bright_red", "bright_blue", "bright_green", "orange3", "deep_pink1", + "yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2" + ] + self.num_colors = len(self.colors) + + def on_next(self, value: Union[Tuple, Text]): + prediction = _extract_prediction(value) + # Extract speakers + speakers = sorted(re.findall(r'\[.*?]', prediction)) + # Colorize based on speakers + colorized = prediction + for i, speaker in enumerate(speakers): + colorized = colorized.replace(speaker, f"[{self.colors[i % self.num_colors]}]") + # Print result + rich.print(colorized) + + class StreamingPlot(Observer): def __init__( self, duration: float, + step: float, latency: float, - visualization: Literal["slide", "accumulate"] = "slide", + sample_rate: float, reference: Optional[Union[Path, Text]] = None, + patch_collar: float = 0.05, ): super().__init__() - assert visualization in ["slide", "accumulate"] - self.visualization = visualization self.reference = reference if self.reference is not None: self.reference = list(load_rttm(reference).values())[0] self.window_duration = duration + self.window_step = step self.latency = latency + self.sample_rate = sample_rate + self.patch_collar = patch_collar + + self.num_window_samples = int(np.rint(self.window_duration * self.sample_rate)) + self.num_step_samples = int(np.rint(self.window_step * self.sample_rate)) + self.audio_resolution = 1 / self.sample_rate + self.figure, self.axs, self.num_axs = None, None, -1 # This flag allows to catch the matplotlib window closed event and make the next call stop iterating self.window_closed = False + self.real_time = 0 + self.pred_buffer, self.audio_buffer = None, None + self.next_sample = 0 def _on_window_closed(self, event): self.window_closed = True @@ -139,21 +175,63 @@ def _clear_axs(self): for i in range(self.num_axs): self.axs[i].clear() - def get_plot_bounds(self, real_time: float) -> Segment: - start_time = 0 - end_time = real_time - self.latency - if self.visualization == "slide": - start_time = max(0., end_time - self.window_duration) + def get_plot_bounds(self) -> Segment: + end_time = self.real_time - self.latency + start_time = max(0., end_time - self.window_duration) return Segment(start_time, end_time) def on_next( self, - values: Tuple[Annotation, SlidingWindowFeature, float] + values: Tuple[Annotation, SlidingWindowFeature] ): if self.window_closed: raise WindowClosedException - prediction, waveform, real_time = values + prediction, waveform = values + + # TODO break this aggregation code into methods + + # Determine the real time of the stream and the start time of the buffer + self.real_time = waveform.extent.end + start_time = max(0., self.real_time - self.latency - self.window_duration) + + # Update prediction buffer and constrain its bounds + if self.pred_buffer is None: + self.pred_buffer = prediction + else: + self.pred_buffer = self.pred_buffer.update(prediction) + self.pred_buffer = self.pred_buffer.support(self.patch_collar) + if start_time > 0: + self.pred_buffer = self.pred_buffer.extrude(Segment(0, start_time)) + + # Update the audio buffer if there's audio in the input + new_next_sample = self.next_sample + self.num_step_samples + if self.audio_buffer is None: + # Determine the size of the first chunk + expected_duration = self.window_duration + self.window_step - self.latency + expected_samples = int(np.rint(expected_duration * self.sample_rate)) + # Shift indicator to start copying new audio in the buffer + new_next_sample = self.next_sample + expected_samples + # Buffer size is duration + step + new_buffer = np.zeros((self.num_window_samples + self.num_step_samples, 1)) + # Copy first chunk into buffer (slicing because of rounding errors) + new_buffer[:expected_samples] = waveform.data[:expected_samples] + elif self.next_sample <= self.num_window_samples: + # The buffer isn't full, copy into next free buffer chunk + new_buffer = self.audio_buffer.data + new_buffer[self.next_sample:new_next_sample] = waveform.data + else: + # The buffer is full, shift values to the left and copy into last buffer chunk + new_buffer = np.roll(self.audio_buffer.data, -self.num_step_samples, axis=0) + # If running on a file, the online prediction may be shorter depending on the latency + # The remaining audio at the end is appended, so 'waveform' may be longer than 'num_step_samples' + # In that case, we simply ignore the appended samples. + new_buffer[-self.num_step_samples:] = waveform.data[:self.num_step_samples] + + # Wrap waveform in a sliding window feature to include timestamps + window = SlidingWindow(start=start_time, duration=self.audio_resolution, step=self.audio_resolution) + self.audio_buffer = SlidingWindowFeature(new_buffer, window) + self.next_sample = new_next_sample # Initialize figure if first call if self.figure is None: @@ -161,20 +239,20 @@ def on_next( # Clear previous plots self._clear_axs() # Set plot bounds - notebook.crop = self.get_plot_bounds(real_time) + notebook.crop = self.get_plot_bounds() # Align prediction and reference if possible if self.reference is not None: metric = DiarizationErrorRate() - mapping = metric.optimal_mapping(self.reference, prediction) - prediction.rename_labels(mapping=mapping, copy=False) + mapping = metric.optimal_mapping(self.reference, self.pred_buffer) + self.pred_buffer.rename_labels(mapping=mapping, copy=False) # Plot prediction - notebook.plot_annotation(prediction, self.axs[0]) + notebook.plot_annotation(self.pred_buffer, self.axs[0]) self.axs[0].set_title("Output") # Plot waveform - notebook.plot_feature(waveform, self.axs[1]) + notebook.plot_feature(self.audio_buffer, self.axs[1]) self.axs[1].set_title("Audio") # Plot reference if available From babf49d9cec81a89ce593b86ea7113fb63a9bbe9 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Mon, 24 Apr 2023 11:22:50 +0200 Subject: [PATCH 15/23] Make transcription pipeline compatible with diart.benchmark and diart.tune. Fix major bug in Optimizer --- src/diart/__init__.py | 4 +- src/diart/console/benchmark.py | 6 ++ src/diart/console/client.py | 3 +- src/diart/console/serve.py | 6 +- src/diart/console/stream.py | 11 ++- src/diart/console/tune.py | 12 ++- src/diart/inference.py | 114 +++++++++++++-------------- src/diart/optim.py | 11 ++- src/diart/pipelines/__init__.py | 2 +- src/diart/pipelines/base.py | 15 ++-- src/diart/pipelines/diarization.py | 11 +-- src/diart/pipelines/transcription.py | 19 +++-- src/diart/pipelines/voice.py | 11 +-- 13 files changed, 125 insertions(+), 100 deletions(-) diff --git a/src/diart/__init__.py b/src/diart/__init__.py index 0d67c9c5..842ba267 100644 --- a/src/diart/__init__.py +++ b/src/diart/__init__.py @@ -1,6 +1,6 @@ from .pipelines import ( - StreamingPipeline, - StreamingConfig, + Pipeline, + PipelineConfig, SpeakerDiarization, SpeakerDiarizationConfig, VoiceActivityDetection, diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py index 27d524c5..d8f04183 100644 --- a/src/diart/console/benchmark.py +++ b/src/diart/console/benchmark.py @@ -12,12 +12,18 @@ def run(): parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") + parser.add_argument("--whisper", default="small", type=str, + help=f"Whisper model for transcription pipeline. Defaults to 'small'") + parser.add_argument("--language", default="en", type=str, + help=f"Transcribe in this language. Defaults to 'en' (English)") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") parser.add_argument("--reference", type=Path, help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files") + parser.add_argument("--duration", default=5, type=float, + help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") diff --git a/src/diart/console/client.py b/src/diart/console/client.py index db4915fa..816c7e0f 100644 --- a/src/diart/console/client.py +++ b/src/diart/console/client.py @@ -45,7 +45,8 @@ def run(): parser.add_argument("--host", required=True, type=str, help="Server host") parser.add_argument("--port", required=True, type=int, help="Server port") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") - parser.add_argument("-sr", "--sample-rate", default=16000, type=int, help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000") + parser.add_argument("-sr", "--sample-rate", default=16000, type=int, + help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000") parser.add_argument("-o", "--output-file", type=Path, help="Output RTTM file. Defaults to no writing") args = parser.parse_args() diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index 5c7c68b0..0698ede0 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -5,7 +5,7 @@ from diart import sources as src from diart import utils from diart.inference import StreamingInference -from diart.pipelines import StreamingPipeline +from diart.pipelines import Pipeline def run(): @@ -22,6 +22,8 @@ def run(): help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") + parser.add_argument("--duration", type=float, + help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") @@ -40,7 +42,7 @@ def run(): # Resolve pipeline pipeline_class = utils.get_pipeline_class(args.pipeline) config = pipeline_class.get_config_class().from_dict(vars(args)) - pipeline: StreamingPipeline = pipeline_class(config) + pipeline: Pipeline = pipeline_class(config) # Create websocket audio source audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port) diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index 3563dd2a..1436eb8a 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -5,7 +5,7 @@ from diart import sources as src from diart import utils from diart.inference import StreamingInference -from diart.pipelines import StreamingPipeline, StreamingConfig +from diart.pipelines import Pipeline, PipelineConfig def run(): @@ -21,6 +21,8 @@ def run(): help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") + parser.add_argument("--duration", default=5, type=float, + help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") @@ -33,15 +35,16 @@ def run(): parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") parser.add_argument("--output", type=str, - help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file") + help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' " + f"or parent directory if SOURCE is a file") parser.add_argument("--hf-token", default="true", type=str, help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() # Resolve pipeline pipeline_class = utils.get_pipeline_class(args.pipeline) - config: StreamingConfig = pipeline_class.get_config_class().from_dict(vars(args)) - pipeline: StreamingPipeline = pipeline_class(config) + config: PipelineConfig = pipeline_class.get_config_class().from_dict(vars(args)) + pipeline: Pipeline = pipeline_class(config) # Manage audio source block_size = config.optimal_block_size() diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index 111d97c2..f492c704 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -16,10 +16,16 @@ def run(): help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files") parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") + parser.add_argument("--whisper", default="small", type=str, + help=f"Whisper model for transcription pipeline. Defaults to 'small'") + parser.add_argument("--language", default="en", type=str, + help=f"Transcribe in this language. Defaults to 'en' (English)") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") + parser.add_argument("--duration", default=5, type=float, + help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") @@ -32,10 +38,12 @@ def run(): parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") parser.add_argument("--hparams", nargs="+", default=("tau_active", "rho_update", "delta_new"), - help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new") + help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. " + "Defaults to tau_active, rho_update and delta_new") parser.add_argument("--num-iter", default=100, type=int, help="Number of optimization trials") parser.add_argument("--storage", type=str, - help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name") + help="Optuna storage string. If provided, continue a previous study instead of creating one. " + "The database name must match the study name") parser.add_argument("--output", type=str, help="Working directory") parser.add_argument("--hf-token", default="true", type=str, help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") diff --git a/src/diart/inference.py b/src/diart/inference.py index b9f2a789..061065fb 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -18,22 +18,21 @@ from . import sources as src from . import utils from .metrics import Metric -from .pipelines import StreamingPipeline, StreamingConfig +from .pipelines import Pipeline, PipelineConfig from .progress import ProgressBar, RichProgressBar, TQDMProgressBar from .sinks import WindowClosedException class StreamingInference: - """Performs inference in real time given a pipeline and an audio source. - Streams an audio source to an online speaker diarization pipeline. - It allows users to attach a chain of operations in the form of hooks. + """Performs streaming inference given a pipeline and an audio source. + Side-effect hooks and observers can also be attached for customized behavior. Parameters ---------- - pipeline: StreamingPipeline - Configured speaker diarization pipeline. + pipeline: Pipeline + A pipeline. source: AudioSource - Audio source to be read and streamed. + Audio source to read and stream. batch_size: int Number of inputs to send to the pipeline at once. Defaults to 1. @@ -50,7 +49,7 @@ class StreamingInference: """ def __init__( self, - pipeline: StreamingPipeline, + pipeline: Pipeline, source: src.AudioSource, batch_size: int = 1, do_profile: bool = True, @@ -62,7 +61,6 @@ def __init__( self.batch_size = batch_size self.do_profile = do_profile self.show_progress = show_progress - self.unit = "chunk" if self.batch_size == 1 else "batch" self._observers = [] self._predictions = [] @@ -84,11 +82,11 @@ def __init__( self._pbar.create( total=self.num_chunks, description=f"Streaming {self.source.uri}", - unit=self.unit + unit="chunk" ) # Initialize chronometer for profiling - self._chrono = utils.Chronometer(self.unit, self._pbar) + self._chrono = utils.Chronometer("batch", self._pbar) self.stream = self.source.stream @@ -198,21 +196,21 @@ def __call__(self) -> List[Any]: class Benchmark: """ - Run an online speaker diarization pipeline on a set of audio files in batches. + Run a pipeline on a set of audio files in batches. Write predictions to a given output directory. - If the reference is given, calculate the average diarization error rate. + If the reference is given, compute the average performance metric. Parameters ---------- speech_path: Text or Path Directory with audio files. reference_path: Text, Path or None - Directory with reference RTTM files (same names as audio files). - If None, performance will not be calculated. + Directory with reference files (same names as audio files with different extension). + If None, performance will not be computed. Defaults to None. - output_path: Text, Path or None - Output directory to store predictions in RTTM format. + output_path: Optional[Text | Path] + Output directory to store predictions. If None, predictions will not be written to disk. Defaults to None. show_progress: bool @@ -222,12 +220,7 @@ class Benchmark: Whether to print a performance report to stdout. Defaults to True. batch_size: int - Inference batch size. - If < 2, then it will run in real time. - If >= 2, then it will pre-calculate segmentation and - embeddings, running the rest in real time. - The performance between this two modes does not differ. - Defaults to 32. + Inference batch size. Defaults to 32. """ def __init__( self, @@ -252,7 +245,7 @@ def __init__( self.output_path = output_path if self.output_path is not None: - self.output_path = Path(output_path).expanduser() + self.output_path: Path = Path(output_path).expanduser() self.output_path.mkdir(parents=True, exist_ok=True) self.show_progress = show_progress @@ -271,27 +264,29 @@ def get_file_paths(self) -> List[Path]: def run_single( self, - pipeline: StreamingPipeline, + pipeline: Pipeline, filepath: Path, progress_bar: ProgressBar, ) -> Tuple[Text, Any]: """Run a given pipeline on a given file. - Note that this method does NOT reset the + This method does NOT reset the state of the pipeline before execution. Parameters ---------- - pipeline: StreamingPipeline - Speaker diarization pipeline to run. + pipeline: Pipeline + A pipeline. filepath: Path Path to the target file. progress_bar: diart.progress.ProgressBar - An object to manage the progress of this run. + Object to display the progress of this run. Returns ------- - prediction: Annotation - Pipeline prediction for the given file. + uri: Text + File URI. + prediction: Any + Aggregated pipeline prediction for the given file. """ padding = pipeline.config.get_file_padding(filepath) source = src.FileAudioSource( @@ -325,7 +320,7 @@ def evaluate( metric: Metric, ) -> Union[pd.DataFrame, Dict[Text, Any]]: """If a reference path was provided, - compute the diarization error rate of a list of predictions. + compute the performance of a list of predictions. Parameters ---------- @@ -369,30 +364,29 @@ def evaluate( def __call__( self, pipeline_class: type, - config: StreamingConfig, + config: PipelineConfig, metric: Optional[Metric] = None, ) -> Union[pd.DataFrame, Dict[Text, Any]]: - """Run a given pipeline on a set of audio files. + """Run a pipeline on a set of audio files. The internal state of the pipeline is reset before benchmarking. Parameters ---------- pipeline_class: class - Class from the StreamingPipeline hierarchy. + Class from the `Pipeline` hierarchy. A pipeline from this class will be instantiated by each worker. - config: StreamingConfig - Streaming pipeline configuration. + config: PipelineConfig + Pipeline configuration. metric: Optional[Metric] Evaluation metric. Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`) Returns ------- - performance: pandas.DataFrame or List[Annotation] + performance: pandas.DataFrame or Dict[Text, Any] If reference annotations are given, a DataFrame with detailed performance on each file as well as average performance. - - If no reference annotations, a list of predictions. + If no reference annotations, a dict of uris with predictions. """ audio_file_paths = self.get_file_paths() num_audio_files = len(audio_file_paths) @@ -412,15 +406,14 @@ def __call__( class Parallelize: """Wrapper to parallelize the execution of a `Benchmark` instance. - Note that models will be copied in each worker instead of being reused. + Models will be copied in each worker instead of being reused. Parameters ---------- benchmark: Benchmark - Benchmark instance to execute in parallel. + Benchmark instance to run in parallel. num_workers: int - Number of parallel workers. - Defaults to 0 (no parallelism). + Number of parallel workers. Defaults to 4. """ def __init__( self, @@ -433,20 +426,20 @@ def __init__( def run_single_job( self, pipeline_class: type, - config: StreamingConfig, + config: PipelineConfig, filepath: Path, description: Text, ) -> Tuple[Text, Any]: - """Build and run a pipeline on a single file. + """Instantiate and run a pipeline on a given file. Configure execution to show progress alongside parallel runs. Parameters ---------- pipeline_class: class - Class from the StreamingPipeline hierarchy. + Class from the Pipeline hierarchy. A pipeline from this class will be instantiated. - config: StreamingConfig - Streaming pipeline configuration. + config: PipelineConfig + Pipeline configuration. filepath: Path Path to the target file. description: Text @@ -454,8 +447,10 @@ def run_single_job( Returns ------- - prediction: Annotation - Pipeline prediction for the given file. + uri: Text + File URI. + prediction: Any + Aggregated pipeline prediction for the given file. """ # The process ID inside the pool determines the position of the progress bar idx_process = int(current_process().name.split('-')[1]) - 1 @@ -470,30 +465,29 @@ def run_single_job( def __call__( self, pipeline_class: type, - config: StreamingConfig, + config: PipelineConfig, metric: Optional[Metric] = None, ) -> Union[pd.DataFrame, Dict[Text, Any]]: - """Run a given pipeline on a set of audio files in parallel. - Each worker will build and run the pipeline on a different file. + """Run a pipeline on a set of audio files in parallel. + Each worker instantiates and runs the pipeline on a different file. Parameters ---------- pipeline_class: class - Class from the StreamingPipeline hierarchy. + Class from the Pipeline hierarchy. A pipeline from this class will be instantiated by each worker. - config: StreamingConfig - Streaming pipeline configuration. + config: PipelineConfig + Pipeline configuration. metric: Optional[Metric] Evaluation metric. Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`) Returns ------- - performance: pandas.DataFrame or List[Annotation] + performance: pandas.DataFrame or Dict[Text, Any] If reference annotations are given, a DataFrame with detailed performance on each file as well as average performance. - - If no reference annotations, a list of predictions. + If no reference annotations, a dict of uris with predictions. """ audio_file_paths = self.benchmark.get_file_paths() num_audio_files = len(audio_file_paths) diff --git a/src/diart/optim.py b/src/diart/optim.py index 0ea27910..be371a71 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -11,7 +11,7 @@ from .audio import FilePath from .inference import Benchmark -from .pipelines import StreamingConfig +from .pipelines import PipelineConfig from .pipelines.hparams import HyperParameter @@ -24,7 +24,7 @@ def __init__( study_or_path: Union[FilePath, Study], batch_size: int = 32, hparams: Optional[Sequence[HyperParameter]] = None, - base_config: Optional[StreamingConfig] = None, + base_config: Optional[PipelineConfig] = None, do_kickstart_hparams: bool = True, metric: Optional[BaseMetric] = None, direction: Literal["minimize", "maximize"] = "minimize", @@ -97,6 +97,9 @@ def _callback(self, study: Study, trial: FrozenTrial): def objective(self, trial: Trial) -> float: # Set suggested values for optimized hyper-parameters trial_config = vars(self.base_config) + trial_config["duration"] = self.base_config.duration + trial_config["step"] = self.base_config.step + trial_config["latency"] = self.base_config.latency for hparam in self.hparams: trial_config[hparam.name] = trial.suggest_uniform( hparam.name, hparam.low, hparam.high @@ -118,12 +121,12 @@ def objective(self, trial: Trial) -> float: report = self.benchmark(self.pipeline_class, config, metric) # Extract target metric from report - return report.loc["TOTAL", metric.name]["%"] + return report.loc["TOTAL", metric.name].item() def __call__(self, num_iter: int, show_progress: bool = True): self._progress = None if show_progress: - self._progress = trange(num_iter) + self._progress = trange(num_iter, unit="trial") last_trial = -1 if self.study.trials: last_trial = self.study.trials[-1].number diff --git a/src/diart/pipelines/__init__.py b/src/diart/pipelines/__init__.py index f430e4f7..11fe5e82 100644 --- a/src/diart/pipelines/__init__.py +++ b/src/diart/pipelines/__init__.py @@ -1,4 +1,4 @@ -from .base import StreamingPipeline, StreamingConfig +from .base import Pipeline, PipelineConfig from .diarization import SpeakerDiarization, SpeakerDiarizationConfig from .transcription import Transcription, TranscriptionConfig from .voice import VoiceActivityDetection, VoiceActivityDetectionConfig diff --git a/src/diart/pipelines/base.py b/src/diart/pipelines/base.py index 1a220203..de0582a7 100644 --- a/src/diart/pipelines/base.py +++ b/src/diart/pipelines/base.py @@ -11,7 +11,7 @@ from ..metrics import Metric -class StreamingConfig: +class PipelineConfig: @property def duration(self) -> float: raise NotImplementedError @@ -29,7 +29,7 @@ def sample_rate(self) -> int: raise NotImplementedError @staticmethod - def from_dict(data: Any) -> 'StreamingConfig': + def from_dict(data: Any) -> 'PipelineConfig': raise NotImplementedError def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: @@ -42,17 +42,21 @@ def optimal_block_size(self) -> int: return int(np.rint(self.step * self.sample_rate)) -class StreamingPipeline: +class Pipeline: @staticmethod def get_config_class() -> type: raise NotImplementedError + @staticmethod + def suggest_metric() -> Metric: + raise NotImplementedError + @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: raise NotImplementedError @property - def config(self) -> StreamingConfig: + def config(self) -> PipelineConfig: raise NotImplementedError def reset(self): @@ -67,9 +71,6 @@ def join_predictions(self, predictions: List[Any]) -> Any: def write_prediction(self, uri: Text, prediction: Any, dir_path: Union[Text, Path]): raise NotImplementedError - def suggest_metric(self) -> Metric: - raise NotImplementedError - def suggest_display(self) -> Observer: raise NotImplementedError diff --git a/src/diart/pipelines/diarization.py b/src/diart/pipelines/diarization.py index 57cb7739..114a4223 100644 --- a/src/diart/pipelines/diarization.py +++ b/src/diart/pipelines/diarization.py @@ -16,7 +16,7 @@ from ..metrics import Metric, DiarizationErrorRate -class SpeakerDiarizationConfig(base.StreamingConfig): +class SpeakerDiarizationConfig(base.PipelineConfig): def __init__( self, segmentation: Optional[m.SegmentationModel] = None, @@ -131,7 +131,7 @@ def sample_rate(self) -> int: return self._sample_rate -class SpeakerDiarization(base.StreamingPipeline): +class SpeakerDiarization(base.Pipeline): def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): self._config = SpeakerDiarizationConfig() if config is None else config @@ -166,6 +166,10 @@ def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): def get_config_class() -> type: return SpeakerDiarizationConfig + @staticmethod + def suggest_metric() -> Metric: + return DiarizationErrorRate(collar=0, skip_overlap=False) + @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: return [TauActive, RhoUpdate, DeltaNew] @@ -187,9 +191,6 @@ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Te with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file: prediction.write_rttm(out_file) - def suggest_metric(self) -> Metric: - return DiarizationErrorRate(collar=0, skip_overlap=False) - def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer: return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm") diff --git a/src/diart/pipelines/transcription.py b/src/diart/pipelines/transcription.py index cb5f40e7..3616222e 100644 --- a/src/diart/pipelines/transcription.py +++ b/src/diart/pipelines/transcription.py @@ -15,7 +15,7 @@ from ..metrics import Metric, WordErrorRate -class TranscriptionConfig(base.StreamingConfig): +class TranscriptionConfig(base.PipelineConfig): def __init__( self, asr: Optional[m.SpeechRecognitionModel] = None, @@ -25,7 +25,11 @@ def __init__( language: Optional[Text] = None, beam_size: int = None, device: Optional[torch.device] = None, + **kwargs, ): + self.language = language + self.beam_size = beam_size + self.device = device if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -34,8 +38,8 @@ def __init__( self.asr = asr if self.asr is None: self.asr = m.SpeechRecognitionModel.from_whisper("small") - self.asr.set_language(language) - self.asr.set_beam_size(beam_size) + self.asr.set_language(self.language) + self.asr.set_beam_size(self.beam_size) self.segmentation = segmentation self.tau_active = tau_active @@ -96,7 +100,7 @@ def from_dict(data: Any) -> 'TranscriptionConfig': ) -class Transcription(base.StreamingPipeline): +class Transcription(base.Pipeline): def __init__(self, config: Optional[TranscriptionConfig] = None): self._config = TranscriptionConfig() if config is None else config self.asr = blocks.SpeechRecognition(self.config.asr, self.config.device) @@ -108,6 +112,10 @@ def __init__(self, config: Optional[TranscriptionConfig] = None): def get_config_class() -> type: return TranscriptionConfig + @staticmethod + def suggest_metric() -> Metric: + return WordErrorRate() + @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: return [TauActive] @@ -131,9 +139,6 @@ def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Pa with open(Path(dir_path) / f"{uri}.txt", "w") as out_file: out_file.write(prediction) - def suggest_metric(self) -> Metric: - return WordErrorRate() - def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer: return sinks.TextWriter(Path(output_dir) / f"{uri}.txt") diff --git a/src/diart/pipelines/voice.py b/src/diart/pipelines/voice.py index b22806ab..05eaa216 100644 --- a/src/diart/pipelines/voice.py +++ b/src/diart/pipelines/voice.py @@ -16,7 +16,7 @@ from ..metrics import Metric, DetectionErrorRate -class VoiceActivityDetectionConfig(base.StreamingConfig): +class VoiceActivityDetectionConfig(base.PipelineConfig): def __init__( self, segmentation: Optional[m.SegmentationModel] = None, @@ -100,7 +100,7 @@ def from_dict(data: Any) -> 'VoiceActivityDetectionConfig': ) -class VoiceActivityDetection(base.StreamingPipeline): +class VoiceActivityDetection(base.Pipeline): def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): self._config = VoiceActivityDetectionConfig() if config is None else config @@ -130,6 +130,10 @@ def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): def get_config_class() -> type: return VoiceActivityDetectionConfig + @staticmethod + def suggest_metric() -> Metric: + return DetectionErrorRate(collar=0, skip_overlap=False) + @staticmethod def hyper_parameters() -> Sequence[HyperParameter]: return [TauActive] @@ -155,9 +159,6 @@ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Te with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file: prediction.write_rttm(out_file) - def suggest_metric(self) -> Metric: - return DetectionErrorRate(collar=0, skip_overlap=False) - def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer: return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm") From 6609e3ca2cb2f92e5839ab9cf1ec4b647211f7ca Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Mon, 24 Apr 2023 11:25:51 +0200 Subject: [PATCH 16/23] Rename base pipeline and config objects --- src/diart/__init__.py | 4 ++-- src/diart/blocks/__init__.py | 2 +- src/diart/blocks/base.py | 8 ++++---- src/diart/blocks/diarization.py | 4 ++-- src/diart/blocks/vad.py | 6 +++--- src/diart/inference.py | 10 +++++----- src/diart/optim.py | 2 +- 7 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/diart/__init__.py b/src/diart/__init__.py index e29287a0..4bd51327 100644 --- a/src/diart/__init__.py +++ b/src/diart/__init__.py @@ -1,8 +1,8 @@ from .blocks import ( SpeakerDiarization, - StreamingPipeline, + Pipeline, SpeakerDiarizationConfig, - StreamingConfig, + PipelineConfig, VoiceActivityDetection, VoiceActivityDetectionConfig, ) diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py index e6e8c479..15cf81d9 100644 --- a/src/diart/blocks/__init__.py +++ b/src/diart/blocks/__init__.py @@ -14,6 +14,6 @@ ) from .segmentation import SpeakerSegmentation from .diarization import SpeakerDiarization, SpeakerDiarizationConfig -from .base import StreamingConfig, StreamingPipeline +from .base import PipelineConfig, Pipeline from .utils import Binarize, Resample, AdjustVolume from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index 28f313eb..11ef961d 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -31,7 +31,7 @@ def from_name(name: Text) -> 'HyperParameter': DeltaNew = HyperParameter("delta_new", low=0, high=2) -class StreamingConfig: +class PipelineConfig: @property def duration(self) -> float: raise NotImplementedError @@ -49,7 +49,7 @@ def sample_rate(self) -> int: raise NotImplementedError @staticmethod - def from_dict(data: Any) -> 'StreamingConfig': + def from_dict(data: Any) -> 'PipelineConfig': raise NotImplementedError def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: @@ -62,7 +62,7 @@ def optimal_block_size(self) -> int: return int(np.rint(self.step * self.sample_rate)) -class StreamingPipeline: +class Pipeline: @staticmethod def get_config_class() -> type: raise NotImplementedError @@ -76,7 +76,7 @@ def hyper_parameters() -> Sequence[HyperParameter]: raise NotImplementedError @property - def config(self) -> StreamingConfig: + def config(self) -> PipelineConfig: raise NotImplementedError def reset(self): diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index f2a25119..06658cfc 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -17,7 +17,7 @@ from .. import utils -class SpeakerDiarizationConfig(base.StreamingConfig): +class SpeakerDiarizationConfig(base.PipelineConfig): def __init__( self, segmentation: Optional[m.SegmentationModel] = None, @@ -129,7 +129,7 @@ def sample_rate(self) -> int: return self._sample_rate -class SpeakerDiarization(base.StreamingPipeline): +class SpeakerDiarization(base.Pipeline): def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): self._config = SpeakerDiarizationConfig() if config is None else config diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py index def833b6..e519a9cf 100644 --- a/src/diart/blocks/vad.py +++ b/src/diart/blocks/vad.py @@ -15,7 +15,7 @@ from .. import utils -class VoiceActivityDetectionConfig(base.StreamingConfig): +class VoiceActivityDetectionConfig(base.PipelineConfig): def __init__( self, segmentation: Optional[m.SegmentationModel] = None, @@ -96,7 +96,7 @@ def from_dict(data: Any) -> 'VoiceActivityDetectionConfig': ) -class VoiceActivityDetection(base.StreamingPipeline): +class VoiceActivityDetection(base.Pipeline): def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): self._config = VoiceActivityDetectionConfig() if config is None else config @@ -135,7 +135,7 @@ def hyper_parameters() -> Sequence[base.HyperParameter]: return [base.TauActive] @property - def config(self) -> base.StreamingConfig: + def config(self) -> base.PipelineConfig: return self._config def reset(self): diff --git a/src/diart/inference.py b/src/diart/inference.py index 6afda89e..f562fdd9 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -53,7 +53,7 @@ class StreamingInference: """ def __init__( self, - pipeline: blocks.StreamingPipeline, + pipeline: blocks.Pipeline, source: src.AudioSource, batch_size: int = 1, do_profile: bool = True, @@ -289,7 +289,7 @@ def get_file_paths(self) -> List[Path]: def run_single( self, - pipeline: blocks.StreamingPipeline, + pipeline: blocks.Pipeline, filepath: Path, progress_bar: ProgressBar, ) -> Annotation: @@ -374,7 +374,7 @@ def evaluate( def __call__( self, pipeline_class: type, - config: blocks.StreamingConfig, + config: blocks.PipelineConfig, metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files. @@ -437,7 +437,7 @@ def __init__( def run_single_job( self, pipeline_class: type, - config: blocks.StreamingConfig, + config: blocks.PipelineConfig, filepath: Path, description: Text, ) -> Annotation: @@ -474,7 +474,7 @@ def run_single_job( def __call__( self, pipeline_class: type, - config: blocks.StreamingConfig, + config: blocks.PipelineConfig, metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files in parallel. diff --git a/src/diart/optim.py b/src/diart/optim.py index f7a96a6e..86492627 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -23,7 +23,7 @@ def __init__( study_or_path: Union[FilePath, Study], batch_size: int = 32, hparams: Optional[Sequence[blocks.base.HyperParameter]] = None, - base_config: Optional[blocks.StreamingConfig] = None, + base_config: Optional[blocks.PipelineConfig] = None, do_kickstart_hparams: bool = True, metric: Optional[BaseMetric] = None, direction: Literal["minimize", "maximize"] = "minimize", From d19b04464355c5b0b282aee38327a540d8d2163d Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Apr 2023 17:41:41 +0200 Subject: [PATCH 17/23] New feature: streaming voice activity detection. Pipeline name changes --- src/diart/__init__.py | 10 +- src/diart/blocks/__init__.py | 5 +- src/diart/blocks/base.py | 92 ++++++++++++++ src/diart/blocks/config.py | 153 ----------------------- src/diart/blocks/diarization.py | 145 ++++++++++++++++++---- src/diart/blocks/vad.py | 208 ++++++++++++++++++++++++++++++++ src/diart/console/benchmark.py | 12 +- src/diart/console/client.py | 6 +- src/diart/console/serve.py | 19 +-- src/diart/console/stream.py | 19 +-- src/diart/console/tune.py | 26 +++- src/diart/inference.py | 86 +++++++------ src/diart/optim.py | 56 ++++----- src/diart/sinks.py | 47 +++++--- src/diart/sources.py | 2 +- src/diart/utils.py | 16 ++- 16 files changed, 605 insertions(+), 297 deletions(-) create mode 100644 src/diart/blocks/base.py delete mode 100644 src/diart/blocks/config.py create mode 100644 src/diart/blocks/vad.py diff --git a/src/diart/__init__.py b/src/diart/__init__.py index c9692638..e29287a0 100644 --- a/src/diart/__init__.py +++ b/src/diart/__init__.py @@ -1,6 +1,8 @@ from .blocks import ( - OnlineSpeakerDiarization, - BasePipeline, - PipelineConfig, - BasePipelineConfig, + SpeakerDiarization, + StreamingPipeline, + SpeakerDiarizationConfig, + StreamingConfig, + VoiceActivityDetection, + VoiceActivityDetectionConfig, ) diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py index 59a6ef36..e6e8c479 100644 --- a/src/diart/blocks/__init__.py +++ b/src/diart/blocks/__init__.py @@ -13,6 +13,7 @@ OverlapAwareSpeakerEmbedding, ) from .segmentation import SpeakerSegmentation -from .diarization import OnlineSpeakerDiarization, BasePipeline -from .config import BasePipelineConfig, PipelineConfig +from .diarization import SpeakerDiarization, SpeakerDiarizationConfig +from .base import StreamingConfig, StreamingPipeline from .utils import Binarize, Resample, AdjustVolume +from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py new file mode 100644 index 00000000..28f313eb --- /dev/null +++ b/src/diart/blocks/base.py @@ -0,0 +1,92 @@ +from typing import Any, Tuple, Sequence, Text +from dataclasses import dataclass + +import numpy as np +from pyannote.core import SlidingWindowFeature +from pyannote.metrics.base import BaseMetric + +from .. import utils +from ..audio import FilePath, AudioLoader + + +@dataclass +class HyperParameter: + name: Text + low: float + high: float + + @staticmethod + def from_name(name: Text) -> 'HyperParameter': + if name == "tau_active": + return TauActive + if name == "rho_update": + return RhoUpdate + if name == "delta_new": + return DeltaNew + raise ValueError(f"Hyper-parameter '{name}' not recognized") + + +TauActive = HyperParameter("tau_active", low=0, high=1) +RhoUpdate = HyperParameter("rho_update", low=0, high=1) +DeltaNew = HyperParameter("delta_new", low=0, high=2) + + +class StreamingConfig: + @property + def duration(self) -> float: + raise NotImplementedError + + @property + def step(self) -> float: + raise NotImplementedError + + @property + def latency(self) -> float: + raise NotImplementedError + + @property + def sample_rate(self) -> int: + raise NotImplementedError + + @staticmethod + def from_dict(data: Any) -> 'StreamingConfig': + raise NotImplementedError + + def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: + file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath) + right = utils.get_padding_right(self.latency, self.step) + left = utils.get_padding_left(file_duration + right, self.duration) + return left, right + + def optimal_block_size(self) -> int: + return int(np.rint(self.step * self.sample_rate)) + + +class StreamingPipeline: + @staticmethod + def get_config_class() -> type: + raise NotImplementedError + + @staticmethod + def suggest_metric() -> BaseMetric: + raise NotImplementedError + + @staticmethod + def hyper_parameters() -> Sequence[HyperParameter]: + raise NotImplementedError + + @property + def config(self) -> StreamingConfig: + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + def set_timestamp_shift(self, shift: float): + raise NotImplementedError + + def __call__( + self, + waveforms: Sequence[SlidingWindowFeature] + ) -> Sequence[Tuple[Any, SlidingWindowFeature]]: + raise NotImplementedError diff --git a/src/diart/blocks/config.py b/src/diart/blocks/config.py deleted file mode 100644 index d8e2a656..00000000 --- a/src/diart/blocks/config.py +++ /dev/null @@ -1,153 +0,0 @@ -from typing import Any, Optional, Union, Tuple - -import numpy as np -import torch -from typing_extensions import Literal - -from .. import models as m -from .. import utils -from ..audio import FilePath, AudioLoader - - -class BasePipelineConfig: - @property - def duration(self) -> float: - raise NotImplementedError - - @property - def step(self) -> float: - raise NotImplementedError - - @property - def latency(self) -> float: - raise NotImplementedError - - @property - def sample_rate(self) -> int: - raise NotImplementedError - - @staticmethod - def from_dict(data: Any) -> 'BasePipelineConfig': - raise NotImplementedError - - def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: - file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath) - right = utils.get_padding_right(self.latency, self.step) - left = utils.get_padding_left(file_duration + right, self.duration) - return left, right - - def optimal_block_size(self) -> int: - return int(np.rint(self.step * self.sample_rate)) - - -class PipelineConfig(BasePipelineConfig): - def __init__( - self, - segmentation: Optional[m.SegmentationModel] = None, - embedding: Optional[m.EmbeddingModel] = None, - duration: Optional[float] = None, - step: float = 0.5, - latency: Optional[Union[float, Literal["max", "min"]]] = None, - tau_active: float = 0.6, - rho_update: float = 0.3, - delta_new: float = 1, - gamma: float = 3, - beta: float = 10, - max_speakers: int = 20, - device: Optional[torch.device] = None, - **kwargs, - ): - # Default segmentation model is pyannote/segmentation - self.segmentation = segmentation - if self.segmentation is None: - self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") - - # Default duration is the one given by the segmentation model - self._duration = duration - - # Expected sample rate is given by the segmentation model - self._sample_rate: Optional[int] = None - - # Default embedding model is pyannote/embedding - self.embedding = embedding - if self.embedding is None: - self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding") - - # Latency defaults to the step duration - self._step = step - self._latency = latency - if self._latency is None or self._latency == "min": - self._latency = self._step - elif self._latency == "max": - self._latency = self._duration - - self.tau_active = tau_active - self.rho_update = rho_update - self.delta_new = delta_new - self.gamma = gamma - self.beta = beta - self.max_speakers = max_speakers - - self.device = device - if self.device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - @staticmethod - def from_dict(data: Any) -> 'PipelineConfig': - # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None - device = utils.get(data, "device", None) - if device is None: - device = torch.device("cpu") if utils.get(data, "cpu", False) else None - - # Instantiate models - hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) - segmentation = utils.get(data, "segmentation", "pyannote/segmentation") - segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) - embedding = utils.get(data, "embedding", "pyannote/embedding") - embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token) - - # Hyper-parameters and their aliases - tau = utils.get(data, "tau_active", None) - if tau is None: - tau = utils.get(data, "tau", 0.6) - rho = utils.get(data, "rho_update", None) - if rho is None: - rho = utils.get(data, "rho", 0.3) - delta = utils.get(data, "delta_new", None) - if delta is None: - delta = utils.get(data, "delta", 1) - - return PipelineConfig( - segmentation=segmentation, - embedding=embedding, - duration=utils.get(data, "duration", None), - step=utils.get(data, "step", 0.5), - latency=utils.get(data, "latency", None), - tau_active=tau, - rho_update=rho, - delta_new=delta, - gamma=utils.get(data, "gamma", 3), - beta=utils.get(data, "beta", 10), - max_speakers=utils.get(data, "max_speakers", 20), - device=device, - ) - - @property - def duration(self) -> float: - if self._duration is None: - self._duration = self.segmentation.duration - return self._duration - - @property - def step(self) -> float: - return self._step - - @property - def latency(self) -> float: - return self._latency - - @property - def sample_rate(self) -> int: - if self._sample_rate is None: - self._sample_rate = self.segmentation.sample_rate - return self._sample_rate diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index 7f0e162c..f2a25119 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -1,42 +1,137 @@ -from typing import Optional, Tuple, Sequence +from typing import Optional, Tuple, Sequence, Union, Any import numpy as np import torch from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment +from pyannote.metrics.base import BaseMetric +from pyannote.metrics.diarization import DiarizationErrorRate +from typing_extensions import Literal from .aggregation import DelayedAggregation +from . import base from .clustering import OnlineSpeakerClustering from .embedding import OverlapAwareSpeakerEmbedding from .segmentation import SpeakerSegmentation from .utils import Binarize -from .config import BasePipelineConfig, PipelineConfig +from .. import models as m +from .. import utils -class BasePipeline: +class SpeakerDiarizationConfig(base.StreamingConfig): + def __init__( + self, + segmentation: Optional[m.SegmentationModel] = None, + embedding: Optional[m.EmbeddingModel] = None, + duration: Optional[float] = None, + step: float = 0.5, + latency: Optional[Union[float, Literal["max", "min"]]] = None, + tau_active: float = 0.6, + rho_update: float = 0.3, + delta_new: float = 1, + gamma: float = 3, + beta: float = 10, + max_speakers: int = 20, + device: Optional[torch.device] = None, + **kwargs, + ): + # Default segmentation model is pyannote/segmentation + self.segmentation = segmentation + if self.segmentation is None: + self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") + + self._duration = duration + self._sample_rate: Optional[int] = None + + # Default embedding model is pyannote/embedding + self.embedding = embedding + if self.embedding is None: + self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding") + + # Latency defaults to the step duration + self._step = step + self._latency = latency + if self._latency is None or self._latency == "min": + self._latency = self._step + elif self._latency == "max": + self._latency = self._duration + + self.tau_active = tau_active + self.rho_update = rho_update + self.delta_new = delta_new + self.gamma = gamma + self.beta = beta + self.max_speakers = max_speakers + + self.device = device + if self.device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + @staticmethod - def get_config_class() -> type: - raise NotImplementedError + def from_dict(data: Any) -> 'SpeakerDiarizationConfig': + # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None + device = utils.get(data, "device", None) + if device is None: + device = torch.device("cpu") if utils.get(data, "cpu", False) else None + + # Instantiate models + hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) + segmentation = utils.get(data, "segmentation", "pyannote/segmentation") + segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) + embedding = utils.get(data, "embedding", "pyannote/embedding") + embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token) + + # Hyper-parameters and their aliases + tau = utils.get(data, "tau_active", None) + if tau is None: + tau = utils.get(data, "tau", 0.6) + rho = utils.get(data, "rho_update", None) + if rho is None: + rho = utils.get(data, "rho", 0.3) + delta = utils.get(data, "delta_new", None) + if delta is None: + delta = utils.get(data, "delta", 1) + + return SpeakerDiarizationConfig( + segmentation=segmentation, + embedding=embedding, + duration=utils.get(data, "duration", None), + step=utils.get(data, "step", 0.5), + latency=utils.get(data, "latency", None), + tau_active=tau, + rho_update=rho, + delta_new=delta, + gamma=utils.get(data, "gamma", 3), + beta=utils.get(data, "beta", 10), + max_speakers=utils.get(data, "max_speakers", 20), + device=device, + ) @property - def config(self) -> BasePipelineConfig: - raise NotImplementedError + def duration(self) -> float: + # Default duration is the one given by the segmentation model + if self._duration is None: + self._duration = self.segmentation.duration + return self._duration - def reset(self): - raise NotImplementedError + @property + def step(self) -> float: + return self._step - def set_timestamp_shift(self, shift: float): - raise NotImplementedError + @property + def latency(self) -> float: + return self._latency - def __call__( - self, - waveforms: Sequence[SlidingWindowFeature] - ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: - raise NotImplementedError + @property + def sample_rate(self) -> int: + # Expected sample rate is given by the segmentation model + if self._sample_rate is None: + self._sample_rate = self.segmentation.sample_rate + return self._sample_rate -class OnlineSpeakerDiarization(BasePipeline): - def __init__(self, config: Optional[PipelineConfig] = None): - self._config = PipelineConfig() if config is None else config +class SpeakerDiarization(base.StreamingPipeline): + def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): + self._config = SpeakerDiarizationConfig() if config is None else config msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" assert self._config.step <= self._config.latency <= self._config.duration, msg @@ -67,10 +162,18 @@ def __init__(self, config: Optional[PipelineConfig] = None): @staticmethod def get_config_class() -> type: - return PipelineConfig + return SpeakerDiarizationConfig + + @staticmethod + def suggest_metric() -> BaseMetric: + return DiarizationErrorRate(collar=0, skip_overlap=False) + + @staticmethod + def hyper_parameters() -> Sequence[base.HyperParameter]: + return [base.TauActive, base.RhoUpdate, base.DeltaNew] @property - def config(self) -> PipelineConfig: + def config(self) -> SpeakerDiarizationConfig: return self._config def set_timestamp_shift(self, shift: float): diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py new file mode 100644 index 00000000..def833b6 --- /dev/null +++ b/src/diart/blocks/vad.py @@ -0,0 +1,208 @@ +from typing import Any, Optional, Union, Sequence, Tuple + +import numpy as np +import torch +from pyannote.core import Annotation, Timeline, SlidingWindowFeature, SlidingWindow, Segment +from pyannote.metrics.base import BaseMetric +from pyannote.metrics.detection import DetectionErrorRate +from typing_extensions import Literal + +from .aggregation import DelayedAggregation +from . import base +from .segmentation import SpeakerSegmentation +from .utils import Binarize +from .. import models as m +from .. import utils + + +class VoiceActivityDetectionConfig(base.StreamingConfig): + def __init__( + self, + segmentation: Optional[m.SegmentationModel] = None, + duration: Optional[float] = None, + step: float = 0.5, + latency: Optional[Union[float, Literal["max", "min"]]] = None, + tau_active: float = 0.6, + device: Optional[torch.device] = None, + **kwargs, + ): + # Default segmentation model is pyannote/segmentation + self.segmentation = segmentation + if self.segmentation is None: + self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") + + self._duration = duration + self._step = step + self._sample_rate: Optional[int] = None + + # Latency defaults to the step duration + self._latency = latency + if self._latency is None or self._latency == "min": + self._latency = self._step + elif self._latency == "max": + self._latency = self._duration + + self.tau_active = tau_active + self.device = device + if self.device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + @property + def duration(self) -> float: + # Default duration is the one given by the segmentation model + if self._duration is None: + self._duration = self.segmentation.duration + return self._duration + + @property + def step(self) -> float: + return self._step + + @property + def latency(self) -> float: + return self._latency + + @property + def sample_rate(self) -> int: + # Expected sample rate is given by the segmentation model + if self._sample_rate is None: + self._sample_rate = self.segmentation.sample_rate + return self._sample_rate + + @staticmethod + def from_dict(data: Any) -> 'VoiceActivityDetectionConfig': + # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None + device = utils.get(data, "device", None) + if device is None: + device = torch.device("cpu") if utils.get(data, "cpu", False) else None + + # Instantiate segmentation model + hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True)) + segmentation = utils.get(data, "segmentation", "pyannote/segmentation") + segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token) + + # Tau active and its alias + tau = utils.get(data, "tau_active", None) + if tau is None: + tau = utils.get(data, "tau", 0.6) + + return VoiceActivityDetectionConfig( + segmentation=segmentation, + duration=utils.get(data, "duration", None), + step=utils.get(data, "step", 0.5), + latency=utils.get(data, "latency", None), + tau_active=tau, + device=device, + ) + + +class VoiceActivityDetection(base.StreamingPipeline): + def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): + self._config = VoiceActivityDetectionConfig() if config is None else config + + msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" + assert self._config.step <= self._config.latency <= self._config.duration, msg + + self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device) + self.pred_aggregation = DelayedAggregation( + self._config.step, + self._config.latency, + strategy="hamming", + cropping_mode="loose", + ) + self.audio_aggregation = DelayedAggregation( + self._config.step, + self._config.latency, + strategy="first", + cropping_mode="center", + ) + self.binarize = Binarize(self._config.tau_active) + + # Internal state, handle with care + self.timestamp_shift = 0 + self.chunk_buffer, self.pred_buffer = [], [] + + @staticmethod + def get_config_class() -> type: + return VoiceActivityDetectionConfig + + @staticmethod + def suggest_metric() -> BaseMetric: + return DetectionErrorRate(collar=0, skip_overlap=False) + + @staticmethod + def hyper_parameters() -> Sequence[base.HyperParameter]: + return [base.TauActive] + + @property + def config(self) -> base.StreamingConfig: + return self._config + + def reset(self): + self.set_timestamp_shift(0) + self.chunk_buffer, self.pred_buffer = [], [] + + def set_timestamp_shift(self, shift: float): + self.timestamp_shift = shift + + def __call__( + self, + waveforms: Sequence[SlidingWindowFeature], + ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]: + batch_size = len(waveforms) + msg = "Pipeline expected at least 1 input" + assert batch_size >= 1, msg + + # Create batch from chunk sequence, shape (batch, samples, channels) + batch = torch.stack([torch.from_numpy(w.data) for w in waveforms]) + + expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate)) + msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" + assert batch.shape[1] == expected_num_samples, msg + + # Extract segmentation + segmentations = self.segmentation(batch) # shape (batch, frames, speakers) + voice_detection = torch.max(segmentations, dim=-1, keepdim=True)[0] # shape (batch, frames, 1) + + seg_resolution = waveforms[0].extent.duration / segmentations.shape[1] + + outputs = [] + for wav, vad in zip(waveforms, voice_detection): + # Add timestamps to segmentation + sw = SlidingWindow( + start=wav.extent.start, + duration=seg_resolution, + step=seg_resolution, + ) + vad = SlidingWindowFeature(vad.cpu().numpy(), sw) + + # Update sliding buffer + self.chunk_buffer.append(wav) + self.pred_buffer.append(vad) + + # Aggregate buffer outputs for this time step + agg_waveform = self.audio_aggregation(self.chunk_buffer) + agg_prediction = self.pred_aggregation(self.pred_buffer) + agg_prediction = self.binarize(agg_prediction).get_timeline(copy=False) + + # Shift prediction timestamps if required + if self.timestamp_shift != 0: + shifted_agg_prediction = Timeline(uri=agg_prediction.uri) + for segment in agg_prediction: + new_segment = Segment( + segment.start + self.timestamp_shift, + segment.end + self.timestamp_shift, + ) + shifted_agg_prediction.add(new_segment) + agg_prediction = shifted_agg_prediction + + # Convert timeline into annotation with single speaker "speech" + agg_prediction = agg_prediction.to_annotation(utils.repeat_label("speech")) + outputs.append((agg_prediction, agg_waveform)) + + # Make place for new chunks in buffer if required + if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows: + self.chunk_buffer = self.chunk_buffer[1:] + self.pred_buffer = self.pred_buffer[1:] + + return outputs diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py index b6a3f9ff..27d524c5 100644 --- a/src/diart/console/benchmark.py +++ b/src/diart/console/benchmark.py @@ -1,15 +1,17 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc import pandas as pd -from diart.blocks import OnlineSpeakerDiarization, PipelineConfig +from diart import argdoc +from diart import utils from diart.inference import Benchmark, Parallelize def run(): parser = argparse.ArgumentParser() parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") + parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -34,6 +36,8 @@ def run(): help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() + pipeline_class = utils.get_pipeline_class(args.pipeline) + benchmark = Benchmark( args.root, args.reference, @@ -43,11 +47,11 @@ def run(): batch_size=args.batch_size, ) - config = PipelineConfig.from_dict(vars(args)) + config = pipeline_class.get_config_class().from_dict(vars(args)) if args.num_workers > 0: benchmark = Parallelize(benchmark, args.num_workers) - report = benchmark(OnlineSpeakerDiarization, config) + report = benchmark(pipeline_class, config) if args.output is not None and isinstance(report, pd.DataFrame): report.to_csv(args.output / "benchmark_report.csv") diff --git a/src/diart/console/client.py b/src/diart/console/client.py index 084dbc13..db4915fa 100644 --- a/src/diart/console/client.py +++ b/src/diart/console/client.py @@ -3,11 +3,11 @@ from threading import Thread from typing import Text, Optional -import diart.argdoc as argdoc -import diart.sources as src -import diart.utils as utils import numpy as np import rx.operators as ops +from diart import argdoc +from diart import sources as src +from diart import utils from websocket import WebSocket diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index 2f632d57..46bb9328 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -1,10 +1,10 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc -import diart.sources as src -from diart.blocks import OnlineSpeakerDiarization, PipelineConfig -from diart.inference import RealTimeInference +from diart import argdoc +from diart import sources as src +from diart import utils +from diart.inference import StreamingInference from diart.sinks import RTTMWriter @@ -12,6 +12,8 @@ def run(): parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0", type=str, help="Server host") parser.add_argument("--port", default=7007, type=int, help="Server port") + parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -31,15 +33,16 @@ def run(): help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() - # Define online speaker diarization pipeline - config = PipelineConfig.from_dict(vars(args)) - pipeline = OnlineSpeakerDiarization(config) + # Resolve pipeline + pipeline_class = utils.get_pipeline_class(args.pipeline) + config = pipeline_class.get_config_class().from_dict(vars(args)) + pipeline = pipeline_class(config) # Create websocket audio source audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port) # Run online inference - inference = RealTimeInference( + inference = StreamingInference( pipeline, audio_source, batch_size=1, diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index d7218f07..e0c670c5 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -1,16 +1,18 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc -import diart.sources as src -from diart.blocks import OnlineSpeakerDiarization, PipelineConfig -from diart.inference import RealTimeInference +from diart import argdoc +from diart import sources as src +from diart import utils +from diart.inference import StreamingInference from diart.sinks import RTTMWriter def run(): parser = argparse.ArgumentParser() parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:'") + parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -32,9 +34,10 @@ def run(): help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() - # Define online speaker diarization pipeline - config = PipelineConfig.from_dict(vars(args)) - pipeline = OnlineSpeakerDiarization(config) + # Resolve pipeline + pipeline_class = utils.get_pipeline_class(args.pipeline) + config = pipeline_class.get_config_class().from_dict(vars(args)) + pipeline = pipeline_class(config) # Manage audio source block_size = config.optimal_block_size() @@ -51,7 +54,7 @@ def run(): audio_source = src.MicrophoneAudioSource(config.sample_rate, block_size, device) # Run online inference - inference = RealTimeInference( + inference = StreamingInference( pipeline, audio_source, batch_size=1, diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index 4ad8852a..a1f1b63a 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -1,10 +1,11 @@ import argparse from pathlib import Path -import diart.argdoc as argdoc import optuna -from diart.blocks import PipelineConfig, OnlineSpeakerDiarization -from diart.optim import Optimizer, HyperParameter +from diart import argdoc +from diart import utils +from diart.blocks.base import HyperParameter +from diart.optim import Optimizer from optuna.samplers import TPESampler @@ -13,6 +14,8 @@ def run(): parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") parser.add_argument("--reference", required=True, type=str, help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files") + parser.add_argument("--pipeline", default="SpeakerDiarization", type=str, + help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'") parser.add_argument("--segmentation", default="pyannote/segmentation", type=str, help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation") parser.add_argument("--embedding", default="pyannote/embedding", type=str, @@ -38,17 +41,28 @@ def run(): help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)") args = parser.parse_args() + # Retrieve pipeline class + pipeline_class = utils.get_pipeline_class(args.pipeline) + # Create the base configuration for each trial - base_config = PipelineConfig.from_dict(vars(args)) + base_config = pipeline_class.get_config_class().from_dict(vars(args)) # Create hyper-parameters to optimize + possible_hparams = pipeline_class.hyper_parameters() hparams = [HyperParameter.from_name(name) for name in args.hparams] + hparams = [hp for hp in hparams if hp in possible_hparams] + if not hparams: + print( + f"No hyper-parameters to optimize. " + f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}" + ) + exit(1) # Use a custom storage if given if args.output is not None: msg = "Both `output` and `storage` were set, but only one was expected" assert args.storage is None, msg - args.output = Path(args.output) + args.output = Path(args.output).expanduser() args.output.mkdir(parents=True, exist_ok=True) study_or_path = args.output elif args.storage is not None: @@ -60,11 +74,11 @@ def run(): # Run optimization Optimizer( + pipeline_class=pipeline_class, speech_path=args.root, reference_path=args.reference, study_or_path=study_or_path, batch_size=args.batch_size, - pipeline_class=OnlineSpeakerDiarization, hparams=hparams, base_config=base_config, )(num_iter=args.num_iter, show_progress=True) diff --git a/src/diart/inference.py b/src/diart/inference.py index f4b65f5f..6afda89e 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -4,32 +4,33 @@ from traceback import print_exc from typing import Union, Text, Optional, Callable, Tuple, List -import diart.operators as dops -import diart.sources as src import numpy as np import pandas as pd import rx import rx.operators as ops import torch -from diart import utils -from diart.blocks import BasePipeline, Resample, BasePipelineConfig -from diart.progress import ProgressBar, RichProgressBar, TQDMProgressBar -from diart.sinks import DiarizationPredictionAccumulator, RealTimePlot, WindowClosedException from pyannote.core import Annotation, SlidingWindowFeature from pyannote.database.util import load_rttm -from pyannote.metrics.diarization import DiarizationErrorRate +from pyannote.metrics.base import BaseMetric from rx.core import Observer from tqdm import tqdm +from . import blocks +from . import operators as dops +from . import sources as src +from . import utils +from .progress import ProgressBar, RichProgressBar, TQDMProgressBar +from .sinks import PredictionAccumulator, StreamingPlot, WindowClosedException -class RealTimeInference: + +class StreamingInference: """Performs inference in real time given a pipeline and an audio source. Streams an audio source to an online speaker diarization pipeline. It allows users to attach a chain of operations in the form of hooks. Parameters ---------- - pipeline: BasePipeline + pipeline: StreamingPipeline Configured speaker diarization pipeline. source: AudioSource Audio source to be read and streamed. @@ -52,7 +53,7 @@ class RealTimeInference: """ def __init__( self, - pipeline: BasePipeline, + pipeline: blocks.StreamingPipeline, source: src.AudioSource, batch_size: int = 1, do_profile: bool = True, @@ -66,7 +67,7 @@ def __init__( self.do_profile = do_profile self.do_plot = do_plot self.show_progress = show_progress - self.accumulator = DiarizationPredictionAccumulator(self.source.uri) + self.accumulator = PredictionAccumulator(self.source.uri) self.unit = "chunk" if self.batch_size == 1 else "batch" self._observers = [] @@ -102,7 +103,7 @@ def __init__( f"but pipeline's is {sample_rate}. Will resample." logging.warning(msg) self.stream = self.stream.pipe( - ops.map(Resample(self.source.sample_rate, sample_rate)) + ops.map(blocks.Resample(self.source.sample_rate, sample_rate)) ) # Add rx operators to manage the inputs and outputs of the pipeline @@ -202,7 +203,7 @@ def __call__(self) -> Annotation: latency=config.latency, sample_rate=config.sample_rate, ), - ops.do(RealTimePlot(config.duration, config.latency)), + ops.do(StreamingPlot(config.duration, config.latency)), ) observable.subscribe( on_error=self._handle_error, @@ -288,7 +289,7 @@ def get_file_paths(self) -> List[Path]: def run_single( self, - pipeline: BasePipeline, + pipeline: blocks.StreamingPipeline, filepath: Path, progress_bar: ProgressBar, ) -> Annotation: @@ -298,7 +299,7 @@ def run_single( Parameters ---------- - pipeline: BasePipeline + pipeline: StreamingPipeline Speaker diarization pipeline to run. filepath: Path Path to the target file. @@ -318,7 +319,7 @@ def run_single( pipeline.config.optimal_block_size(), ) pipeline.set_timestamp_shift(-padding[0]) - inference = RealTimeInference( + inference = StreamingInference( pipeline, source, self.batch_size, @@ -337,7 +338,11 @@ def run_single( return pred - def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[Annotation]]: + def evaluate( + self, + predictions: List[Annotation], + metric: BaseMetric, + ) -> Union[pd.DataFrame, List[Annotation]]: """If a reference path was provided, compute the diarization error rate of a list of predictions. @@ -345,6 +350,8 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An ---------- predictions: List[Annotation] Predictions to evaluate. + metric: BaseMetric + Evaluation metric from pyannote.metrics. Returns ------- @@ -353,8 +360,7 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An reference path was given. Otherwise return the same predictions. """ if self.reference_path is not None: - metric = DiarizationErrorRate(collar=0, skip_overlap=False) - progress_bar = TQDMProgressBar("Computing DER", leave=False) + progress_bar = TQDMProgressBar(f"Computing {metric.name}", leave=False) progress_bar.create(total=len(predictions), unit="file") progress_bar.start() for hyp in predictions: @@ -368,18 +374,22 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An def __call__( self, pipeline_class: type, - config: BasePipelineConfig, + config: blocks.StreamingConfig, + metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files. - Notice that the internal state of the pipeline is reset before benchmarking. + The internal state of the pipeline is reset before benchmarking. Parameters ---------- pipeline_class: class - Class from the BasePipeline hierarchy. + Class from the StreamingPipeline hierarchy. A pipeline from this class will be instantiated by each worker. - config: BasePipelineConfig - Diarization pipeline configuration. + config: StreamingConfig + Streaming pipeline configuration. + metric: Optional[BaseMetric] + Evaluation metric from pyannote.metrics. + Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`) Returns ------- @@ -400,7 +410,8 @@ def __call__( progress = TQDMProgressBar(desc, leave=False, do_close=True) predictions.append(self.run_single(pipeline, filepath, progress)) - return self.evaluate(predictions) + metric = pipeline.suggest_metric() if metric is None else metric + return self.evaluate(predictions, metric) class Parallelize: @@ -426,20 +437,20 @@ def __init__( def run_single_job( self, pipeline_class: type, - config: BasePipelineConfig, + config: blocks.StreamingConfig, filepath: Path, description: Text, - ): + ) -> Annotation: """Build and run a pipeline on a single file. Configure execution to show progress alongside parallel runs. Parameters ---------- pipeline_class: class - Class from the BasePipeline hierarchy. + Class from the StreamingPipeline hierarchy. A pipeline from this class will be instantiated. - config: BasePipelineConfig - Diarization pipeline configuration. + config: StreamingConfig + Streaming pipeline configuration. filepath: Path Path to the target file. description: Text @@ -463,7 +474,8 @@ def run_single_job( def __call__( self, pipeline_class: type, - config: BasePipelineConfig, + config: blocks.StreamingConfig, + metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files in parallel. Each worker will build and run the pipeline on a different file. @@ -471,10 +483,13 @@ def __call__( Parameters ---------- pipeline_class: class - Class from the BasePipeline hierarchy. + Class from the StreamingPipeline hierarchy. A pipeline from this class will be instantiated by each worker. - config: BasePipelineConfig - Diarization pipeline configuration. + config: StreamingConfig + Streaming pipeline configuration. + metric: Optional[BaseMetric] + Evaluation metric from pyannote.metrics. + Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`) Returns ------- @@ -512,4 +527,5 @@ def __call__( predictions = [job.get() for job in jobs] # Evaluate results - return self.benchmark.evaluate(predictions) + metric = pipeline_class.suggest_metric() if metric is None else metric + return self.benchmark.evaluate(predictions, metric) diff --git a/src/diart/optim.py b/src/diart/optim.py index 05800a05..f7a96a6e 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -1,51 +1,32 @@ from collections import OrderedDict -from dataclasses import dataclass from pathlib import Path from typing import Sequence, Text, Optional, Union from optuna import TrialPruned, Study, create_study from optuna.samplers import TPESampler from optuna.trial import Trial, FrozenTrial +from pyannote.metrics.base import BaseMetric from tqdm import trange, tqdm +from typing_extensions import Literal +from . import blocks from .audio import FilePath -from .blocks import BasePipelineConfig, PipelineConfig, OnlineSpeakerDiarization from .inference import Benchmark -@dataclass -class HyperParameter: - name: Text - low: float - high: float - - @staticmethod - def from_name(name: Text) -> 'HyperParameter': - if name == "tau_active": - return TauActive - if name == "rho_update": - return RhoUpdate - if name == "delta_new": - return DeltaNew - raise ValueError(f"Hyper-parameter '{name}' not recognized") - - -TauActive = HyperParameter("tau_active", low=0, high=1) -RhoUpdate = HyperParameter("rho_update", low=0, high=1) -DeltaNew = HyperParameter("delta_new", low=0, high=2) - - class Optimizer: def __init__( self, + pipeline_class: type, speech_path: Union[Text, Path], reference_path: Union[Text, Path], study_or_path: Union[FilePath, Study], batch_size: int = 32, - pipeline_class: type = OnlineSpeakerDiarization, - hparams: Optional[Sequence[HyperParameter]] = None, - base_config: Optional[BasePipelineConfig] = None, + hparams: Optional[Sequence[blocks.base.HyperParameter]] = None, + base_config: Optional[blocks.StreamingConfig] = None, do_kickstart_hparams: bool = True, + metric: Optional[BaseMetric] = None, + direction: Literal["minimize", "maximize"] = "minimize", ): self.pipeline_class = pipeline_class # FIXME can we run this benchmark in parallel? @@ -58,15 +39,17 @@ def __init__( batch_size=batch_size, ) + self.metric = metric + self.direction = direction self.base_config = base_config self.do_kickstart_hparams = do_kickstart_hparams if self.base_config is None: - self.base_config = PipelineConfig() + self.base_config = self.pipeline_class.get_config_class()() self.do_kickstart_hparams = False self.hparams = hparams if self.hparams is None: - self.hparams = [TauActive, RhoUpdate, DeltaNew] + self.hparams = self.pipeline_class.hyper_parameters() # Make sure hyper-parameters exist in the configuration class given possible_hparams = vars(self.base_config) @@ -85,7 +68,7 @@ def __init__( storage="sqlite:///" + str(study_or_path / f"{study_or_path.stem}.db"), sampler=TPESampler(), study_name=study_or_path.stem, - direction="minimize", + direction=self.direction, load_if_exists=True, ) else: @@ -105,7 +88,7 @@ def _callback(self, study: Study, trial: FrozenTrial): return self._progress.update(1) self._progress.set_description(f"Trial {trial.number + 1}") - values = {"best_der": study.best_value} + values = {"best_perf": study.best_value} for name, value in study.best_params.items(): values[f"best_{name}"] = value self._progress.set_postfix(OrderedDict(values)) @@ -125,11 +108,16 @@ def objective(self, trial: Trial) -> float: # Instantiate the new configuration for the trial config = self.base_config.__class__(**trial_config) + # Determine the evaluation metric + metric = self.metric + if metric is None: + metric = self.pipeline_class.suggest_metric() + # Run pipeline over the dataset - report = self.benchmark(self.pipeline_class, config) + report = self.benchmark(self.pipeline_class, config, metric) - # Extract DER from report - return report.loc["TOTAL", "diarization error rate"]["%"] + # Extract target metric from report + return report.loc["TOTAL", metric.name]["%"] def __call__(self, num_iter: int, show_progress: bool = True): self._progress = None diff --git a/src/diart/sinks.py b/src/diart/sinks.py index cf480bed..63c170d0 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -8,12 +8,14 @@ from rx.core import Observer from typing_extensions import Literal +from . import utils + class WindowClosedException(Exception): pass -def _extract_annotation(value: Union[Tuple, Annotation]) -> Annotation: +def _extract_prediction(value: Union[Tuple, Annotation]) -> Annotation: if isinstance(value, tuple): return value[0] if isinstance(value, Annotation): @@ -43,10 +45,11 @@ def patch(self): annotation.support(self.patch_collar).write_rttm(file) def on_next(self, value: Union[Tuple, Annotation]): - annotation = _extract_annotation(value) - annotation.uri = self.uri + prediction = _extract_prediction(value) + # Write prediction in RTTM format + prediction.uri = self.uri with open(self.path, 'a') as file: - annotation.write_rttm(file) + prediction.write_rttm(file) def on_error(self, error: Exception): self.patch() @@ -55,30 +58,30 @@ def on_completed(self): self.patch() -class DiarizationPredictionAccumulator(Observer): +class PredictionAccumulator(Observer): def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05): super().__init__() self.uri = uri self.patch_collar = patch_collar - self._annotation = None + self._prediction: Optional[Annotation] = None def patch(self): """Stitch same-speaker turns that are close to each other""" - if self._annotation is not None: - self._annotation = self._annotation.support(self.patch_collar) + if self._prediction is not None: + self._prediction = self._prediction.support(self.patch_collar) def get_prediction(self) -> Annotation: # Patch again in case this is called before on_completed self.patch() - return self._annotation + return self._prediction def on_next(self, value: Union[Tuple, Annotation]): - annotation = _extract_annotation(value) - annotation.uri = self.uri - if self._annotation is None: - self._annotation = annotation + prediction = _extract_prediction(value) + prediction.uri = self.uri + if self._prediction is None: + self._prediction = prediction else: - self._annotation.update(annotation) + self._prediction.update(prediction) def on_error(self, error: Exception): self.patch() @@ -87,7 +90,7 @@ def on_completed(self): self.patch() -class RealTimePlot(Observer): +class StreamingPlot(Observer): def __init__( self, duration: float, @@ -134,11 +137,15 @@ def get_plot_bounds(self, real_time: float) -> Segment: start_time = max(0., end_time - self.window_duration) return Segment(start_time, end_time) - def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): + def on_next( + self, + values: Tuple[Annotation, SlidingWindowFeature, float] + ): if self.window_closed: raise WindowClosedException prediction, waveform, real_time = values + # Initialize figure if first call if self.figure is None: self._init_figure() @@ -147,15 +154,21 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): # Set plot bounds notebook.crop = self.get_plot_bounds(real_time) - # Plot current values + # Align prediction and reference if possible if self.reference is not None: metric = DiarizationErrorRate() mapping = metric.optimal_mapping(self.reference, prediction) prediction.rename_labels(mapping=mapping, copy=False) + + # Plot prediction notebook.plot_annotation(prediction, self.axs[0]) self.axs[0].set_title("Output") + + # Plot waveform notebook.plot_feature(waveform, self.axs[1]) self.axs[1].set_title("Audio") + + # Plot reference if available if self.num_axs == 3: notebook.plot_annotation(self.reference, self.axs[2]) self.axs[2].set_title("Reference") diff --git a/src/diart/sources.py b/src/diart/sources.py index 0f5dedf7..b34d5cf3 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -5,12 +5,12 @@ import numpy as np import sounddevice as sd import torch -from diart import utils from einops import rearrange from rx.subject import Subject from torchaudio.io import StreamReader from websocket_server import WebsocketServer +from . import utils from .audio import FilePath, AudioLoader diff --git a/src/diart/utils.py b/src/diart/utils.py index e90861c7..e825ef29 100644 --- a/src/diart/utils.py +++ b/src/diart/utils.py @@ -4,9 +4,11 @@ import matplotlib.pyplot as plt import numpy as np -from diart.progress import ProgressBar from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook +from .progress import ProgressBar +from . import blocks + class Chronometer: def __init__(self, unit: Text, progress_bar: Optional[ProgressBar] = None): @@ -74,6 +76,18 @@ def get_padding_left(stream_duration: float, chunk_duration: float) -> float: return 0 +def repeat_label(label: Text): + while True: + yield label + + +def get_pipeline_class(class_name: Text) -> type: + pipeline_class = getattr(blocks, class_name, None) + msg = f"Pipeline '{class_name}' doesn't exist" + assert pipeline_class is not None, msg + return pipeline_class + + def get_padding_right(latency: float, step: float) -> float: return latency - step From 6caa4a4ab9b2e8c2ab7bc5049b1112f325a08d5b Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Apr 2023 17:43:51 +0200 Subject: [PATCH 18/23] Update link in setup.cfg --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 594c876e..e67e4426 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,11 +2,11 @@ name=diart version=0.7.0 author=Juan Manuel Coria -description=Speaker diarization in real time +description=Streaming speaker diarization in real-time long_description=file: README.md long_description_content_type=text/markdown keywords=speaker diarization, streaming, online, real time, rxpy -url=https://github.com/juanmc2005/StreamingSpeakerDiarization +url=https://github.com/juanmc2005/diart license=MIT classifiers= Development Status :: 4 - Beta From 0993fe85411d09f3b0bb5db709b7b02a7a56f0be Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Apr 2023 17:51:41 +0200 Subject: [PATCH 19/23] Update code snippets in README --- README.md | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index ef533946..57ca293a 100644 --- a/README.md +++ b/README.md @@ -110,17 +110,17 @@ See `diart.stream -h` for more options. ### From python -Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk: +Use `StreamingInference` to run a pipeline on an audio source and write the results to disk: ```python -from diart import OnlineSpeakerDiarization +from diart import SpeakerDiarization from diart.sources import MicrophoneAudioSource -from diart.inference import RealTimeInference +from diart.inference import StreamingInference from diart.sinks import RTTMWriter -pipeline = OnlineSpeakerDiarization() +pipeline = SpeakerDiarization() mic = MicrophoneAudioSource(pipeline.config.sample_rate) -inference = RealTimeInference(pipeline, mic, do_plot=True) +inference = StreamingInference(pipeline, mic, do_plot=True) inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm")) prediction = inference() ``` @@ -129,13 +129,13 @@ For inference and evaluation on a dataset we recommend to use `Benchmark` (see n ## 🤖 Custom models -Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel` (which are PyTorch `Module` subclasses): +Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`): ```python -from diart import OnlineSpeakerDiarization, PipelineConfig +from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart.models import EmbeddingModel, SegmentationModel from diart.sources import MicrophoneAudioSource -from diart.inference import RealTimeInference +from diart.inference import StreamingInference def model_loader(): @@ -168,19 +168,19 @@ class MyEmbeddingModel(EmbeddingModel): return self.model(waveform, weights) -config = PipelineConfig( +config = SpeakerDiarizationConfig( segmentation=MySegmentationModel(), embedding=MyEmbeddingModel() ) -pipeline = OnlineSpeakerDiarization(config) +pipeline = SpeakerDiarization(config) mic = MicrophoneAudioSource(config.sample_rate) -inference = RealTimeInference(pipeline, mic) +inference = StreamingInference(pipeline, mic) prediction = inference() ``` ## 📈 Tune hyper-parameters -Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset. +Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs. ### From the command line @@ -281,7 +281,7 @@ diart.serve --host 0.0.0.0 --port 7007 diart.client microphone --host --port 7007 ``` -**Note:** please make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`. +**Note:** make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`. See `-h` for more options. @@ -290,13 +290,13 @@ See `-h` for more options. For customized solutions, a server can also be created in python using the `WebSocketAudioSource`: ```python -from diart import OnlineSpeakerDiarization +from diart import SpeakerDiarization from diart.sources import WebSocketAudioSource -from diart.inference import RealTimeInference +from diart.inference import StreamingInference -pipeline = OnlineSpeakerDiarization() +pipeline = SpeakerDiarization() source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007) -inference = RealTimeInference(pipeline, source) +inference = StreamingInference(pipeline, source) inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm())) prediction = inference() ``` @@ -354,14 +354,14 @@ or using the inference API: ```python from diart.inference import Benchmark, Parallelize -from diart import OnlineSpeakerDiarization, PipelineConfig +from diart import SpeakerDiarization, SpeakerDiarizationConfig from diart.models import SegmentationModel benchmark = Benchmark("/wav/dir", "/rttm/dir") name = "pyannote/segmentation@Interspeech2021" segmentation = SegmentationModel.from_pyannote(name) -config = PipelineConfig( +config = SpeakerDiarizationConfig( # Set the model used in the paper segmentation=segmentation, step=0.5, @@ -370,12 +370,12 @@ config = PipelineConfig( rho_update=0.422, delta_new=1.517 ) -benchmark(OnlineSpeakerDiarization, config) +benchmark(SpeakerDiarization, config) # Run the same benchmark in parallel p_benchmark = Parallelize(benchmark, num_workers=4) if __name__ == "__main__": # Needed for multiprocessing - p_benchmark(OnlineSpeakerDiarization, config) + p_benchmark(SpeakerDiarization, config) ``` This pre-calculates model outputs in batches, so it runs a lot faster. From 95d4fae66dea06e1cbb12ac591e5f323687cd02f Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 19 Apr 2023 21:18:36 +0200 Subject: [PATCH 20/23] Add minor README modifications --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 57ca293a..ae13059f 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ | - 🤖 Custom models + 🤖 Add your model | @@ -127,7 +127,7 @@ prediction = inference() For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)). -## 🤖 Custom models +## 🤖 Add your model Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`): From 569c68fa5648c9c940dae215ed557582f17a513f Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Mon, 24 Apr 2023 11:25:51 +0200 Subject: [PATCH 21/23] Rename base pipeline and config objects --- src/diart/__init__.py | 4 ++-- src/diart/blocks/__init__.py | 2 +- src/diart/blocks/base.py | 8 ++++---- src/diart/blocks/diarization.py | 4 ++-- src/diart/blocks/vad.py | 6 +++--- src/diart/inference.py | 10 +++++----- src/diart/optim.py | 2 +- 7 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/diart/__init__.py b/src/diart/__init__.py index e29287a0..4bd51327 100644 --- a/src/diart/__init__.py +++ b/src/diart/__init__.py @@ -1,8 +1,8 @@ from .blocks import ( SpeakerDiarization, - StreamingPipeline, + Pipeline, SpeakerDiarizationConfig, - StreamingConfig, + PipelineConfig, VoiceActivityDetection, VoiceActivityDetectionConfig, ) diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py index e6e8c479..15cf81d9 100644 --- a/src/diart/blocks/__init__.py +++ b/src/diart/blocks/__init__.py @@ -14,6 +14,6 @@ ) from .segmentation import SpeakerSegmentation from .diarization import SpeakerDiarization, SpeakerDiarizationConfig -from .base import StreamingConfig, StreamingPipeline +from .base import PipelineConfig, Pipeline from .utils import Binarize, Resample, AdjustVolume from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py index 28f313eb..11ef961d 100644 --- a/src/diart/blocks/base.py +++ b/src/diart/blocks/base.py @@ -31,7 +31,7 @@ def from_name(name: Text) -> 'HyperParameter': DeltaNew = HyperParameter("delta_new", low=0, high=2) -class StreamingConfig: +class PipelineConfig: @property def duration(self) -> float: raise NotImplementedError @@ -49,7 +49,7 @@ def sample_rate(self) -> int: raise NotImplementedError @staticmethod - def from_dict(data: Any) -> 'StreamingConfig': + def from_dict(data: Any) -> 'PipelineConfig': raise NotImplementedError def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: @@ -62,7 +62,7 @@ def optimal_block_size(self) -> int: return int(np.rint(self.step * self.sample_rate)) -class StreamingPipeline: +class Pipeline: @staticmethod def get_config_class() -> type: raise NotImplementedError @@ -76,7 +76,7 @@ def hyper_parameters() -> Sequence[HyperParameter]: raise NotImplementedError @property - def config(self) -> StreamingConfig: + def config(self) -> PipelineConfig: raise NotImplementedError def reset(self): diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py index f2a25119..06658cfc 100644 --- a/src/diart/blocks/diarization.py +++ b/src/diart/blocks/diarization.py @@ -17,7 +17,7 @@ from .. import utils -class SpeakerDiarizationConfig(base.StreamingConfig): +class SpeakerDiarizationConfig(base.PipelineConfig): def __init__( self, segmentation: Optional[m.SegmentationModel] = None, @@ -129,7 +129,7 @@ def sample_rate(self) -> int: return self._sample_rate -class SpeakerDiarization(base.StreamingPipeline): +class SpeakerDiarization(base.Pipeline): def __init__(self, config: Optional[SpeakerDiarizationConfig] = None): self._config = SpeakerDiarizationConfig() if config is None else config diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py index def833b6..e519a9cf 100644 --- a/src/diart/blocks/vad.py +++ b/src/diart/blocks/vad.py @@ -15,7 +15,7 @@ from .. import utils -class VoiceActivityDetectionConfig(base.StreamingConfig): +class VoiceActivityDetectionConfig(base.PipelineConfig): def __init__( self, segmentation: Optional[m.SegmentationModel] = None, @@ -96,7 +96,7 @@ def from_dict(data: Any) -> 'VoiceActivityDetectionConfig': ) -class VoiceActivityDetection(base.StreamingPipeline): +class VoiceActivityDetection(base.Pipeline): def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None): self._config = VoiceActivityDetectionConfig() if config is None else config @@ -135,7 +135,7 @@ def hyper_parameters() -> Sequence[base.HyperParameter]: return [base.TauActive] @property - def config(self) -> base.StreamingConfig: + def config(self) -> base.PipelineConfig: return self._config def reset(self): diff --git a/src/diart/inference.py b/src/diart/inference.py index 6afda89e..f562fdd9 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -53,7 +53,7 @@ class StreamingInference: """ def __init__( self, - pipeline: blocks.StreamingPipeline, + pipeline: blocks.Pipeline, source: src.AudioSource, batch_size: int = 1, do_profile: bool = True, @@ -289,7 +289,7 @@ def get_file_paths(self) -> List[Path]: def run_single( self, - pipeline: blocks.StreamingPipeline, + pipeline: blocks.Pipeline, filepath: Path, progress_bar: ProgressBar, ) -> Annotation: @@ -374,7 +374,7 @@ def evaluate( def __call__( self, pipeline_class: type, - config: blocks.StreamingConfig, + config: blocks.PipelineConfig, metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files. @@ -437,7 +437,7 @@ def __init__( def run_single_job( self, pipeline_class: type, - config: blocks.StreamingConfig, + config: blocks.PipelineConfig, filepath: Path, description: Text, ) -> Annotation: @@ -474,7 +474,7 @@ def run_single_job( def __call__( self, pipeline_class: type, - config: blocks.StreamingConfig, + config: blocks.PipelineConfig, metric: Optional[BaseMetric] = None, ) -> Union[pd.DataFrame, List[Annotation]]: """Run a given pipeline on a set of audio files in parallel. diff --git a/src/diart/optim.py b/src/diart/optim.py index f7a96a6e..86492627 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -23,7 +23,7 @@ def __init__( study_or_path: Union[FilePath, Study], batch_size: int = 32, hparams: Optional[Sequence[blocks.base.HyperParameter]] = None, - base_config: Optional[blocks.StreamingConfig] = None, + base_config: Optional[blocks.PipelineConfig] = None, do_kickstart_hparams: bool = True, metric: Optional[BaseMetric] = None, direction: Literal["minimize", "maximize"] = "minimize", From a16bb5c40c0c0850f3ab0bf197ea0a86ef214484 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Mon, 24 Apr 2023 18:30:33 +0200 Subject: [PATCH 22/23] Add initial implementation of SpeakerAwareTranscription --- src/diart/inference.py | 2 +- src/diart/pipelines/__init__.py | 1 + src/diart/pipelines/speaker_transcription.py | 316 +++++++++++++++++++ src/diart/pipelines/transcription.py | 28 -- src/diart/sinks.py | 21 +- 5 files changed, 337 insertions(+), 31 deletions(-) create mode 100644 src/diart/pipelines/speaker_transcription.py diff --git a/src/diart/inference.py b/src/diart/inference.py index eb2de85e..1589feed 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -116,7 +116,7 @@ def __init__( self.stream = self.stream.pipe( ops.flat_map(lambda results: rx.from_iterable(results)), - ops.do_action(lambda pred_wav: self._predictions.append(pred_wav[0])), + ops.do_action(lambda res: self._predictions.append(res[0] if isinstance(res, tuple) else res)), ) if show_progress: diff --git a/src/diart/pipelines/__init__.py b/src/diart/pipelines/__init__.py index 11fe5e82..55676f66 100644 --- a/src/diart/pipelines/__init__.py +++ b/src/diart/pipelines/__init__.py @@ -1,4 +1,5 @@ from .base import Pipeline, PipelineConfig from .diarization import SpeakerDiarization, SpeakerDiarizationConfig +from .speaker_transcription import SpeakerAwareTranscription, SpeakerAwareTranscriptionConfig from .transcription import Transcription, TranscriptionConfig from .voice import VoiceActivityDetection, VoiceActivityDetectionConfig diff --git a/src/diart/pipelines/speaker_transcription.py b/src/diart/pipelines/speaker_transcription.py new file mode 100644 index 00000000..87d7d192 --- /dev/null +++ b/src/diart/pipelines/speaker_transcription.py @@ -0,0 +1,316 @@ +from pathlib import Path +from typing import Any, Optional, Union, Sequence, Tuple, Text, List + +import numpy as np +import torch +from diart.metrics import Metric +from pyannote.core import SlidingWindowFeature, SlidingWindow, Annotation, Segment +from rx.core import Observer +from typing_extensions import Literal + +from .base import Pipeline, PipelineConfig +from .diarization import SpeakerDiarization, SpeakerDiarizationConfig +from .hparams import HyperParameter, TauActive, RhoUpdate, DeltaNew +from .. import models as m +from .. import sinks +from .. import blocks +from .. import utils +from ..metrics import WordErrorRate + + +class SpeakerAwareTranscriptionConfig(PipelineConfig): + def __init__( + self, + asr: Optional[m.SpeechRecognitionModel] = None, + segmentation: Optional[m.SegmentationModel] = None, + embedding: Optional[m.EmbeddingModel] = None, + duration: Optional[float] = None, + asr_duration: float = 3, + step: float = 0.5, + latency: Optional[Union[float, Literal["max", "min"]]] = None, + tau_active: float = 0.5, + rho_update: float = 0.3, + delta_new: float = 1, + language: Optional[Text] = None, + beam_size: Optional[int] = None, + gamma: float = 3, + beta: float = 10, + max_speakers: int = 20, + merge_collar: float = 0.05, + diarization_device: Optional[torch.device] = None, + asr_device: Optional[torch.device] = None, + **kwargs, + ): + # Default segmentation model is pyannote/segmentation + self.segmentation = segmentation + if self.segmentation is None: + self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") + + self._duration = duration + self._sample_rate: Optional[int] = None + + # Default embedding model is pyannote/embedding + self.embedding = embedding + if self.embedding is None: + self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding") + + # Latency defaults to the step duration + self._step = step + self._latency = latency + if self._latency is None or self._latency == "min": + self._latency = self._step + elif self._latency == "max": + self._latency = self.duration + + self.tau_active = tau_active + self.rho_update = rho_update + self.delta_new = delta_new + self.gamma = gamma + self.beta = beta + self.max_speakers = max_speakers + self.merge_collar = merge_collar + self.asr_duration = asr_duration + + self.diarization_device = diarization_device + if self.diarization_device is None: + self.diarization_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.language = language + self.beam_size = beam_size + + self.asr_device = asr_device + if self.asr_device is None: + self.asr_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Default ASR model is Whisper small (244M parameters) + self.asr = asr + if self.asr is None: + self.asr = m.SpeechRecognitionModel.from_whisper("small") + self.asr.set_language(self.language) + self.asr.set_beam_size(self.beam_size) + + def to_diarization_config(self) -> SpeakerDiarizationConfig: + return SpeakerDiarizationConfig( + segmentation=self.segmentation, + embedding=self.embedding, + duration=self.duration, + step=self.step, + latency=self.latency, + tau_active=self.tau_active, + rho_update=self.rho_update, + delta_new=self.delta_new, + gamma=self.gamma, + beta=self.beta, + max_speakers=self.max_speakers, + merge_collar=self.merge_collar, + device=self.diarization_device, + ) + + @property + def duration(self) -> float: + # Default duration is the one given by the segmentation model + if self._duration is None: + self._duration = self.segmentation.duration + return self._duration + + @property + def step(self) -> float: + return self._step + + @property + def latency(self) -> float: + return self._latency + + @property + def sample_rate(self) -> int: + if self._sample_rate is None: + dia_sample_rate = self.segmentation.sample_rate + asr_sample_rate = self.asr.sample_rate + msg = "Sample rates for speech recognition and speaker segmentation models must match" + assert dia_sample_rate == asr_sample_rate, msg + self._sample_rate = dia_sample_rate + return self._sample_rate + + @staticmethod + def from_dict(data: Any) -> 'SpeakerAwareTranscriptionConfig': + # Resolve arguments exactly like diarization + dia_config = SpeakerDiarizationConfig.from_dict(data) + + # Default ASR model is Whisper small (244M parameters) + whisper_size = utils.get(data, "whisper", "small") + asr = m.SpeechRecognitionModel.from_whisper(whisper_size) + + return SpeakerAwareTranscriptionConfig( + asr=asr, + segmentation=dia_config.segmentation, + embedding=dia_config.embedding, + duration=dia_config.duration, + asr_duration=utils.get(data, "asr_duration", 3), + step=dia_config.step, + latency=dia_config.latency, + tau_active=dia_config.tau_active, + rho_update=dia_config.rho_update, + delta_new=dia_config.delta_new, + language=utils.get(data, "language", None), + beam_size=utils.get(data, "beam_size", None), + gamma=dia_config.gamma, + beta=dia_config.beta, + max_speakers=dia_config.max_speakers, + merge_collar=dia_config.merge_collar, + diarization_device=dia_config.device, + # TODO handle different devices + asr_device=dia_config.device, + ) + + +class SpeakerAwareTranscription(Pipeline): + def __init__(self, config: Optional[SpeakerAwareTranscriptionConfig] = None): + self._config = SpeakerAwareTranscriptionConfig() if config is None else config + self.diarization = SpeakerDiarization(self.config.to_diarization_config()) + self.asr = blocks.SpeechRecognition(self.config.asr, self.config.asr_device) + + # Internal state, handle with care + self.audio_buffer, self.dia_buffer = None, None + + @staticmethod + def get_config_class() -> type: + return SpeakerAwareTranscriptionConfig + + @staticmethod + def suggest_metric() -> Metric: + # TODO per-speaker WER? + return WordErrorRate() + + @staticmethod + def hyper_parameters() -> Sequence[HyperParameter]: + return [TauActive, RhoUpdate, DeltaNew] + + @property + def config(self) -> SpeakerAwareTranscriptionConfig: + return self._config + + def reset(self): + self.diarization.reset() + self.audio_buffer, self.dia_buffer = None, None + + def set_timestamp_shift(self, shift: float): + self.diarization.set_timestamp_shift(shift) + + def join_predictions(self, predictions: List[Text]) -> Text: + return "\n".join(predictions) + + def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Path]): + with open(Path(dir_path) / f"{uri}.txt", "w") as out_file: + out_file.write(prediction) + + def suggest_display(self) -> Observer: + return sinks.RichScreen() + + def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer: + return sinks.TextWriter(Path(output_dir) / f"{uri}.txt") + + def __call__( + self, + waveforms: Sequence[SlidingWindowFeature], + ) -> Sequence[Text]: + # Compute diarization output + diarization_output = self.diarization(waveforms) + + first_chunk = diarization_output[0][1] + output_start = first_chunk.extent.start + resolution = first_chunk.sliding_window.duration + diarization, chunk_data = Annotation(), [] + for dia, chunk in diarization_output: + diarization = diarization.update(dia) + chunk_data.append(chunk.data) + + # Update diarization output buffer + if self.dia_buffer is None: + self.dia_buffer = diarization + else: + self.dia_buffer = self.dia_buffer.update(diarization) + self.dia_buffer = self.dia_buffer.support(self.config.merge_collar) + + # Update audio buffer + if self.audio_buffer is None: + window = SlidingWindow(resolution, resolution, output_start) + self.audio_buffer = SlidingWindowFeature(np.concatenate(chunk_data, axis=0), window) + else: + chunk_data.insert(0, self.audio_buffer.data) + self.audio_buffer = SlidingWindowFeature( + np.concatenate(chunk_data, axis=0), + self.audio_buffer.sliding_window + ) + + # Extract audio to transcribe from the buffer + asr_duration = self.config.asr_duration + buffer_duration = self.audio_buffer.extent.duration + asr_batch_size = int(buffer_duration / asr_duration) + + if asr_batch_size == 0: + return ["" for _ in waveforms] + + buffer_start = self.audio_buffer.extent.start + asr_inputs, input_dia, last_end_time = [], [], None + for i in range(asr_batch_size): + start = buffer_start + i * asr_duration + last_end_time = start + asr_duration + region = Segment(start, last_end_time) + chunk = self.audio_buffer.crop(region, fixed=asr_duration) + window = SlidingWindow(resolution, resolution, start) + asr_inputs.append(SlidingWindowFeature(chunk, window)) + input_dia.append(self.dia_buffer.crop(region)) + + # Create ASR batch, shape (batch, samples, channels) + batch = torch.stack([torch.from_numpy(w.data) for w in asr_inputs]) + + # Remove transcribed chunks from buffer + new_buffer_bounds = Segment(last_end_time, self.audio_buffer.extent.end) + new_buffer = self.audio_buffer.crop(new_buffer_bounds, fixed=new_buffer_bounds.duration) + window = SlidingWindow(resolution, resolution, last_end_time) + self.audio_buffer = SlidingWindowFeature(new_buffer, window) + self.dia_buffer = self.dia_buffer.extrude(Segment(0, last_end_time)) + + # Filter out non-speech chunks + has_voice = torch.tensor([dia.get_timeline().duration() > 0 for dia in input_dia]) + has_voice = torch.where(has_voice)[0] + + # Return empty list if no speech in the entire batch + if len(has_voice) == 0: + return ["" for _ in waveforms] + + # Transcribe batch + outputs = self.asr(batch[has_voice]) + + # Align transcription with diarization to determine speakers + full_transcription = [] + for i, waveform in enumerate(asr_inputs): + if i not in has_voice: + continue + buffer_shift = waveform.sliding_window.start + for text, timestamp in zip(outputs[i].chunks, outputs[i].timestamps): + if not text.strip(): + continue + target_region = Segment( + buffer_shift + timestamp.start, + buffer_shift + timestamp.end, + ) + dia = input_dia[i].crop(target_region) + speakers = dia.labels() + num_speakers = len(speakers) + if num_speakers == 0: + # Include transcription but don't assign a speaker + full_transcription.append(text) + elif num_speakers == 1: + # Typical case, annotate text with the only speaker + full_transcription.append(f"[{speakers[0]}]{text}") + else: + # Multiple speakers for the same text block, choose the most active one + max_spk = np.argmax([dia.label_duration(spk) for spk in speakers]) + full_transcription.append(f"[{speakers[max_spk]}]{text}") + + batch_size = len(waveforms) + output = [" ".join(full_transcription).strip()] + if batch_size > 1: + output += [""] * (batch_size - 1) + return output diff --git a/src/diart/pipelines/transcription.py b/src/diart/pipelines/transcription.py index 3616222e..49b60175 100644 --- a/src/diart/pipelines/transcription.py +++ b/src/diart/pipelines/transcription.py @@ -182,31 +182,3 @@ def __call__( (outputs[mapping[i]].text if i in has_voice else "", waveforms[i]) for i in range(batch_size) ] - - # TODO align text with speakers if diarization is not None - - # diarization = diarization[0] - # - # # Align transcription with diarization to determine speakers - # full_transcription = [] - # buffer_shift = waveform.sliding_window.start - # for text, timestamp in zip(outputs.chunks, outputs.timestamps): - # target_region = Segment( - # buffer_shift + timestamp.start, - # buffer_shift + timestamp.end - # ) - # dia = diarization.crop(target_region) - # speakers = dia.labels() - # num_speakers = len(speakers) - # if num_speakers == 0: - # # Include transcription but don't assign a speaker - # full_transcription.append(text) - # elif num_speakers == 1: - # # Typical case, annotate text with the only speaker - # full_transcription.append(f"[{speakers[0]}]{text}") - # else: - # # Multiple speakers for the same text block, choose the most active one - # max_spk = np.argmax([dia.label_duration(spk) for spk in speakers]) - # full_transcription.append(f"[{speakers[max_spk]}]{text}") - # - # return [(" ".join(full_transcription).strip(), waveform)] \ No newline at end of file diff --git a/src/diart/sinks.py b/src/diart/sinks.py index be461fff..fbc46904 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -62,6 +62,9 @@ def __init__(self, path: Union[Path, Text]): if self.path.exists(): self.path.unlink() + def on_error(self, error: Exception): + pass + def on_next(self, value: Union[Tuple, Text]): # Write transcription to file prediction = _extract_prediction(value) @@ -111,15 +114,25 @@ def __init__(self, speaker_colors: Optional[List[Text]] = None): "yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2" ] self.num_colors = len(self.colors) + self._speaker_to_color = {} + + def on_error(self, error: Exception): + pass def on_next(self, value: Union[Tuple, Text]): prediction = _extract_prediction(value) + if not prediction.strip(): + return # Extract speakers speakers = sorted(re.findall(r'\[.*?]', prediction)) # Colorize based on speakers colorized = prediction - for i, speaker in enumerate(speakers): - colorized = colorized.replace(speaker, f"[{self.colors[i % self.num_colors]}]") + for spk in speakers: + name = spk[1:-1] + if name not in self._speaker_to_color: + next_color_idx = len(self._speaker_to_color) % self.num_colors + self._speaker_to_color[name] = self.colors[next_color_idx] + colorized = colorized.replace(spk, f"[{self._speaker_to_color[name]}]") # Print result rich.print(colorized) @@ -180,6 +193,10 @@ def get_plot_bounds(self) -> Segment: start_time = max(0., end_time - self.window_duration) return Segment(start_time, end_time) + def on_error(self, error: Exception): + # Do nothing on error + pass + def on_next( self, values: Tuple[Annotation, SlidingWindowFeature] From c7bbcc43aadd65358da6d95065f7a250a1f3ad81 Mon Sep 17 00:00:00 2001 From: juanmc2005 Date: Wed, 26 Apr 2023 14:37:34 +0200 Subject: [PATCH 23/23] Refactor SpeakerAwareTranscription --- src/diart/console/benchmark.py | 2 + src/diart/console/serve.py | 4 +- src/diart/console/stream.py | 2 + src/diart/console/tune.py | 2 + src/diart/pipelines/speaker_transcription.py | 120 +++++++++++-------- src/diart/utils.py | 6 +- 6 files changed, 86 insertions(+), 50 deletions(-) diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py index d8f04183..3c87edab 100644 --- a/src/diart/console/benchmark.py +++ b/src/diart/console/benchmark.py @@ -24,6 +24,8 @@ def run(): help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files") parser.add_argument("--duration", default=5, type=float, help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline") + parser.add_argument("--asr-duration", default=3, type=float, + help=f"Duration of the transcription window (in seconds). Defaults to 3") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index 0698ede0..fe668c5b 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -24,6 +24,8 @@ def run(): help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") parser.add_argument("--duration", type=float, help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline") + parser.add_argument("--asr-duration", default=3, type=float, + help=f"Duration of the transcription window (in seconds). Defaults to 3") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") @@ -61,7 +63,7 @@ def run(): inference.attach_observers(pipeline.suggest_writer(audio_source.uri, args.output)) # Send back responses as text - inference.attach_hooks(lambda pred_wav: audio_source.send(utils.serialize_prediction(pred_wav[0]))) + inference.attach_hooks(lambda result: audio_source.send(utils.serialize_prediction(result))) # Run server and pipeline inference() diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py index 1436eb8a..af8e2cf8 100644 --- a/src/diart/console/stream.py +++ b/src/diart/console/stream.py @@ -23,6 +23,8 @@ def run(): help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") parser.add_argument("--duration", default=5, type=float, help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline") + parser.add_argument("--asr-duration", default=3, type=float, + help=f"Duration of the transcription window (in seconds). Defaults to 3") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py index f492c704..ea34ed97 100644 --- a/src/diart/console/tune.py +++ b/src/diart/console/tune.py @@ -26,6 +26,8 @@ def run(): help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding") parser.add_argument("--duration", default=5, type=float, help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline") + parser.add_argument("--asr-duration", default=3, type=float, + help=f"Duration of the transcription window (in seconds). Defaults to 3") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") diff --git a/src/diart/pipelines/speaker_transcription.py b/src/diart/pipelines/speaker_transcription.py index 87d7d192..ad2303fa 100644 --- a/src/diart/pipelines/speaker_transcription.py +++ b/src/diart/pipelines/speaker_transcription.py @@ -209,13 +209,8 @@ def suggest_display(self) -> Observer: def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer: return sinks.TextWriter(Path(output_dir) / f"{uri}.txt") - def __call__( - self, - waveforms: Sequence[SlidingWindowFeature], - ) -> Sequence[Text]: - # Compute diarization output - diarization_output = self.diarization(waveforms) - + def _update_buffers(self, diarization_output: Sequence[Tuple[Annotation, SlidingWindowFeature]]): + # Separate diarization and aligned audio chunks first_chunk = diarization_output[0][1] output_start = first_chunk.extent.start resolution = first_chunk.sliding_window.duration @@ -239,78 +234,109 @@ def __call__( chunk_data.insert(0, self.audio_buffer.data) self.audio_buffer = SlidingWindowFeature( np.concatenate(chunk_data, axis=0), - self.audio_buffer.sliding_window + self.audio_buffer.sliding_window, ) - # Extract audio to transcribe from the buffer - asr_duration = self.config.asr_duration + def _extract_asr_inputs(self) -> Tuple[List[SlidingWindowFeature], List[Annotation]]: + chunk_duration = self.config.asr_duration buffer_duration = self.audio_buffer.extent.duration - asr_batch_size = int(buffer_duration / asr_duration) - - if asr_batch_size == 0: - return ["" for _ in waveforms] - + batch_size = int(buffer_duration / chunk_duration) buffer_start = self.audio_buffer.extent.start + resolution = self.audio_buffer.sliding_window.duration + + # Extract audio chunks with their diarization asr_inputs, input_dia, last_end_time = [], [], None - for i in range(asr_batch_size): - start = buffer_start + i * asr_duration - last_end_time = start + asr_duration + for i in range(batch_size): + start = buffer_start + i * chunk_duration + last_end_time = start + chunk_duration region = Segment(start, last_end_time) - chunk = self.audio_buffer.crop(region, fixed=asr_duration) + chunk = self.audio_buffer.crop(region, fixed=chunk_duration) window = SlidingWindow(resolution, resolution, start) asr_inputs.append(SlidingWindowFeature(chunk, window)) input_dia.append(self.dia_buffer.crop(region)) - # Create ASR batch, shape (batch, samples, channels) - batch = torch.stack([torch.from_numpy(w.data) for w in asr_inputs]) - - # Remove transcribed chunks from buffer - new_buffer_bounds = Segment(last_end_time, self.audio_buffer.extent.end) - new_buffer = self.audio_buffer.crop(new_buffer_bounds, fixed=new_buffer_bounds.duration) - window = SlidingWindow(resolution, resolution, last_end_time) - self.audio_buffer = SlidingWindowFeature(new_buffer, window) - self.dia_buffer = self.dia_buffer.extrude(Segment(0, last_end_time)) + # Remove extracted chunks from buffers + if asr_inputs: + new_buffer_bounds = Segment(last_end_time, self.audio_buffer.extent.end) + new_buffer = self.audio_buffer.crop(new_buffer_bounds, fixed=new_buffer_bounds.duration) + window = SlidingWindow(resolution, resolution, last_end_time) + self.audio_buffer = SlidingWindowFeature(new_buffer, window) + self.dia_buffer = self.dia_buffer.extrude(Segment(0, last_end_time)) - # Filter out non-speech chunks - has_voice = torch.tensor([dia.get_timeline().duration() > 0 for dia in input_dia]) - has_voice = torch.where(has_voice)[0] + return asr_inputs, input_dia - # Return empty list if no speech in the entire batch - if len(has_voice) == 0: - return ["" for _ in waveforms] - - # Transcribe batch - outputs = self.asr(batch[has_voice]) - - # Align transcription with diarization to determine speakers - full_transcription = [] + def _get_speaker_transcriptions( + self, + input_diarization: List[Annotation], + asr_inputs: List[SlidingWindowFeature], + asr_outputs: List[m.TranscriptionResult], + ) -> Text: + transcriptions = [] for i, waveform in enumerate(asr_inputs): - if i not in has_voice: + if waveform is None: continue buffer_shift = waveform.sliding_window.start - for text, timestamp in zip(outputs[i].chunks, outputs[i].timestamps): + for text, timestamp in zip(asr_outputs[i].chunks, asr_outputs[i].timestamps): if not text.strip(): continue target_region = Segment( buffer_shift + timestamp.start, buffer_shift + timestamp.end, ) - dia = input_dia[i].crop(target_region) + dia = input_diarization[i].crop(target_region) speakers = dia.labels() num_speakers = len(speakers) if num_speakers == 0: # Include transcription but don't assign a speaker - full_transcription.append(text) + transcriptions.append(text) elif num_speakers == 1: # Typical case, annotate text with the only speaker - full_transcription.append(f"[{speakers[0]}]{text}") + transcriptions.append(f"[{speakers[0]}]{text}") else: # Multiple speakers for the same text block, choose the most active one max_spk = np.argmax([dia.label_duration(spk) for spk in speakers]) - full_transcription.append(f"[{speakers[max_spk]}]{text}") + transcriptions.append(f"[{speakers[max_spk]}]{text}") + return " ".join(transcriptions).strip() + def __call__( + self, + waveforms: Sequence[SlidingWindowFeature], + ) -> Sequence[Text]: + # Compute diarization output + diarization_output = self.diarization(waveforms) + self._update_buffers(diarization_output) + + # Extract audio to transcribe from the buffer + asr_inputs, asr_input_dia = self._extract_asr_inputs() + if not asr_inputs: + return ["" for _ in waveforms] + + # Detect non-speech chunks + has_voice = torch.tensor([dia.get_timeline().duration() > 0 for dia in asr_input_dia]) + has_voice = torch.where(has_voice)[0] + # Return empty strings if no speech in the entire batch + if len(has_voice) == 0: + return ["" for _ in waveforms] + + # Create ASR batch, shape (batch, samples, channels) + batch = torch.stack([torch.from_numpy(w.data) for w in asr_inputs]) + + # Transcribe batch + asr_outputs = self.asr(batch[has_voice]) + asr_outputs = [ + asr_outputs[i] if i in has_voice else None + for i in range(batch.shape[0]) + ] + + # Attach speaker labels to ASR output and concatenate + transcription = self._get_speaker_transcriptions( + asr_input_dia, asr_inputs, asr_outputs + ) + + # Fill output sequence with empty strings batch_size = len(waveforms) - output = [" ".join(full_transcription).strip()] + output = [transcription] if batch_size > 1: output += [""] * (batch_size - 1) + return output diff --git a/src/diart/utils.py b/src/diart/utils.py index 018bc02c..725d9bf4 100644 --- a/src/diart/utils.py +++ b/src/diart/utils.py @@ -1,6 +1,6 @@ import base64 import time -from typing import Optional, Text, Union, Any, Dict +from typing import Optional, Text, Union, Any, Dict, Tuple import matplotlib.pyplot as plt import numpy as np @@ -92,7 +92,9 @@ def get_padding_right(latency: float, step: float) -> float: return latency - step -def serialize_prediction(value: Union[Annotation, Text]) -> Text: +def serialize_prediction(value: Union[Tuple, Annotation, Text]) -> Text: + if isinstance(value, tuple): + value = value[0] if isinstance(value, Annotation): return value.to_rttm() return value