Skip to content

Commit

Permalink
batch sampling: cut sampled sampling-rate to (0,1]
Browse files Browse the repository at this point in the history
  • Loading branch information
eladn committed Nov 20, 2021
1 parent 0422d84 commit 187160c
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions ndfa/misc/tensors_data_class/batch_flattened_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,18 @@ def _collate_first_pass(cls, inputs: List['BatchFlattenedSequencesDataClassMixin
for sequences_field in sequences_fields:
tensors = getattr(inp, sequences_field.name)
nr_tensors = tensors.size(0) if isinstance(tensors, torch.Tensor) else len(tensors)
if nr_tensors < 1:
continue
if sequences_per_example_sampling.min_nr_items_to_sample_by_rate is not None and \
nr_tensors < sequences_per_example_sampling.min_nr_items_to_sample_by_rate:
continue
random_state = np.random.RandomState(random_seed_per_example[example_idx])
nr_sequences_to_sample_per_example = nr_tensors
if sequences_per_example_sampling.distribution_for_rate_to_sample_by is not None:
sampling_rate = _sample_by_distribution_params(
sampling_rate = max(1 / nr_tensors, min(1, _sample_by_distribution_params(
params=sequences_per_example_sampling.distribution_for_rate_to_sample_by,
rng=random_state)
nr_sequences_to_sample_per_example = round(sampling_rate * nr_tensors)
rng=random_state)))
nr_sequences_to_sample_per_example = max(1, min(nr_tensors, round(sampling_rate * nr_tensors)))
if sequences_per_example_sampling.max_nr_items is not None:
nr_sequences_to_sample_per_example = \
min(nr_sequences_to_sample_per_example, sequences_per_example_sampling.max_nr_items)
Expand Down

0 comments on commit 187160c

Please sign in to comment.