Skip to content

Custom models

Juan Coria edited this page Sep 2, 2022 · 2 revisions

Diart is compatible with models that are not trained with pyannote.audio. However, an additional effort is required from the user so that the 3rd party model satisfies the expected interface. This allows diart to run without knowing how models actually work internally.

Segmentation

A segmentation model must ingest a batch of audio chunks and return the corresponding per-speaker activity probabilities across time. It must also define the expected sample rate and duration of its inputs so that the pipeline knows how to format the audio stream.

Example

from diart.models import SegmentationModel

class MySegmentationModel(SegmentationModel):
    def __init__(self):
        self.my_pretrained_model = load("my_segmentation.ckpt")

    def get_sample_rate(self) -> int:
        return 16000

    def get_duration(self) -> float:
        return 2  # seconds

    def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
        # waveform has shape (batch, channels, samples)
        # ... operations to adapt the input to this specific model (e.g. converting to TensorFlow)
        output = self.my_pretrained_model(waveform)
        # ... operations to adapt the output to a torch.Tensor of shape (batch, frames, speakers)
        return output

Speaker Embedding

A speaker embedding model must ingest a batch of audio chunks and output a batch of speaker embeddings. Optional weights are usually provided to inform the model where it should focus its attention.

Example

from diart.models import EmbeddingModel

class MyEmbeddingModel(EmbeddingModel):
    def __init__(self):
        super().__init__()
        self.my_pretrained_model = load("my_embedding_model.ckpt")
    
    def __call__(
        self,
        waveform: torch.Tensor,
        weights: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        # waveform has shape (batch, channels, samples)
        # weights have shape (batch, frames)
        # The output should have shape (batch, embedding_dim)
        return self.my_pretrained_model(waveform, weights)

Replacing default models

Models can be easily replaced from the configuration object PipelineConfig:

from diart.pipelines import PipelineConfig, OnlineSpeakerDiarization

config = PipelineConfig(segmentation=MySegmentationModel(), embedding=MyEmbeddingModel())
diarization = OnlineSpeakerDiarization(config)

...
Clone this wiki locally