Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workflow: annotate DNSMOS P.835 #1406

Merged
merged 5 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions lhotse/bin/modes/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,82 @@ def activity_detection(
supervisions.to_file(str(sups_path))

print("Results saved to:", str(sups_path), sep="\n")


@workflows.command()
@click.argument("out_cuts", type=click.Path(allow_dash=True))
@click.option(
"-m",
"--recordings-manifest",
type=click.Path(exists=True, dir_okay=False, allow_dash=True),
help="Path to an existing recording manifest.",
)
@click.option(
"-r",
"--recordings-dir",
type=click.Path(exists=True, file_okay=False),
help="Directory with recordings. We will create a RecordingSet for it automatically.",
)
@click.option(
"-c",
"--cuts-manifest",
type=click.Path(exists=True, dir_okay=False, allow_dash=True),
help="Path to an existing cuts manifest.",
)
@click.option(
"-e",
"--extension",
default="wav",
help="Audio file extension to search for. Used with RECORDINGS_DIR.",
)
@click.option(
"-p",
"--is-personalized-mos",
default=False,
help="Flag to indicate if personalized MOS score is needed or regular.",
)
@click.option("-j", "--jobs", default=1, help="Number of jobs for audio scanning.")
def annotate_dnsmos(
out_cuts: str,
recordings_manifest: Optional[str],
recordings_dir: Optional[str],
cuts_manifest: Optional[str],
extension: str,
is_personalized_mos: str,
jobs: int,
):
"""
Use Microsoft DNSMOS P.835 prediction model to annotate either RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST.
It will predict DNSMOS P.835 score including SIG, NAK, and OVRL.

See the original repo for more details: https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS

RECORDINGS_MANIFEST, RECORDINGS_DIR, and CUTS_MANIFEST are mutually exclusive. If CUTS_MANIFEST
is provided, its supervisions will be overwritten with the results of the inference.
"""
from lhotse import annotate_dnsmos as annotate_dnsmos_

assert exactly_one_not_null(recordings_manifest, recordings_dir, cuts_manifest), (
"Options RECORDINGS_MANIFEST, RECORDINGS_DIR, and CUTS_MANIFEST are mutually exclusive "
"and at least one is required."
)

if recordings_manifest is not None:
manifest = RecordingSet.from_file(recordings_manifest)
elif recordings_dir is not None:
manifest = RecordingSet.from_dir(
recordings_dir, pattern=f"*.{extension}", num_jobs=jobs
)
else:
manifest = CutSet.from_file(cuts_manifest).to_eager()

with CutSet.open_writer(out_cuts) as writer:
for cut in tqdm(
annotate_dnsmos_(
manifest,
is_personalized_mos=is_personalized_mos,
),
total=len(manifest),
desc="Annotating with DNSMOS P.835 prediction model",
):
writer.write(cut, flush=True)
1 change: 1 addition & 0 deletions lhotse/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .activity_detection import *
from .dnsmos import annotate_dnsmos
from .forced_alignment import align_with_torchaudio
from .meeting_simulation import *
from .whisper import annotate_with_whisper
213 changes: 213 additions & 0 deletions lhotse/workflows/dnsmos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import logging
import os
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Generator, List, Optional, Union

import numpy as np
from tqdm import tqdm

from lhotse import CutSet, MonoCut, RecordingSet, SupervisionSegment
from lhotse.utils import fastcopy, is_module_available, resumable_download


class ComputeScore:
def __init__(self, primary_model_path) -> None:
import onnxruntime as ort

self.onnx_sess = ort.InferenceSession(primary_model_path)
self.SAMPLING_RATE = 16000
self.INPUT_LENGTH = 9.01

def audio_melspec(
self, audio, n_mels=120, frame_size=320, hop_length=160, sr=16000, to_db=True
):
import librosa

mel_spec = librosa.feature.melspectrogram(
y=audio, sr=sr, n_fft=frame_size + 1, hop_length=hop_length, n_mels=n_mels
)
if to_db:
mel_spec = (librosa.power_to_db(mel_spec, ref=np.max) + 40) / 40
return mel_spec.T

def get_polyfit_val(self, sig, bak, ovr, is_personalized_mos):
if is_personalized_mos:
p_ovr = np.poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046])
p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726])
p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132])
else:
p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535])
p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439])
p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546])

sig_poly = p_sig(sig)
bak_poly = p_bak(bak)
ovr_poly = p_ovr(ovr)

return sig_poly, bak_poly, ovr_poly

def __call__(self, manifest, is_personalized_mos):
fs = self.SAMPLING_RATE
audio = manifest.resample(fs).load_audio()
len_samples = int(self.INPUT_LENGTH * fs)
while len(audio) < len_samples:
audio = np.append(audio, audio)

num_hops = int(np.floor(len(audio) / fs) - self.INPUT_LENGTH) + 1
hop_len_samples = fs
predicted_mos_sig_seg = []
predicted_mos_bak_seg = []
predicted_mos_ovr_seg = []

