-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Approximate proportional sampling in BucketingSampler; remaining_duration, remaining_cuts, num_cuts properties for samplers. #372
Conversation
…tion, remaining_cuts, num_cuts properties for samplers.
return self._orig_items.is_lazy | ||
|
||
@property | ||
def remaining_duration(self) -> Optional[float]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why this is a property and not a function? E.g. does it indicate that it's expected to be fast to compute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's the reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, except my only concern is what might happen if, due to floating point roundoff, we get inaccuracies near the last batch remaining. Can you convince yourself that at least it won't lead to a crash? Discarding a few cuts is OK.
I share the concern -- I ran it multiple times with different seeds on 100k+ items without issues, but I'm still not sure. I'll sleep on it and add some safeguards. |
Yeah, I don't think testing like that is sufficient, I think there needs to be logic where you handle the case where the duration is wrong. You could temporarily initialize the total duration to the real total duration plus a random number, to test whether that logic works. |
I think I've convinced myself that this logic is OK. It ensures that the duration is non-negative (via the property in DataSource), and so even if it'd be incorrect, it'll only affect the sampling probabilities. I also followed your suggestion to add random numbers to the total duration (0 - 100s), and it successfully passes all tests and is able to iterate through a 3 mln item CutSet. BTW using the same CutSet I checked that if we compute the total duration (3020h) and then subtract the duration one-by-one in a randomized order, the errors accumulate to only ~2e-7 (seconds). For extra safety though, I added a flag to disable the proportional sampling in case it causes some issues. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
lhotse/dataset/sampling/bucketing.py
Outdated
counts = [s.remaining_cuts for _, s in self._nondepleted_samplers_with_idxs] | ||
if any(c is None for c in counts): | ||
return None | ||
return sum(counts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW you could make this more efficient with try-except.
…-sampling' into feature/approximate-proportional-sampling
I checked the efficiency of the approximate proportional sampling; for a sampler with 7700 batches, I'm checking what's the step when the first bucket gets depleted. For 'equal_len' buckets, the number goes up from 2500 to 5100. For 'equal_duration', it goes up from 6800-7200 to 7600. It seems to be working well. @danpovey
@csukuangfj I added
num_cuts
property to the sampler that you were asking for. Take note that it may beNone
when theCutSet
is opened as a lazy manifest.