Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier committed Jan 7, 2025
1 parent c41a2c5 commit e013ba2
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from fast_llm.data.data.gpt.data import GPTData
from fast_llm.data.dataset.gpt.abstract import GPTIndexedDataset
from fast_llm.data.dataset.gpt.concatenated import GPTConcatenatedDataset
from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset
from fast_llm.data.dataset.gpt.dummy import GPTDummySampledDataset
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset


Expand Down Expand Up @@ -102,15 +102,32 @@ def build(self, data: "GPTData") -> "GPTIndexedDataset":


@config_class()
class GPTDummyDatasetConfig(GPTSamplableDatasetConfig, type_="dummy"):
class GPTDummyDatasetConfig(GPTSampledSplitDatasetConfig, type_="dummy"):
# NA -> (unsampled, unsplit)
_abstract = False
name: str = Field(
default="dummy",
desc="The name of the dataset.",
hint=FieldHint.core,
)

def build(self, data: "GPTData") -> "GPTDummyDataset":
return GPTDummyDataset(self.name, data.max_sequence_length, data.vocab_size)
def build_split_sample(
self,
data: "GPTData",
config: PhaseSplits[GPTSamplingConfig],
default_phase: PhaseType = PhaseType.training,
) -> "SampledSplitDataset[GPTDummySampledDataset]":
from fast_llm.data.dataset.gpt.dummy import GPTDummyDataset, GPTDummySampledDataset

return SampledSplitDataset[GPTDummySampledDataset](
self.name,
{
phase: GPTDummyDataset(f"{self.name}_{phase.value}", data.max_sequence_length, data.vocab_size).sample(
phase_config, data
)
for phase, phase_config in config.items()
},
)


@config_class()
Expand Down Expand Up @@ -191,7 +208,7 @@ def build_split(
@config_class()
class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig, type_="blended"):
_abstract = False
datasets: list[GPTSampledDatasetConfig] = FieldUpdate(desc="UINGBRI")
datasets: list[GPTSampledDatasetConfig] = FieldUpdate()


class LegacyDatasetSource(str, enum.Enum):
Expand Down

0 comments on commit e013ba2

Please sign in to comment.