Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Test Suite #237

Merged
merged 8 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Pytest

on:
pull_request:
branches:
- main
- develop

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'

- name: Install apt dependencies
run: |
sudo add-apt-repository ppa:savoury1/ffmpeg4
sudo apt-get update
sudo apt-get -y install ffmpeg libportaudio2=19.6.0-1.1

- name: Install pip dependencies
run: |
python -m pip install --upgrade pip
pip install .[tests]

- name: Run tests
run: |
pytest
7 changes: 4 additions & 3 deletions .github/workflows/quick-runs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .
pip install onnxruntime==1.18.0
- name: Crop audio and rttm
run: |
sox audio/ES2002a_long.wav audio/ES2002a.wav trim 00:40 00:30
Expand All @@ -50,10 +51,10 @@ jobs:
rm rttms/ES2002b_long.rttm
- name: Run stream
run: |
diart.stream audio/ES2002a.wav --output trash --no-plot --hf-token ${{ secrets.HUGGINGFACE }}
diart.stream audio/ES2002a.wav --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx --output trash --no-plot
- name: Run benchmark
run: |
diart.benchmark audio --reference rttms --batch-size 4 --hf-token ${{ secrets.HUGGINGFACE }}
diart.benchmark audio --reference rttms --batch-size 4 --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx
- name: Run tuning
run: |
diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --hf-token ${{ secrets.HUGGINGFACE }}
diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx
Binary file added assets/models/embedding_uint8.onnx
Binary file not shown.
Binary file added assets/models/segmentation_uint8.onnx
Binary file not shown.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy>=1.20.2
matplotlib>=3.3.3
matplotlib>=3.3.3,<3.6.0
rx>=3.2.0
scipy>=1.6.0
sounddevice>=0.4.2
Expand Down
7 changes: 6 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package_dir=
packages=find:
install_requires=
numpy>=1.20.2
matplotlib>=3.3.3
matplotlib>=3.3.3,<3.6.0
rx>=3.2.0
scipy>=1.6.0
sounddevice>=0.4.2
Expand All @@ -41,6 +41,11 @@ install_requires=
websocket-client>=0.58.0
rich>=12.5.1

[options.extras_require]
tests=
pytest>=7.4.0,<8.0.0
onnxruntime==1.18.0

[options.packages.find]
where=src

Expand Down
48 changes: 48 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import random

import pytest
import torch

from diart.models import SegmentationModel, EmbeddingModel


class DummySegmentationModel:
def to(self, device):
pass

def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
assert waveform.ndim == 3

batch_size, num_channels, num_samples = waveform.shape
num_frames = random.randint(250, 500)
num_speakers = random.randint(3, 5)

return torch.rand(batch_size, num_frames, num_speakers)


class DummyEmbeddingModel:
def to(self, device):
pass

