Skip to content

Commit

Permalink
Refactoring of diart.models to ease custom model usage
Browse files Browse the repository at this point in the history
  • Loading branch information
juanmc2005 committed Mar 10, 2023
1 parent abe5f66 commit 4b744ed
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 90 deletions.
45 changes: 31 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ from diart.sinks import RTTMWriter

pipeline = OnlineSpeakerDiarization()
mic = MicrophoneAudioSource(pipeline.config.sample_rate)
inference = RealTimeInference(pipeline, mic, do_plot=True)
inference = RealTimeInference(pipeline, mic)
inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm"))
prediction = inference()
```
Expand All @@ -127,28 +127,45 @@ For inference and evaluation on a dataset we recommend to use `Benchmark` (see n

## Custom models

Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`:
Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel` (which are PyTorch `Module` subclasses):

```python
import torch
from typing import Optional
from diart import OnlineSpeakerDiarization, PipelineConfig
from diart.models import EmbeddingModel
from diart.models import EmbeddingModel, SegmentationModel
from diart.sources import MicrophoneAudioSource
from diart.inference import RealTimeInference


def model_loader():
return load_pretrained_model("my_model.ckpt")


class MySegmentationModel(SegmentationModel):
def __init__(self):
super().__init__(model_loader)

@property
def sample_rate(self) -> int:
return 16000

@property
def duration(self) -> float:
return 2 # seconds

def forward(self, waveform):
# self.model is created lazily
return self.model(waveform)


class MyEmbeddingModel(EmbeddingModel):
def __init__(self):
super().__init__()
self.my_pretrained_model = load("my_model.ckpt")
super().__init__(model_loader)

