From e013ba2d3ac1ebf63d10b343676bfe39b511edd9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 7 Jan 2025 15:21:38 -0500 Subject: [PATCH] fix --- fast_llm/data/dataset/gpt/config.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index f9c78d4..d475f74 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -22,7 +22,7 @@ from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset from fast_llm.data.dataset.gpt.concatenated import GPTConcatenatedDataset - from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset + from fast_llm.data.dataset.gpt.dummy import GPTDummySampledDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset @@ -102,15 +102,32 @@ def build(self, data: "GPTData") -> "GPTIndexedDataset": @config_class() -class GPTDummyDatasetConfig(GPTSamplableDatasetConfig, type_="dummy"): +class GPTDummyDatasetConfig(GPTSampledSplitDatasetConfig, type_="dummy"): + # NA -> (unsampled, unsplit) + _abstract = False name: str = Field( default="dummy", desc="The name of the dataset.", hint=FieldHint.core, ) - def build(self, data: "GPTData") -> "GPTDummyDataset": - return GPTDummyDataset(self.name, data.max_sequence_length, data.vocab_size) + def build_split_sample( + self, + data: "GPTData", + config: PhaseSplits[GPTSamplingConfig], + default_phase: PhaseType = PhaseType.training, + ) -> "SampledSplitDataset[GPTDummySampledDataset]": + from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset, GPTDummySampledDataset + + return SampledSplitDataset[GPTDummySampledDataset]( + self.name, + { + phase: GPTDummyDataset(f"{self.name}_{phase.value}", data.max_sequence_length, data.vocab_size).sample( + phase_config, data + ) + for phase, phase_config in config.items() + }, + ) @config_class() @@ -191,7 +208,7 @@ def build_split( @config_class() class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig, type_="blended"): _abstract = False - datasets: list[GPTSampledDatasetConfig] = FieldUpdate(desc="UINGBRI") + datasets: list[GPTSampledDatasetConfig] = FieldUpdate() class LegacyDatasetSource(str, enum.Enum):