Skip to content

Commit

Permalink
Improving and fixing trim_to_supervisions, adding more documentation,…
Browse files Browse the repository at this point in the history
… fixes to filter_supervisions (#330)

* Fixes, more tests, and documentation for `trim_to_supervisions`; add `drop_recording` and `drop_supervisions`

* Add docs about `drop_X` and cuts not being mutated by their member methods
  • Loading branch information
pzelasko authored Jul 4, 2021
1 parent 0749ab9 commit 28e56fd
Show file tree
Hide file tree
Showing 3 changed files with 375 additions and 51 deletions.
153 changes: 130 additions & 23 deletions lhotse/cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,16 @@ class Cut:
All cut transformations are performed lazily, on-the-fly, upon calling ``load_audio`` or ``load_features``.
The stored waveforms and features are untouched.
.. caution::
Operations on cuts are not mutating -- they return modified copies of :class:`.Cut` objects,
leaving the original object unmodified.
Cuts can be detached from parts of their metadata::
>>> cut_no_feat = cut.drop_features()
>>> cut_no_rec = cut.drop_recording()
>>> cut_no_sup = cut.drop_supervisions()
Finally, cuts provide convenience methods to compute feature frame and audio sample masks for supervised regions::
>>> sup_frames = cut.supervisions_feature_mask()
Expand All @@ -138,6 +148,7 @@ class Cut:
# The following is the list of members and properties implemented by the child classes.
# They are not abstract properties because dataclasses do not work well with the "abc" module.
id: str
start: Seconds
duration: Seconds
sampling_rate: int
supervisions: List[SupervisionSegment]
Expand All @@ -157,6 +168,8 @@ class Cut:
load_features: Callable[[], np.ndarray]
compute_and_store_features: Callable
drop_features: Callable
drop_recording: Callable
drop_supervisions: Callable
truncate: Callable
pad: Callable
resample: Callable
Expand Down Expand Up @@ -521,10 +534,19 @@ def load_audio(self) -> Optional[np.ndarray]:
return None

def drop_features(self) -> 'MonoCut':
"""Return a copy of the current :class:`MonoCut`, detached from ``features``."""
"""Return a copy of the current :class:`.MonoCut`, detached from ``features``."""
assert self.has_recording, f"Cannot detach features from a MonoCut with no Recording (cut ID = {self.id})."
return fastcopy(self, features=None)

def drop_recording(self) -> 'MonoCut':
"""Return a copy of the current :class:`.MonoCut`, detached from ``recording``."""
assert self.has_features, f"Cannot detach recording from a MonoCut with no Features (cut ID = {self.id})."
return fastcopy(self, recording=None)

def drop_supervisions(self) -> 'MonoCut':
"""Return a copy of the current :class:`.MonoCut`, detached from ``supervisions``."""
return fastcopy(self, supervisions=[])

def compute_and_store_features(
self,
extractor: FeatureExtractor,
Expand Down Expand Up @@ -610,7 +632,7 @@ def truncate(
# We are going to measure the overlap ratio of the supervision with the "truncated" cut
# and reject segments that overlap less than 1%. This way we can avoid quirks and errors
# of limited float precision.
olap_ratio = measure_overlap(interval.data, TimeSpan(new_start, new_start + new_duration))
olap_ratio = measure_overlap(interval.data, TimeSpan(offset, offset + new_duration))
if olap_ratio > 0.01:
supervisions.append(interval.data.with_offset(-offset))

Expand Down Expand Up @@ -948,10 +970,19 @@ def perturb_speed(self, factor: float, affix_id: bool = True) -> 'PaddingCut':
)

def drop_features(self) -> 'PaddingCut':
"""Return a copy of the current :class:`PaddingCut`, detached from ``features``."""
"""Return a copy of the current :class:`.PaddingCut`, detached from ``features``."""
assert self.has_recording, f"Cannot detach features from a MonoCut with no Recording (cut ID = {self.id})."
return fastcopy(self, num_frames=None, num_features=None, frame_shift=None)

def drop_recording(self) -> 'PaddingCut':
"""Return a copy of the current :class:`.PaddingCut`, detached from ``recording``."""
assert self.has_features, f"Cannot detach recording from a PaddingCut with no Features (cut ID = {self.id})."
return fastcopy(self, num_samples=None)

def drop_supervisions(self) -> 'PaddingCut':
"""Return a copy of the current :class:`.PaddingCut`, detached from ``supervisions``."""
return self

def compute_and_store_features(self, extractor: FeatureExtractor, *args, **kwargs) -> Cut:
"""
Returns a new PaddingCut with updates information about the feature dimension and number of
Expand Down Expand Up @@ -1146,7 +1177,7 @@ def truncate(
# when the track offset is larger than the truncation offset, we are not truncating the cut;
# just decreasing the track offset.

# 'cut_offset' determines how much we're going to truncate the MonoCut for the current track.
# 'cut_offset' determines how much we're going to truncate the Cut for the current track.
cut_offset = max(offset - track.offset, 0)
# 'track_offset' determines the new track's offset after truncation.
track_offset = max(track.offset - offset, 0)
Expand Down Expand Up @@ -1406,9 +1437,18 @@ def plot_tracks_audio(self):

def drop_features(self) -> 'MixedCut':
"""Return a copy of the current :class:`MixedCut`, detached from ``features``."""
assert self.has_recording, f"Cannot detach features from a MonoCut with no Recording (cut ID = {self.id})."
assert self.has_recording, f"Cannot detach features from a MixedCut with no Recording (cut ID = {self.id})."
return fastcopy(self, tracks=[fastcopy(t, cut=t.cut.drop_features()) for t in self.tracks])

def drop_recording(self) -> 'MixedCut':
"""Return a copy of the current :class:`.MixedCut`, detached from ``recording``."""
assert self.has_features, f"Cannot detach recording from a MixedCut with no Features (cut ID = {self.id})."
return fastcopy(self, tracks=[fastcopy(t, cut=t.cut.drop_recording()) for t in self.tracks])

def drop_supervisions(self) -> 'MixedCut':
"""Return a copy of the current :class:`.MixedCut`, detached from ``supervisions``."""
return fastcopy(self, tracks=[fastcopy(t, cut=t.cut.drop_supervisions()) for t in self.tracks])

def compute_and_store_features(
self,
extractor: FeatureExtractor,
Expand Down Expand Up @@ -1487,9 +1527,13 @@ def filter_supervisions(self, predicate: Callable[[SupervisionSegment], bool]) -
:param predicate: A callable that accepts `SupervisionSegment` and returns bool
:return: a modified MonoCut
"""
new_mixed_cut = fastcopy(self)
for track in new_mixed_cut.tracks:
track.cut = track.cut.filter_supervisions(predicate)
new_mixed_cut = fastcopy(
self,
tracks=[
fastcopy(track, cut=track.cut.filter_supervisions(predicate))
for track in self.tracks
]
)
return new_mixed_cut

@staticmethod
Expand Down Expand Up @@ -1577,6 +1621,10 @@ class CutSet(Serializable, Sequence[Cut]):
>>> cuts_B = cuts.cut_into_windows(duration=5.0)
>>> cuts_C = cuts.trim_to_unsupervised_segments()
.. caution::
Operations on cut sets are not mutating -- they return modified copies of :class:`.CutSet` objects,
leaving the original object unmodified (and all of its cuts are also unmodified).
:class:`~lhotse.cut.CutSet` can be stored and read from JSON, JSONL, etc. and supports optional gzip compression::
>>> cuts.to_file('cuts.jsonl.gz')
Expand Down Expand Up @@ -1621,6 +1669,12 @@ class CutSet(Serializable, Sequence[Cut]):
>>> random_sample = cuts.sample(n_cuts=10)
>>> new_ids = cuts.modify_ids(lambda c: c.id + '-newid')
Cuts in a :class:`.CutSet` can be detached from parts of their metadata::
>>> cuts_no_feat = cuts.drop_features()
>>> cuts_no_rec = cuts.drop_recordings()
>>> cuts_no_sup = cuts.drop_supervisions()
Sometimes specific sorting patterns are useful when a small CutSet represents a mini-batch::
>>> cuts = cuts.sort_by_duration(ascending=False)
Expand Down Expand Up @@ -1915,19 +1969,60 @@ def filter(self, predicate: Callable[[Cut], bool]) -> 'CutSet':
"""
return CutSet.from_cuts(cut for cut in self if predicate(cut))

def trim_to_supervisions(self) -> 'CutSet':
def trim_to_supervisions(self, keep_overlapping: bool = True) -> 'CutSet':
"""
Return a new CutSet with Cuts that have identical spans as their supervisions.
For example, the following cut::
Cut
|-----------------|
Sup1
|----| Sup2
|-----------|
is transformed into two cuts::
Cut1
|----|
Sup1
|----|
Sup2
|-|
Cut2
|-----------|
Sup1
|-|
Sup2
|-----------|
:param keep_overlapping: when ``False``, it will discard parts of other supervisions that overlap with the
main supervision. In the illustration above, it would discard ``Sup2`` in ``Cut1`` and ``Sup1`` in ``Cut2``.
:return: a ``CutSet``.
"""
supervisions_index = self.index_supervisions(index_mixed_tracks=True)
return CutSet.from_cuts(
cut.truncate(offset=segment.start, duration=segment.duration,
_supervisions_index=supervisions_index)
for cut in self
for segment in cut.supervisions
)
if keep_overlapping:
supervisions_index = self.index_supervisions(index_mixed_tracks=True)
return CutSet.from_cuts(
cut.truncate(offset=segment.start, duration=segment.duration,
_supervisions_index=supervisions_index)
for cut in self
for segment in cut.supervisions
)
else:
# If we're not going to keep overlapping supervision, we can use a slightly faster variant
# that doesn't require indexing and search of supervisions in an interval tree.
return CutSet.from_cuts(
(
cut.filter_supervisions(
lambda s: s.id == segment.id
).truncate(
offset=segment.start,
duration=segment.duration
)
)
for cut in self
for segment in cut.supervisions
)

def trim_to_unsupervised_segments(self) -> 'CutSet':
"""
Expand Down Expand Up @@ -1996,8 +2091,8 @@ def sort_like(self, other: 'CutSet') -> 'CutSet':

def index_supervisions(self, index_mixed_tracks: bool = False) -> Dict[str, IntervalTree]:
"""
Create a two-level index of supervision segments. It is a mapping from a MonoCut's ID to an
interval tree that contains the supervisions of that MonoCut.
Create a two-level index of supervision segments. It is a mapping from a Cut's ID to an
interval tree that contains the supervisions of that Cut.
The interval tree can be efficiently queried for overlapping and/or enveloping segments.
It helps speed up some operations on Cuts of very long recordings (1h+) that contain many
Expand Down Expand Up @@ -2244,10 +2339,22 @@ def mix(

def drop_features(self) -> 'CutSet':
"""
Return a new :class:`CutSet`, where each MonoCut is copied and detached from its extracted features.
Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from its extracted features.
"""
return CutSet.from_cuts(c.drop_features() for c in self)

def drop_recordings(self) -> 'CutSet':
"""
Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from its recordings.
"""
return CutSet.from_cuts(c.drop_recording() for c in self)

def drop_supervisions(self) -> 'CutSet':
"""
Return a new :class:`.CutSet`, where each :class:`.Cut` is copied and detached from its supervisions.
"""
return CutSet.from_cuts(c.drop_supervisions() for c in self)

def compute_and_store_features(
self,
extractor: FeatureExtractor,
Expand Down Expand Up @@ -2424,7 +2531,7 @@ def compute_and_store_recordings(
executor: Optional[Executor] = None,
augment_fn: Optional[AugmentFn] = None,
progress_bar: bool = True
) -> 'CutSet':
) -> 'CutSet':
"""
Store waveforms of all cuts as audio recordings to disk.
Expand Down Expand Up @@ -2554,16 +2661,16 @@ def with_recording_path_prefix(self, path: Pathlike) -> 'CutSet':

def map(self, transform_fn: Callable[[Cut], Cut]) -> 'CutSet':
"""
Modify the cuts in this ``CutSet`` and return a new ``CutSet``.
Apply `transform_fn` to the cuts in this :class:`.CutSet` and return a new :class:`.CutSet`.
:param transform_fn: A callable (function) that accepts a single cut instance
and returns a single cut instance.
:return: a new ``CutSet`` with modified cuts.
:return: a new ``CutSet`` with transformed cuts.
"""

def verified(mapped: Any) -> Cut:
assert isinstance(mapped, (MonoCut, MixedCut, PaddingCut)), \
"The callable passed to CutSet.map() must return a MonoCut class instance."
"The callable passed to CutSet.map() must return a Cut class instance."
return mapped

return CutSet.from_cuts(verified(transform_fn(c)) for c in self)
Expand Down
83 changes: 83 additions & 0 deletions test/cut/test_cut_drop_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest

from lhotse import CutSet
from lhotse.cut import PaddingCut
from lhotse.testing.dummies import dummy_cut, dummy_supervision

parametrize_on_cut_types = pytest.mark.parametrize(
'cut', [
# MonoCut
dummy_cut(0, supervisions=[dummy_supervision(0)]),
# PaddingCut
PaddingCut('pad', duration=1.0, sampling_rate=16000, feat_value=-100,
num_frames=100, frame_shift=0.01, num_features=80, num_samples=16000),
# MixedCut
dummy_cut(0, supervisions=[dummy_supervision(0)]).mix(
dummy_cut(1, supervisions=[dummy_supervision(1)]),
offset_other_by=0.5,
snr=10
)
]
)


@parametrize_on_cut_types
def test_drop_features(cut):
assert cut.has_features
cut_drop = cut.drop_features()
assert cut.has_features
assert not cut_drop.has_features


@parametrize_on_cut_types
def test_drop_recording(cut):
assert cut.has_recording
cut_drop = cut.drop_recording()
assert cut.has_recording
assert not cut_drop.has_recording


@parametrize_on_cut_types
def test_drop_supervisions(cut):
assert len(cut.supervisions) > 0 or isinstance(cut, PaddingCut)
cut_drop = cut.drop_supervisions()
assert len(cut.supervisions) > 0 or isinstance(cut, PaddingCut)
assert len(cut_drop.supervisions) == 0


@pytest.fixture()
def cutset():
return CutSet.from_cuts([
# MonoCut
dummy_cut(0, supervisions=[dummy_supervision(0)]),
# PaddingCut
PaddingCut('pad', duration=1.0, sampling_rate=16000, feat_value=-100,
num_frames=100, frame_shift=0.01, num_features=80, num_samples=16000),
# MixedCut
dummy_cut(0, supervisions=[dummy_supervision(0)]).mix(
dummy_cut(1, supervisions=[dummy_supervision(1)]),
offset_other_by=0.5,
snr=10
)
])


def test_drop_features_cutset(cutset):
assert any(cut.has_features for cut in cutset)
cutset_drop = cutset.drop_features()
assert any(cut.has_features for cut in cutset)
assert all(not cut.has_features for cut in cutset_drop)


def test_drop_recordings_cutset(cutset):
assert any(cut.has_recording for cut in cutset)
cutset_drop = cutset.drop_recordings()
assert any(cut.has_recording for cut in cutset)
assert all(not cut.has_recording for cut in cutset_drop)


def test_drop_supervisions_cutset(cutset):
assert any(len(cut.supervisions) > 0 for cut in cutset)
cutset_drop = cutset.drop_supervisions()
assert any(len(cut.supervisions) > 0 for cut in cutset)
assert all(len(cut.supervisions) == 0 for cut in cutset_drop)
Loading

0 comments on commit 28e56fd

Please sign in to comment.