Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Approximate proportional sampling in BucketingSampler; remaining_duration, remaining_cuts, num_cuts properties for samplers. #372

Merged
merged 8 commits into from
Aug 16, 2021
35 changes: 31 additions & 4 deletions lhotse/dataset/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def __init__(
It makes sense to turn it off when iterating the sampler is somewhat costly for any reason;
e.g. because the underlying manifest is lazily loaded from the filesystem/somewhere else.
"""
super().__init__(
data_source=None
) # the "data_source" arg is not used in Sampler...
super().__init__(data_source=None) # the "data_source" arg is not used in Sampler...
self.shuffle = shuffle
self.seed = seed
self.epoch = 0
Expand Down Expand Up @@ -103,7 +101,6 @@ def set_epoch(self, epoch: int) -> None:
:param epoch: Epoch number.
"""
self.epoch = epoch
self.num_batches = None

def filter(self, predicate: Callable[[Cut], bool]) -> None:
"""
Expand Down Expand Up @@ -134,6 +131,36 @@ def _next_batch(self):
"Sub-classes of CutSampler have to implement self._next_batch()"
)

@property
def remaining_duration(self) -> Optional[float]:
"""
Remaining duration of data left in the sampler (may be inexact due to float arithmetic).
Not available when the CutSet is read in lazy mode (returns None).
"""
raise NotImplementedError(
'Sub-classes of CutSampler have to implement self.remaining_duration'
)

@property
def remaining_cuts(self) -> Optional[int]:
"""
Remaining number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).
"""
raise NotImplementedError(
'Sub-classes of CutSampler have to implement self.remaining_cuts'
)

@property
def num_cuts(self) -> Optional[int]:
"""
Total number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).
"""
raise NotImplementedError(
'Sub-classes of CutSampler have to implement self.num_cuts'
)

def __len__(self) -> int:
if not self.provide_len:
# Fake non-existence of this attribute
Expand Down
79 changes: 77 additions & 2 deletions lhotse/dataset/sampling/bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import reduce
from itertools import chain
from operator import add
from typing import Callable, Dict, List, Tuple, Type
from typing import Callable, Dict, List, Optional, Tuple, Type

import numpy as np
from typing_extensions import Literal
Expand Down Expand Up @@ -119,6 +119,45 @@ def __init__(
self.bucket_rng = random.Random(self.seed + self.epoch)
self.depleted = [False] * num_buckets

@property
def remaining_duration(self) -> Optional[float]:
"""
Remaining duration of data left in the sampler (may be inexact due to float arithmetic).
Not available when the CutSet is read in lazy mode (returns None).

.. note: For BucketingSampler, it's the sum of remaining duration in all buckets.
"""
durs = [s.remaining_duration for _, s in self._nondepleted_samplers_with_idxs]
if any(d is None for d in durs):
return None
return sum(durs)

@property
def remaining_cuts(self) -> Optional[int]:
"""
Remaining number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).

.. note: For BucketingSampler, it's the sum of remaining cuts in all buckets.
"""
counts = [s.remaining_cuts for _, s in self._nondepleted_samplers_with_idxs]
if any(c is None for c in counts):
return None
return sum(counts)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW you could make this more efficient with try-except.


@property
def num_cuts(self) -> Optional[int]:
"""
Total number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).

.. note: For BucketingSampler, it's the sum of num cuts in all buckets.
"""
counts = [s.num_cuts for s in self.bucket_samplers]
if any(c is None for c in counts):
return None
return sum(counts)

def set_epoch(self, epoch: int) -> None:
"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
Expand Down Expand Up @@ -157,9 +196,45 @@ def __iter__(self) -> "BucketingSampler":
self.depleted = [False] * self.num_buckets
return self

def _select_bucket_with_idx(self) -> Tuple[int, CutSampler]:
# return self.bucket_rng.choice(self._nondepleted_samplers_with_idxs)
if self.cut_sets[0].is_lazy:
# With lazy CutSets, we simply choose a random bucket,
# because we can't know how much data is left in the buckets.
return self.bucket_rng.choice(self._nondepleted_samplers_with_idxs)
idx_sampler_pairs = self._nondepleted_samplers_with_idxs
if len(idx_sampler_pairs) == 1:
# Only a single bucket left -- choose it.
return idx_sampler_pairs[0]
# If we got there, it means there are at least 2 buckets we can sample from.
# We are going to use approximate proportional sampling:
# for that, we randomly select two buckets, and then assign a higher probability
# to the bucket that has more cumulative data duration left to sample.
# This helps ensure that none of the buckets is depleted much earlier than
# the others.
idx1, sampler1 = self.bucket_rng.choice(idx_sampler_pairs)
idx2, sampler2 = self.bucket_rng.choice(idx_sampler_pairs)
# Note: prob1 is the probability of selecting sampler1
try:
prob1 = sampler1.remaining_duration / (
sampler1.remaining_duration + sampler2.remaining_duration
)
except ZeroDivisionError:
# This will happen when we have already depleted the samplers,
# but the BucketingSampler doesn't know it yet. We only truly
# know that a sampler is depleted when we try to get a batch
# and it raises a StopIteration, which is done after this stage.
# We can't depend on remaining_duration for lazy CutSets.
# If both samplers are zero duration, just return the first one.
return idx1, sampler1
if self.bucket_rng.random() > prob1:
return idx2, sampler2
else:
return idx1, sampler1

def _next_batch(self):
while not self.is_depleted:
idx, sampler = self.bucket_rng.choice(self._nondepleted_samplers_with_idxs)
idx, sampler = self._select_bucket_with_idx()
try:
return next(sampler)
except StopIteration:
Expand Down
28 changes: 28 additions & 0 deletions lhotse/dataset/sampling/cut_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,34 @@ def __init__(
self.max_cuts = max_cuts
self.drop_last = drop_last

@property
def remaining_duration(self) -> Optional[float]:
"""
Remaining duration of data left in the sampler (may be inexact due to float arithmetic).
Not available when the CutSet is read in lazy mode (returns None).

