Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier committed Jan 7, 2025
1 parent 147e33b commit c41a2c5
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 56 deletions.
10 changes: 5 additions & 5 deletions examples/mistral-4-node-benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ batch:
micro_batch_size: 1
batch_size: 32
data:
format: random
split: [1, 0, 0]
dataset:
type: dummy
optimizer:
learning_rate:
base: 1.0e-05
Expand All @@ -27,18 +27,18 @@ model:
normalization:
type: rms_norm
epsilon: 1.0e-05
rotary:
type: default
theta: 10000
num_layers: 32
hidden_size: 4096
ffn_hidden_size: 14336
num_attention_heads: 32
head_groups: 8
add_linear_biases: false
use_rotary_embeddings: true
gated: true
activation_type: silu
triton_rotary: true
kv_channels: 128
rotary_embedding_scale: -9.210340371976184
window_size: 4096
init_method_std: 0.009021
attention_dropout: 0.0
Expand Down
33 changes: 21 additions & 12 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ def __post_init__(self):
In general this should not be overridden in derived classes,
and all post-processing should be done in `_validate`
"""
self._check_abstract()
self._validated = False
if _AUTO_VALIDATE:
self.validate()
Expand Down Expand Up @@ -343,6 +342,7 @@ def _validate(self):
Can be extended to add custom post-processing (typically before the super() call)
and validation (typically after)
"""
self._check_abstract()
errors = []
for name, field in self.fields():
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
Expand Down Expand Up @@ -604,17 +604,12 @@ def _add_field_to_args(
else:
field_value = value
if serializable:
if hasattr(value, "__fast_llm_serialize__"):
field_value = field_value.__fast_llm_serialize__()
if isinstance(value, enum.Enum):
field_value = field_value.value
# Tag is not actually serializable, but needs to be kept as-is for config processing,
# and should be absent for valid configs.
elif not isinstance(value, int | float | bool | str | Tag | None):
field_value = str(field_value)
field_value = cls._serialize_value(value)
if format_ == _ConfigDictFormat.tuple:
field_value = {(): field_value}

if serializable:
name = cls._serialize_value(name)
if format_ == _ConfigDictFormat.tuple:
args.update({(name,) + name_: value_ for name_, value_ in field_value.items()})
elif format_ == _ConfigDictFormat.nested:
Expand All @@ -626,6 +621,19 @@ def _add_field_to_args(
else:
raise NotImplementedError(format_)

@classmethod
def _serialize_value(cls, value):
value = value
if hasattr(value, "__fast_llm_serialize__"):
value = value.__fast_llm_serialize__()
if isinstance(value, enum.Enum):
value = value.value
# Tag is not actually serializable, but needs to be kept as-is for config processing,
# and should be absent for valid configs.
elif not isinstance(value, int | float | bool | str | Tag | None):
value = str(value)
return value

def to_copy(
self,
*updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]],
Expand Down Expand Up @@ -690,7 +698,6 @@ def _from_dict(
strict: bool = True,
flat: bool = False,
):
cls._check_abstract()
# TODO v0.3: Remove flat format
out_arg_dict = {}

Expand Down Expand Up @@ -841,9 +848,11 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
@classmethod
def _check_abstract(cls):
if cls._abstract:
raise RuntimeError(f"{cls.__name__} is abstract")
raise ValidationError(f"{cls.__name__} is abstract")
if not cls.__class_validated__:
raise RuntimeError(f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator.")
raise ValidationError(
f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator."
)

def __init_subclass__(cls):
"""
Expand Down
9 changes: 5 additions & 4 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
hint=FieldHint.feature,
)
dataset: GPTSampledSplitDatasetConfig = Field(
default=None,
default_factory=GPTSampledSplitDatasetConfig,
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
)
Expand All @@ -47,11 +47,12 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
hint=FieldHint.expert,
)

def __post_init__(self):
if self.dataset is None:
def _validate(self):
if self.dataset.type is None:
logger.warning("Using the legacy dataset definition format." " Specify it through `data.dataset` instead.")
self.dataset = GPTLegacyDatasetConfig(
split=self.split,
ratio=self.ratio,
format=self.format,
path=self.path,
)
super()._validate()
1 change: 1 addition & 0 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def setup(self, distributed: Distributed, samples_per_phase: PhaseSplits[int]):
}
)
self._datasets = self._config.dataset.build_split_sample(self, sampling_config)
self._is_setup = True

@property
def tokenizer(self):
Expand Down
27 changes: 16 additions & 11 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
import functools
import math
import typing

from fast_llm.config import Config, Field, FieldHint, config_class
from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.config import SamplingConfig
from fast_llm.data.dataset.abstract import (
PhaseSplits,
SamplableDataset,
SamplableSplitDataset,
SampledDataset,
SampledSplitDataset,
)
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.abstract import (
PhaseSplits,
SamplableDataset,
SamplableSplitDataset,
SampledDataset,
SampledSplitDataset,
)


@config_class()
class DatasetConfig(Config):
_abstract = True


@config_class()
class SampledSplitDatasetConfig(DatasetConfig):

def build_split_sample(
Expand All @@ -43,6 +42,7 @@ def split(self):
return True


@config_class()
class SampledDatasetConfig(SampledSplitDatasetConfig):
"""
A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training.
Expand Down Expand Up @@ -70,6 +70,7 @@ def split(self):
return False


@config_class()
class SamplableSplitDatasetConfig(SampledSplitDatasetConfig):

def build_split(
Expand Down Expand Up @@ -102,6 +103,7 @@ def split(self):
return True


@config_class()
class SamplableDatasetConfig(SampledDatasetConfig, SamplableSplitDatasetConfig):
def build(self, data: Data) -> SamplableDataset:
raise NotImplementedError()
Expand Down Expand Up @@ -135,10 +137,13 @@ class BlendedDatasetConfig(SampledDatasetConfig):
hint=FieldHint.core,
)
datasets: list[SampledDatasetConfig] = Field(
default_factory=list,
desc="The datasets to blend.",
hint=FieldHint.core,
valid=check_field(functools.partial(Assert.custom, lambda x: len(x) > 0)),
)
weights: list[float] = Field(
default_factory=list,
desc="The blending weight of each dataset.",
hint=FieldHint.core,
)
Expand Down
Loading

0 comments on commit c41a2c5

Please sign in to comment.