From 187160c18256d0ac6e1161577f1f932746268683 Mon Sep 17 00:00:00 2001 From: Elad Date: Sun, 21 Nov 2021 00:57:46 +0200 Subject: [PATCH] batch sampling: cut sampled sampling-rate to (0,1] --- ndfa/misc/tensors_data_class/batch_flattened_seq.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ndfa/misc/tensors_data_class/batch_flattened_seq.py b/ndfa/misc/tensors_data_class/batch_flattened_seq.py index 34a1da3..96b6265 100644 --- a/ndfa/misc/tensors_data_class/batch_flattened_seq.py +++ b/ndfa/misc/tensors_data_class/batch_flattened_seq.py @@ -69,16 +69,18 @@ def _collate_first_pass(cls, inputs: List['BatchFlattenedSequencesDataClassMixin for sequences_field in sequences_fields: tensors = getattr(inp, sequences_field.name) nr_tensors = tensors.size(0) if isinstance(tensors, torch.Tensor) else len(tensors) + if nr_tensors < 1: + continue if sequences_per_example_sampling.min_nr_items_to_sample_by_rate is not None and \ nr_tensors < sequences_per_example_sampling.min_nr_items_to_sample_by_rate: continue random_state = np.random.RandomState(random_seed_per_example[example_idx]) nr_sequences_to_sample_per_example = nr_tensors if sequences_per_example_sampling.distribution_for_rate_to_sample_by is not None: - sampling_rate = _sample_by_distribution_params( + sampling_rate = max(1 / nr_tensors, min(1, _sample_by_distribution_params( params=sequences_per_example_sampling.distribution_for_rate_to_sample_by, - rng=random_state) - nr_sequences_to_sample_per_example = round(sampling_rate * nr_tensors) + rng=random_state))) + nr_sequences_to_sample_per_example = max(1, min(nr_tensors, round(sampling_rate * nr_tensors))) if sequences_per_example_sampling.max_nr_items is not None: nr_sequences_to_sample_per_example = \ min(nr_sequences_to_sample_per_example, sequences_per_example_sampling.max_nr_items)