.. note: For :class:`.CutPairsSampler` we return the source cuts duration.
"""
return self.source_cuts.remaining_duration

@property
def remaining_cuts(self) -> Optional[int]:
"""
Remaining number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).
"""
return self.source_cuts.remaining_cuts

@property
def num_cuts(self) -> Optional[int]:
"""
Total number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).
"""
if self.source_cuts.is_lazy:
return None
return len(self.source_cuts)

def __iter__(self) -> "CutPairsSampler":
"""
Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested.
Expand Down
37 changes: 34 additions & 3 deletions lhotse/dataset/sampling/data_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from collections import deque
from typing import Generator, Iterable
from typing import Generator, Iterable, Optional

from lhotse import CutSet
from lhotse.cut import Cut
Expand All @@ -18,6 +18,27 @@ def __init__(self, items: CutSet):
self._shuffled_items = self._orig_items
self._iter = None
self._reusable = deque()
# Add duration tracking for non-lazy CutSets
if not self.is_lazy:
self._total_duration = sum(c.duration for c in self._orig_items)
self._total_cuts = len(self._orig_items)
else:
self._total_duration = None
self._total_cuts = None
self._remaining_duration = self._total_duration
self.remaining_cuts = self._total_cuts

@property
def is_lazy(self) -> bool:
return self._orig_items.is_lazy

@property
def remaining_duration(self) -> Optional[float]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why this is a property and not a function? E.g. does it indicate that it's expected to be fast to compute?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's the reason.

# Paranoia mode: float arithmetic is imprecise, so we'll make sure
# that the returned duration is non-negative.
if self._remaining_duration is None:
return None
return max(0, self._remaining_duration)

def shuffle(self, seed: int) -> "DataSource":
"""
Expand Down Expand Up @@ -45,11 +66,16 @@ def sort_like(self, other: "DataSource") -> "DataSource":
def take_back(self, cut: Cut) -> None:
"""Push the cut in front of other cuts to be sampled again."""
self._reusable.append(cut)
if not self.is_lazy:
self._remaining_duration += cut.duration
self.remaining_cuts += 1

def reset(self) -> None:
"""Reset the iterable state of DataSource."""
self._iter = None
self._reusable.clear()
self._remaining_duration = self._total_duration
self.remaining_cuts = self._total_cuts

def __iter__(self) -> "DataSource":
self.reset()
Expand All @@ -58,8 +84,13 @@ def __iter__(self) -> "DataSource":

def __next__(self) -> Cut:
if self._reusable:
return self._reusable.popleft()
return next(self._iter)
next_cut = self._reusable.popleft()
else:
next_cut = next(self._iter)
if not self.is_lazy:
self._remaining_duration -= next_cut.duration
self.remaining_cuts -= 1
return next_cut

def __len__(self) -> int:
return len(self._shuffled_items)
Expand Down
26 changes: 26 additions & 0 deletions lhotse/dataset/sampling/single_cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,32 @@ def __init__(
# Constraints
assert is_none_or_gt(self.max_cuts, 0)

@property
def remaining_duration(self) -> Optional[float]:
"""
Remaining duration of data left in the sampler (may be inexact due to float arithmetic).
Not available when the CutSet is read in lazy mode (returns None).
"""
return self.data_source.remaining_duration

@property
def remaining_cuts(self) -> Optional[int]:
"""
Remaining number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).
"""
return self.data_source.remaining_cuts

@property
def num_cuts(self) -> Optional[int]:
"""
Total number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).
"""
if self.data_source.is_lazy:
return None
return len(self.data_source)

def __iter__(self) -> "SingleCutSampler":
"""
Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested.
Expand Down
40 changes: 39 additions & 1 deletion lhotse/dataset/sampling/zip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import reduce
from operator import add
from typing import Callable
from typing import Callable, Optional

from lhotse import CutSet
from lhotse.cut import Cut
Expand Down Expand Up @@ -32,6 +32,44 @@ def __init__(self, *samplers: CutSampler) -> None:
super().__init__()
self.samplers = samplers

@property
def remaining_duration(self) -> Optional[float]:
"""
Remaining duration of data left in the sampler (may be inexact due to float arithmetic).

.. note: For ZipSampler, it's the minimum of remaining durations in its sub-samplers.
"""
durs = [s.remaining_duration for s in self.samplers]
if any(d is None for d in durs):
return None
return min(durs)

@property
def remaining_cuts(self) -> Optional[int]:
"""
Remaining number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).

.. note: For ZipSampler, it's the minimum of remaining cuts in its sub-samplers.
"""
counts = [s.remaining_cuts for s in self.samplers]
if any(c is None for c in counts):
return None
return min(counts)

@property
def num_cuts(self) -> Optional[int]:
"""
Total number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).

.. note: For ZipSampler, it's the minimum of num cuts in its sub-samplers.
"""
counts = [s.num_cuts for s in self.samplers]
if any(c is None for c in counts):
return None
return min(counts)

def __iter__(self):
for sampler in self.samplers:
iter(sampler)
Expand Down
Loading