for idx in range(num_hops):
audio_seg = audio[
int(idx * hop_len_samples) : int(
(idx + self.INPUT_LENGTH) * hop_len_samples
)
]
if len(audio_seg) < len_samples:
continue

input_features = np.array(audio_seg).astype("float32")[np.newaxis, :]
oi = {"input_1": input_features}
mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run(None, oi)[0][0]
mos_sig, mos_bak, mos_ovr = self.get_polyfit_val(
mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_mos
)
predicted_mos_sig_seg.append(mos_sig)
predicted_mos_bak_seg.append(mos_bak)
predicted_mos_ovr_seg.append(mos_ovr)

return manifest, {
"OVRL": np.mean(predicted_mos_ovr_seg),
"SIG": np.mean(predicted_mos_sig_seg),
"BAK": np.mean(predicted_mos_bak_seg),
}


def download_model(
is_personalized_mos: bool = False,
download_root: Optional[str] = None,
) -> str:
download_root = download_root if download_root is not None else "/tmp"
url = (
"https://github.com/microsoft/DNS-Challenge/raw/refs/heads/master/DNSMOS/pDNSMOS/sig_bak_ovr.onnx"
if is_personalized_mos
else "https://github.com/microsoft/DNS-Challenge/raw/refs/heads/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx"
)
filename = os.path.join(download_root, "sig_bak_ovr.onnx")
resumable_download(url, filename=filename)
return filename


def annotate_dnsmos(
manifest: Union[RecordingSet, CutSet],
is_personalized_mos: bool = False,
download_root: Optional[str] = None,
) -> Generator[MonoCut, None, None]:
"""
Use Microsoft DNSMOS P.835 prediction model to annotate either RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST.
It will predict DNSMOS P.835 score including SIG, NAK, and OVRL.

See the original repo for more details: https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS

:param manifest: a ``RecordingSet`` or ``CutSet`` object.
:param is_personalized_mos: flag to indicate if personalized MOS score is needed or regular.
:param download_root: if specified, the model will be downloaded to this directory. Otherwise,
it will be downloaded to /tmp.
:return: a generator of cuts (use ``CutSet.open_writer()`` to write them).
"""
assert is_module_available("librosa"), (
"This function expects librosa to be installed. "
"You can install it via 'pip install librosa'"
)

assert is_module_available("onnxruntime"), (
"This function expects onnxruntime to be installed. "
"You can install it via 'pip install onnxruntime'"
)

if isinstance(manifest, RecordingSet):
yield from _annotate_recordings(
manifest,
is_personalized_mos,
download_root,
)
elif isinstance(manifest, CutSet):
yield from _annotate_cuts(
manifest,
is_personalized_mos,
download_root,
)
else:
raise ValueError("The ``manifest`` must be either a RecordingSet or a CutSet.")


def _annotate_recordings(
recordings: RecordingSet,
is_personalized_mos: bool = False,
download_root: Optional[str] = None,
):
"""
Helper function that annotates a RecordingSet with DNSMOS P.835 prediction model.
"""
primary_model_path = download_model(is_personalized_mos, download_root)
compute_score = ComputeScore(primary_model_path)

with ThreadPoolExecutor() as ex:
futures = []
for recording in tqdm(recordings, desc="Distributing tasks"):
if recording.num_channels > 1:
logging.warning(
f"Skipping recording '{recording.id}'. It has {recording.num_channels} channels, "
f"but we currently only support mono input."
)
continue
futures.append(ex.submit(compute_score, recording, is_personalized_mos))

for future in tqdm(futures, desc="Processing"):
recording, result = future.result()
supervision = SupervisionSegment(
id=recording.id,
recording_id=recording.id,
start=0,
duration=recording.duration,
)
cut = MonoCut(
id=recording.id,
start=0,
duration=recording.duration,
channel=0,
recording=recording,
supervisions=[supervision],
custom=result,
)
yield cut


def _annotate_cuts(
cuts: CutSet,
is_personalized_mos: bool = False,
download_root: Optional[str] = None,
):
"""
Helper function that annotates a CutSet with DNSMOS P.835 prediction model.
"""
primary_model_path = download_model(is_personalized_mos, download_root)
compute_score = ComputeScore(primary_model_path)

with ThreadPoolExecutor() as ex:
futures = []
for cut in tqdm(cuts, desc="Distributing tasks"):
if cut.num_channels > 1:
logging.warning(
f"Skipping cut '{cut.id}'. It has {cut.num_channels} channels, "
f"but we currently only support mono input."
)
continue
futures.append(ex.submit(compute_score, cut, is_personalized_mos))

for future in tqdm(futures, desc="Processing"):
cut, result = future.result()
new_cut = fastcopy(cut, custom=result)
yield new_cut
Loading