def __call__(self, waveform: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
assert waveform.ndim == 3
assert weights.ndim == 2

batch_size, num_channels, num_samples = waveform.shape
batch_size_weights, num_frames = weights.shape

assert batch_size == batch_size_weights

embedding_dim = random.randint(128, 512)

return torch.randn(batch_size, embedding_dim)


@pytest.fixture(scope="session")
def segmentation_model() -> SegmentationModel:
return SegmentationModel(DummySegmentationModel)


@pytest.fixture(scope="session")
def embedding_model() -> EmbeddingModel:
return EmbeddingModel(DummyEmbeddingModel)
Binary file added tests/data/audio/sample.wav
Binary file not shown.
13 changes: 13 additions & 0 deletions tests/data/rttm/latency_0.5.rttm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
SPEAKER sample 1 6.675 0.533 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 7.625 1.883 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 9.508 1.000 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 10.508 0.567 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 10.625 4.133 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 14.325 3.733 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 18.058 3.450 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 18.325 0.183 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.508 0.017 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.775 0.233 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 22.008 6.633 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 28.508 1.500 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 29.958 0.050 <NA> <NA> speaker0 <NA> <NA>
13 changes: 13 additions & 0 deletions tests/data/rttm/latency_1.rttm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
SPEAKER sample 1 6.708 0.450 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 7.625 1.383 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 9.008 1.500 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 10.008 1.067 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 10.592 4.200 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 14.308 3.700 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 18.042 3.250 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 18.508 0.033 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.108 0.383 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.508 0.033 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 21.775 6.817 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 28.008 2.000 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 29.975 0.033 <NA> <NA> speaker0 <NA> <NA>
10 changes: 10 additions & 0 deletions tests/data/rttm/latency_2.rttm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SPEAKER sample 1 6.725 0.433 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 7.592 0.817 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 8.475 1.617 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 9.892 1.150 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 10.625 4.133 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 14.292 3.667 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 18.008 3.533 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 18.225 0.283 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.758 6.867 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 27.875 2.133 <NA> <NA> speaker1 <NA> <NA>
10 changes: 10 additions & 0 deletions tests/data/rttm/latency_3.rttm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SPEAKER sample 1 6.725 0.433 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 7.625 0.467 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 8.008 2.050 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 9.875 1.167 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 10.592 4.167 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 14.292 3.667 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 17.992 3.550 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 18.192 0.367 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.758 6.833 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 27.825 2.183 <NA> <NA> speaker1 <NA> <NA>
10 changes: 10 additions & 0 deletions tests/data/rttm/latency_4.rttm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SPEAKER sample 1 6.742 0.400 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 7.625 0.650 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 8.092 1.950 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 9.875 1.167 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 10.575 4.183 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 14.308 3.667 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 17.992 3.550 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 18.208 0.333 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.758 6.817 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 27.808 2.200 <NA> <NA> speaker1 <NA> <NA>
10 changes: 10 additions & 0 deletions tests/data/rttm/latency_5.rttm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SPEAKER sample 1 6.742 0.383 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 7.625 0.667 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 8.092 1.967 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 9.875 1.167 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 10.558 4.200 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 14.308 3.667 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 17.992 3.550 <NA> <NA> speaker1 <NA> <NA>
SPEAKER sample 1 18.208 0.317 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 21.758 6.817 <NA> <NA> speaker0 <NA> <NA>
SPEAKER sample 1 27.808 2.200 <NA> <NA> speaker1 <NA> <NA>
54 changes: 54 additions & 0 deletions tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import pytest
from pyannote.core import SlidingWindow, SlidingWindowFeature

from diart.blocks.aggregation import (
AggregationStrategy,
HammingWeightedAverageStrategy,
FirstOnlyStrategy,
AverageStrategy,
DelayedAggregation,
)


def test_strategy_build():
strategy = AggregationStrategy.build("mean")
assert isinstance(strategy, AverageStrategy)

strategy = AggregationStrategy.build("hamming")
assert isinstance(strategy, HammingWeightedAverageStrategy)

strategy = AggregationStrategy.build("first")
assert isinstance(strategy, FirstOnlyStrategy)

with pytest.raises(Exception):
AggregationStrategy.build("invalid")


def test_aggregation():
duration = 5
frames = 500
step = 0.5
speakers = 2
start_time = 10
resolution = duration / frames

dagg1 = DelayedAggregation(step=step, latency=2, strategy="mean")
dagg2 = DelayedAggregation(step=step, latency=2, strategy="hamming")
dagg3 = DelayedAggregation(step=step, latency=2, strategy="first")

for dagg in [dagg1, dagg2, dagg3]:
assert dagg.num_overlapping_windows == 4

buffers = [
SlidingWindowFeature(
np.random.rand(frames, speakers),
SlidingWindow(
start=(i + start_time) * step, duration=resolution, step=resolution
),
)
for i in range(dagg1.num_overlapping_windows)
]

for dagg in [dagg1, dagg2, dagg3]:
assert dagg(buffers).data.shape == (51, 2)
Loading
Loading