diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index e1591f2a166e..bae1a4dd20ac 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1276,9 +1276,13 @@ def _iter(self): function_args.append([current_idx + i for i in range(batch_len)]) mask = self.function(*function_args, **self.fn_kwargs) # yield one example at a time from the batch - example_keys = combined_key.split("_") examples = _batch_to_examples(batch) - for key, example, to_keep in zip(example_keys, examples, mask): + # TODO: nicer way to handle keys? + if not self.formatting: + keys = combined_key.split("_") + else: + keys = [combined_key] * len(mask) + for key, example, to_keep in zip(keys, examples, mask): current_idx += 1 if self._state_dict: self._state_dict["num_examples_since_previous_state"] += 1