def __call__(
self,
waveform: torch.Tensor,
weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
return self.my_pretrained_model(waveform, weights)
def forward(self, waveform, weights):
# self.model is created lazily
return self.model(waveform, weights)


config = PipelineConfig(embedding=MyEmbeddingModel())
pipeline = OnlineSpeakerDiarization(config)
mic = MicrophoneAudioSource(config.sample_rate)
Expand Down Expand Up @@ -225,7 +242,7 @@ from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding

segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation")
embedding = OverlapAwareSpeakerEmbedding.from_pyannote("pyannote/embedding")
sample_rate = segmentation.model.get_sample_rate()
sample_rate = segmentation.model.sample_rate
mic = MicrophoneAudioSource(sample_rate)

stream = mic.stream.pipe(
Expand Down
4 changes: 2 additions & 2 deletions src/diart/blocks/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def from_dict(data: Any) -> 'PipelineConfig':
@property
def duration(self) -> float:
if self._duration is None:
self._duration = self.segmentation.get_duration()
self._duration = self.segmentation.duration
return self._duration

@property
Expand All @@ -138,5 +138,5 @@ def latency(self) -> float:
@property
def sample_rate(self) -> int:
if self._sample_rate is None:
self._sample_rate = self.segmentation.get_sample_rate()
self._sample_rate = self.segmentation.sample_rate
return self._sample_rate
134 changes: 60 additions & 74 deletions src/diart/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Text, Union
from typing import Optional, Text, Union, Callable

import torch
import torch.nn as nn
Expand All @@ -10,77 +10,43 @@
_has_pyannote = False


class LazyModel(nn.Module):
@property
def model(self) -> Optional[nn.Module]:
raise NotImplementedError

def load(self):
"""Load model to memory"""
raise NotImplementedError

def is_in_memory(self) -> bool:
"""Return whether the model has been loaded into memory"""
return self.model is not None

def to(self, *args, **kwargs) -> nn.Module:
if not self.in_memory():
self.load_model()
return super().to(*args, **kwargs)

def __call__(self, *args, **kwargs):
if not self.is_in_memory():
self.load()
return super().__call__(*args, **kwargs)


class PyannoteModel(LazyModel):
class PyannoteLoader:
def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
super().__init__()
self.model_info = model_info
self.hf_token = hf_token
self._model: Optional[nn.Module] = None

@property
def model(self) -> Optional[nn.Module]:
return self._model
def __call__(self) -> nn.Module:
return pyannote_loader.get_model(self.model_info, self.hf_token)


class LazyModel(nn.Module):
def __init__(self, loader: Callable[[], nn.Module]):
super().__init__()
self.get_model = loader
self.model: Optional[nn.Module] = None

def is_in_memory(self) -> bool:
"""Return whether the model has been loaded into memory"""
return self.model is not None

def load(self):
"""Load model to memory"""
if not self.is_in_memory():
self._model = pyannote_loader.get_model(self.model_info, self.hf_token)

self.model = self.get_model()

class PyannoteSegmentationModel(PyannoteModel):
def get_sample_rate(self) -> int:
def to(self, *args, **kwargs) -> nn.Module:
self.load()
return self.model.audio.sample_rate
return super().to(*args, **kwargs)

def get_duration(self) -> float:
def __call__(self, *args, **kwargs):
self.load()
return self.model.specifications.duration

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
return self.model(waveform)


class PyannoteEmbeddingModel(PyannoteModel):
def forward(
self,
waveform: torch.Tensor,
weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(waveform, weights=weights)
return super().__call__(*args, **kwargs)


class SegmentationModel(nn.Module):
class SegmentationModel(LazyModel):
"""
Minimal interface for a segmentation model.
"""
def __init__(self, model: LazyModel):
super().__init__()
self.model = model

@staticmethod
def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'SegmentationModel':
"""
Expand All @@ -100,13 +66,15 @@ def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'Segme
wrapper: SegmentationModel
"""
assert _has_pyannote, "No pyannote.audio installation found"
return SegmentationModel(PyannoteSegmentationModel(model, use_hf_token))
return PyannoteSegmentationModel(model, use_hf_token)

def get_sample_rate(self) -> int:
return self.model.get_sample_rate()
@property
def sample_rate(self) -> int:
raise NotImplementedError

def get_duration(self) -> float:
return self.model.get_duration()
@property
def duration(self) -> float:
raise NotImplementedError

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -120,19 +88,29 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
-------
speaker_segmentation: torch.Tensor, shape (batch, frames, speakers)
"""
return self.model(waveform)
raise NotImplementedError

def to(self, *args, **kwargs) -> nn.Module:
self.model.load()
return super().to(*args, **kwargs)

class PyannoteSegmentationModel(SegmentationModel):
def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
super().__init__(PyannoteLoader(model_info, hf_token))

class EmbeddingModel(nn.Module):
"""Minimal interface for an embedding model."""
def __init__(self, model: LazyModel):
super().__init__()
self.model = model
@property
def sample_rate(self) -> int:
self.load()
return self.model.audio.sample_rate

@property
def duration(self) -> float:
self.load()
return self.model.specifications.duration

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
return self.model(waveform)


class EmbeddingModel(LazyModel):
"""Minimal interface for an embedding model."""
@staticmethod
def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'EmbeddingModel':
"""
Expand All @@ -152,7 +130,7 @@ def from_pyannote(model, use_hf_token: Union[Text, bool, None] = True) -> 'Embed
wrapper: EmbeddingModel
"""
assert _has_pyannote, "No pyannote.audio installation found"
return EmbeddingModel(PyannoteEmbeddingModel(model, use_hf_token))
return PyannoteEmbeddingModel(model, use_hf_token)

def forward(
self,
Expand All @@ -172,8 +150,16 @@ def forward(
-------
speaker_embeddings: torch.Tensor, shape (batch, embedding_dim)
"""
return self.model(waveform, weights)
raise NotImplementedError

def to(self, *args, **kwargs) -> nn.Module:
self.model.load()
return super().to(*args, **kwargs)

class PyannoteEmbeddingModel(EmbeddingModel):
def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
super().__init__(PyannoteLoader(model_info, hf_token))

def forward(
self,
waveform: torch.Tensor,
weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(waveform, weights=weights)

0 comments on commit 4b744ed

Please sign in to comment.