Skip to content

Commit

Permalink
formatted iterator for filter
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hh committed Oct 8, 2024
1 parent 4a761a9 commit 3b65d99
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,46 +1245,48 @@ def _iter(self):
num_examples_to_skip = 0
iterator = iter(self.ex_iterable)

if self.formatting:
formatter = get_formatter(self.formatting.format_type)
format_dict = (
formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else cast_to_python_objects
formatter = get_formatter(self.formatting.format_type) if self.formatting else None
if self.formatting and self.ex_iterable.iter_arrow:
# we still want to use an arrow iterator, yielding single batches of size self.batch_size
# to which the formatter can be applied
ex_iterable = RebatchedArrowExamplesIterable(
self.ex_iterable, batch_size=self.batch_size if self.batched else 1, drop_last_batch=False
)
batched_examples_iterator = formatted_arrow_examples_iterator(ex_iterable, formatter, batched=self.batched)

else:
format_dict = None
batched_examples_iterator = formatted_python_examples_iterator(
self.ex_iterable, batch_size=self.batch_size, formatter=formatter, batched=self.batched
)

if self.batched:
if self._state_dict:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = current_idx
for key, example in iterator:
# If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset
iterator_batch = (
iterator
if self.batch_size is None or self.batch_size <= 0
else islice(iterator, self.batch_size - 1)
)
key_examples_list = [(key, example)] + list(iterator_batch)
keys, examples = zip(*key_examples_list)
batch = _examples_to_batch(examples)
batch = format_dict(batch) if format_dict else batch
for combined_key, batch in batched_examples_iterator:
if batch:
batch_len = len(batch[next(iter(batch))])
else:
batch_len = 0
# then compute the mask for the batch
inputs = batch
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns]
if self.with_indices:
function_args.append([current_idx + i for i in range(len(key_examples_list))])
function_args.append([current_idx + i for i in range(batch_len)])
mask = self.function(*function_args, **self.fn_kwargs)
# yield one example at a time from the batch
for key_example, to_keep in zip(key_examples_list, mask):
example_keys = combined_key.split("_")
examples = _batch_to_examples(batch)
for key, example, to_keep in zip(example_keys, examples, mask):
current_idx += 1
if self._state_dict:
self._state_dict["num_examples_since_previous_state"] += 1
if num_examples_to_skip > 0:
num_examples_to_skip -= 1
continue
if to_keep:
yield key_example
yield key, example
if self._state_dict:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
Expand All @@ -1293,7 +1295,7 @@ def _iter(self):
for key, example in iterator:
# If not batched, we can apply the filtering function direcly
example = dict(example)
inputs = format_dict(example) if format_dict else example
inputs = example
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns]
if self.with_indices:
function_args.append(current_idx)
Expand Down

0 comments on commit 3b65d99

Please sign in to comment.