diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 911df8f50..fd82db6cb 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -1,12 +1,13 @@ """Module containing Dataset functionality""" import logging +import os from typing import List import torch from datasets import IterableDataset -from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy +from .prompt_tokenizers import PromptTokenizingStrategy # We want this to be a wrapper for an existing dataset that we have loaded # lets use the concept of middlewares to wrap each dataset, for example @@ -34,17 +35,15 @@ def __init__( # pylint: disable=super-init-not-called self.dataset = dataset def __iter__(self): - iterator = iter(self.dataset) - count = 0 - # Loop through the entire dataset - for example in iterator: - try: - yield self.prompt_tokenizer.tokenize_prompt(example) - count += 1 - except InvalidDataException: - pass - if count == 0: - raise RuntimeError("Expected at least one datapoint in dataset.") + features = self.dataset.features.keys() + num_proc = os.cpu_count() + return iter( + self.dataset.map( + self.prompt_tokenizer.tokenize_prompt, + num_proc=num_proc, + remove_columns=features, + ) + ) # TODO this isn't the best since it can't interleave datasets