Skip to content

Commit

Permalink
National Speech Corpus data prep, optimizations in Cut, limited expor…
Browse files Browse the repository at this point in the history
…t to Kaldi data dir (#149)

* NSC data prep
* Supervisions indexing inside an interval tree in select CutSet operations
* Kaldi data dir export (limited)
* Immutable SupervisionSegment
  • Loading branch information
pzelasko authored Nov 24, 2020
1 parent 76d4e6e commit 21ee379
Show file tree
Hide file tree
Showing 11 changed files with 440 additions and 18 deletions.
34 changes: 32 additions & 2 deletions lhotse/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,16 @@ class Recording:
duration: Seconds

@staticmethod
def from_sphere(sph_path: Pathlike, relative_path_depth: Optional[int] = None) -> 'Recording':
def from_sphere(
sph_path: Pathlike,
recording_id: Optional[str] = None,
relative_path_depth: Optional[int] = None
) -> 'Recording':
"""
Read a SPHERE file's header and create the corresponding ``Recording``.
:param sph_path: Path to the sphere (.sph) file.
:param recording_id: recording id, when not specified ream the filename's stem ("x.wav" -> "x").
:param relative_path_depth: optional int specifying how many last parts of the file path
should be retained in the ``AudioSource``. By default writes the path as is.
:return: a new ``Recording`` instance pointing to the sphere file.
Expand All @@ -117,7 +122,7 @@ def from_sphere(sph_path: Pathlike, relative_path_depth: Optional[int] = None) -
sph_path = Path(sph_path)
sphf = SPHFile(sph_path)
return Recording(
id=sph_path.stem,
id=recording_id if recording_id is not None else sph_path.stem,
sampling_rate=sphf.format['sample_rate'],
num_samples=sphf.format['sample_count'],
duration=sphf.format['sample_count'] / sphf.format['sample_rate'],
Expand All @@ -134,6 +139,31 @@ def from_sphere(sph_path: Pathlike, relative_path_depth: Optional[int] = None) -
]
)

@staticmethod
def from_wav(path: Pathlike, recording_id: Optional[str] = None) -> 'Recording':
"""
Read a WAVE file's header and create the corresponding ``Recording``.
:param path: Path to the WAVE (.wav) file.
:param recording_id: recording id, when not specified ream the filename's stem ("x.wav" -> "x").
:return: a new ``Recording`` instance pointing to the sphere file.
"""
from soundfile import SoundFile
with SoundFile(path) as sf:
return Recording(
id=recording_id if recording_id is not None else Path(path).stem,
sampling_rate=sf.samplerate,
num_samples=sf.frames,
duration=sf.frames / sf.samplerate,
sources=[
AudioSource(
type='file',
channels=list(range(sf.channels)),
source=str(path)
)
]
)

@property
def num_channels(self):
return sum(len(source.channels) for source in self.sources)
Expand Down
28 changes: 24 additions & 4 deletions lhotse/bin/modes/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,23 @@

import click

from lhotse import load_manifest
from lhotse.bin.modes.cli_base import cli
from lhotse.kaldi import load_kaldi_data_dir
from lhotse.kaldi import export_to_kaldi, load_kaldi_data_dir
from lhotse.utils import Pathlike

__all__ = ['convert_kaldi']

@cli.group()
def kaldi():
"""Kaldi import/export related commands."""
pass

@cli.command()

@kaldi.command(name='import')
@click.argument('data_dir', type=click.Path(exists=True, file_okay=False))
@click.argument('sampling_rate', type=int)
@click.argument('manifest_dir', type=click.Path())
def convert_kaldi(data_dir: Pathlike, sampling_rate: int, manifest_dir: Pathlike):
def import_(data_dir: Pathlike, sampling_rate: int, manifest_dir: Pathlike):
"""
Convert a Kaldi data dir DATA_DIR into a directory MANIFEST_DIR of lhotse manifests. Ignores feats.scp.
The SAMPLING_RATE has to be explicitly specified as it is not available to read from DATA_DIR.
Expand All @@ -24,3 +29,18 @@ def convert_kaldi(data_dir: Pathlike, sampling_rate: int, manifest_dir: Pathlike
recording_set.to_json(manifest_dir / 'audio.json')
if maybe_supervision_set is not None:
maybe_supervision_set.to_json(manifest_dir / 'supervision.json')


@kaldi.command()
@click.argument('recordings', type=click.Path(exists=True, dir_okay=False))
@click.argument('supervisions', type=click.Path(exists=True, dir_okay=False))
@click.argument('output_dir', type=click.Path())
def export(recordings: Pathlike, supervisions: Pathlike, output_dir: Pathlike):
"""
Convert a pair of ``RecordingSet`` and ``SupervisionSet`` manifests into a Kaldi-style data directory.
"""
export_to_kaldi(
recordings=load_manifest(recordings),
supervisions=load_manifest(supervisions),
output_dir=output_dir
)
1 change: 1 addition & 0 deletions lhotse/bin/modes/recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .heroico import *
from .librimix import *
from .mini_librispeech import *
from .nsc import *
from .switchboard import *
from .tedlium import *
21 changes: 21 additions & 0 deletions lhotse/bin/modes/recipes/nsc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import click

from lhotse.bin.modes import prepare
from lhotse.recipes.nsc import NSC_PARTS, prepare_nsc
from lhotse.utils import Pathlike


@prepare.command(context_settings=dict(show_default=True))
@click.argument('corpus_dir', type=click.Path(exists=True, dir_okay=True))
@click.argument('output_dir', type=click.Path())
@click.option('-p', '--dataset-part', type=click.Choice(NSC_PARTS), default='PART3_SameCloseMic',
help='Which part of NSC should be prepared')
def nsc(
corpus_dir: Pathlike,
output_dir: Pathlike,
dataset_part: str
):
"""
This is a data preparation recipe for the National Corpus of Speech in Singaporean English.
"""
prepare_nsc(corpus_dir, dataset_part=dataset_part, output_dir=output_dir)
66 changes: 57 additions & 9 deletions lhotse/cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
from cytoolz import sliding_window
from cytoolz.itertoolz import groupby
from intervaltree import Interval, IntervalTree
from tqdm.auto import tqdm

from lhotse.audio import AudioMixer, Recording, RecordingSet
Expand Down Expand Up @@ -351,7 +352,8 @@ def truncate(
offset: Seconds = 0.0,
duration: Optional[Seconds] = None,
keep_excessive_supervisions: bool = True,
preserve_id: bool = False
preserve_id: bool = False,
_supervisions_index: Optional[Dict[str, IntervalTree]] = None
) -> 'Cut':
"""
Returns a new Cut that is a sub-region of the current Cut.
Expand All @@ -366,6 +368,8 @@ def truncate(
:param keep_excessive_supervisions: bool. Since trimming may happen inside a SupervisionSegment,
the caller has an option to either keep or discard such supervisions.
:param preserve_id: bool. Should the truncated cut keep the same ID or get a new, random one.
:param _supervisions_index: when passed, allows to speed up processing of Cuts with a very
large number of supervisions. Intended as an internal parameter.
:return: a new Cut instance. If the current Cut is shorter than the duration, return None.
"""
new_start = self.start + offset
Expand All @@ -379,17 +383,32 @@ def truncate(
# Round the duration to avoid the possible loss of a single audio sample due to floating point
# additions and subtractions.
new_duration = round(new_duration, ndigits=8)
new_time_span = TimeSpan(start=0, end=new_duration)
criterion = overlaps if keep_excessive_supervisions else overspans
new_supervisions = (segment.with_offset(-offset) for segment in self.supervisions)

if _supervisions_index is None:
criterion = overlaps if keep_excessive_supervisions else overspans
new_time_span = TimeSpan(start=0, end=new_duration)
new_supervisions = (segment.with_offset(-offset) for segment in self.supervisions)
supervisions = [
segment for segment in new_supervisions if criterion(new_time_span, segment)
]
else:
tree = _supervisions_index[self.id]
# Below we select which method should be called on the IntervalTree object.
# The result of calling that method with a range of (begin, end) is an iterable
# of Intervals that contain the SupervisionSegments matching our criterion.
# We call "interval.data" to obtain the underlying SupervisionSegment.
match_supervisions = tree.overlap if keep_excessive_supervisions else tree.envelop
supervisions = [
interval.data.with_offset(-offset)
for interval in match_supervisions(begin=offset, end=offset + new_duration)
]

return Cut(
id=self.id if preserve_id else str(uuid4()),
start=new_start,
duration=new_duration,
channel=self.channel,
supervisions=[
segment for segment in new_supervisions if criterion(new_time_span, segment)
],
supervisions=sorted(supervisions, key=lambda s: s.start),
features=self.features,
recording=self.recording
)
Expand Down Expand Up @@ -516,6 +535,7 @@ def truncate(
duration: Optional[Seconds] = None,
keep_excessive_supervisions: bool = True,
preserve_id: bool = False,
**kwargs
) -> 'PaddingCut':
new_duration = self.duration - offset if duration is None else duration
assert new_duration > 0.0
Expand Down Expand Up @@ -677,6 +697,7 @@ def truncate(
duration: Optional[Seconds] = None,
keep_excessive_supervisions: bool = True,
preserve_id: bool = False,
_supervisions_index: Optional[Dict[str, IntervalTree]] = None
) -> 'MixedCut':
"""
Returns a new MixedCut that is a sub-region of the current MixedCut. This method truncates the underlying Cuts
Expand Down Expand Up @@ -734,7 +755,8 @@ def truncate(
offset=cut_offset,
duration=new_duration,
keep_excessive_supervisions=keep_excessive_supervisions,
preserve_id=preserve_id
preserve_id=preserve_id,
_supervisions_index=_supervisions_index
),
offset=track_offset,
snr=track.snr
Expand Down Expand Up @@ -1138,8 +1160,10 @@ def trim_to_supervisions(self) -> 'CutSet':
:return: a ``CutSet``.
"""
supervisions_index = self.index_supervisions(index_mixed_tracks=True)
return CutSet.from_cuts(
cut.truncate(offset=segment.start, duration=segment.duration)
cut.truncate(offset=segment.start, duration=segment.duration,
_supervisions_index=supervisions_index)
for cut in self
for segment in cut.supervisions
)
Expand Down Expand Up @@ -1200,6 +1224,30 @@ def sort_by_duration(self, ascending: bool = False) -> 'CutSet':
"""Sort the CutSet according to cuts duration. Descending by default."""
return CutSet.from_cuts(sorted(self, key=(lambda cut: cut.duration), reverse=not ascending))

def index_supervisions(self, index_mixed_tracks: bool = False) -> Dict[str, IntervalTree]:
"""
Create a two-level index of supervision segments. It is a mapping from a Cut's ID to an
interval tree that contains the supervisions of that Cut.
The interval tree can be efficiently queried for overlapping and/or enveloping segments.
It helps speed up some operations on Cuts of very long recordings (1h+) that contain many
supervisions.
:param index_mixed_tracks: Should the tracks of MixedCut's be indexed as additional, separate entries.
:return: a mapping from Cut ID to an interval tree of SupervisionSegments.
"""
indexed = {
cut.id: IntervalTree(Interval(s.start, s.end, s) for s in cut.supervisions)
for cut in self
}
if index_mixed_tracks:
for cut in self:
if isinstance(cut, MixedCut):
for track in cut.tracks:
indexed[track.cut.id] = IntervalTree(
Interval(s.start, s.end, s) for s in track.cut.supervisions)
return indexed

def pad(
self,
duration: Seconds = None,
Expand Down
81 changes: 80 additions & 1 deletion lhotse/kaldi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections import defaultdict
from pathlib import Path
from typing import Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple

from lhotse import CutSet
from lhotse.audio import AudioSource, Recording, RecordingSet
from lhotse.supervision import SupervisionSegment, SupervisionSet
from lhotse.utils import Pathlike
Expand Down Expand Up @@ -78,6 +79,77 @@ def load_kaldi_data_dir(path: Pathlike, sampling_rate: int) -> Tuple[RecordingSe
return audio_set, supervision_set


def export_to_kaldi(recordings: RecordingSet, supervisions: SupervisionSet, output_dir: Pathlike):
"""
Export a pair of ``RecordingSet`` and ``SupervisionSet`` to a Kaldi data directory.
Currently, it only supports single-channel recordings that have a single ``AudioSource``.
The ``RecordingSet`` and ``SupervisionSet`` must be compatible, i.e. it must be possible to create a
``CutSet`` out of them.
:param recordings: a ``RecordingSet`` manifest.
:param supervisions: a ``SupervisionSet`` manifest.
:param output_dir: path where the Kaldi-style data directory will be created.
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

assert all(len(r.sources) == 1 for r in recordings), "Kaldi export of Recordings with multiple audio sources " \
"is currently not supported."
assert all(r.num_channels == 1 for r in recordings), "Kaldi export of multi-channel Recordings is currently " \
"not supported."

# Create a simple CutSet that ties together the recording <-> supervision information.
cuts = CutSet.from_manifests(recordings=recordings, supervisions=supervisions).trim_to_supervisions()

# wav.scp
save_kaldi_text_mapping(
data={
recording.id: f'{source.source} |' if source.type == 'command' else source.source
for recording in recordings
for src_idx, source in enumerate(recording.sources)
},
path=output_dir / 'wav.scp'
)
# segments
save_kaldi_text_mapping(
data={cut.supervisions[0].id: f'{cut.recording_id} {cut.start} {cut.end}' for cut in cuts},
path=output_dir / 'segments'
)
# text
save_kaldi_text_mapping(
data={cut.supervisions[0].id: cut.supervisions[0].text for cut in cuts},
path=output_dir / 'text'
)
# utt2spk
save_kaldi_text_mapping(
data={cut.supervisions[0].id: cut.supervisions[0].speaker for cut in cuts},
path=output_dir / 'utt2spk'
)
# utt2dur
save_kaldi_text_mapping(
data={cut.supervisions[0].id: cut.duration for cut in cuts},
path=output_dir / 'utt2dur'
)
# reco2dur
save_kaldi_text_mapping(
data={recording.id: recording.duration for recording in recordings},
path=output_dir / 'reco2dur'
)
# utt2lang [optional]
if all(s.language is not None for s in supervisions):
save_kaldi_text_mapping(
data={cut.supervisions[0].id: cut.supervisions[0].language for cut in cuts},
path=output_dir / 'utt2lang'
)
# utt2gender [optional]
if all(s.gender is not None for s in supervisions):
save_kaldi_text_mapping(
data={cut.supervisions[0].id: cut.supervisions[0].gender for cut in cuts},
path=output_dir / 'utt2gender'
)


def load_kaldi_text_mapping(path: Path, must_exist: bool = False) -> Dict[str, Optional[str]]:
"""Load Kaldi files such as utt2spk, spk2gender, text, etc. as a dict."""
mapping = defaultdict(lambda: None)
Expand All @@ -87,3 +159,10 @@ def load_kaldi_text_mapping(path: Path, must_exist: bool = False) -> Dict[str, O
elif must_exist:
raise ValueError(f"No such file: {path}")
return mapping


def save_kaldi_text_mapping(data: Dict[str, Any], path: Path):
"""Save flat dicts to Kaldi files such as utt2spk, spk2gender, text, etc."""
with path.open('w') as f:
for key, value in data.items():
print(key, value, file=f)
Loading

0 comments on commit 21ee379

Please sign in to comment.