Skip to content
6 changes: 3 additions & 3 deletions examples/mistral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ optimizer:
beta_2: 0.95
model:
base_model:
embeddings_layer:
embeddings:
hidden_size: 4096
vocab_size: 32000
dropout: 0.0
Expand All @@ -54,11 +54,11 @@ model:
epsilon: 1.0e-05
dropout: 0.0
num_blocks: 32
output_layer:
tied_weight: false
head:
normalization:
type: rms_norm
epsilon: 1.0e-05
tied_embedding_weight: false
multi_stage:
zero_stage: 2
distributed:
Expand Down
8 changes: 4 additions & 4 deletions fast_llm/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def reduce_op(
return (input_, handle) if async_op else input_


def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]:
def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor:
"""Split the tensor along its last dimension and keep the
corresponding slice."""
if group:
Expand Down Expand Up @@ -139,11 +139,11 @@ class _Split(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""

@staticmethod
def symbolic(graph, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: # noqa
def symbolic(graph, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: # noqa
return split_op(input_, group, dim)

@staticmethod
def forward(ctx, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: # noqa
def forward(ctx, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: # noqa
ctx.group = group
ctx.dim = dim
return split_op(input_, group, dim)
Expand Down Expand Up @@ -209,7 +209,7 @@ def reduce_backward(input_: torch.Tensor, group: ProcessGroup | None) -> torch.T


@torch._dynamo.disable # noqa
def split(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]:
def split(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor:
return _Split.apply(input_, group, dim)


Expand Down
134 changes: 86 additions & 48 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
import abc
import typing

import torch
import torch.nn

from fast_llm.config import Configurable
from fast_llm.engine.base_model.config import BaseModelConfig, ResourceUsageConfig
from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, ResourceUsageConfig
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.tensor import ParameterMeta, TensorMeta
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.engine.inference.runner import InferenceRunner


class Module(torch.nn.Module, abc.ABC):
""" """

class LayerBase(torch.nn.Module, abc.ABC):
_is_setup: bool = False
_distributed: Distributed

Expand All @@ -27,85 +23,121 @@ def __init__(self, distributed_config: DistributedConfig):

def setup(self, distributed: Distributed) -> None:
assert not self._is_setup
for layer in self.get_layers():
if layer is not self:
layer.setup(distributed)
distributed.check_config(self._distributed_config)
self._distributed = distributed
self._is_setup = True

@abc.abstractmethod
def get_layers(self) -> list["Layer"]:
"""
The list of layers as meant to be seen by the Fast-LLM engine.
May differ from the module configuration seen by pytorch.
"""

class Layer(Module):
# Weight used to determine the stage size
def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
out = 0
for layer in self.get_layers():
if layer is self:
raise NotImplementedError()
out += layer.get_compute_usage(input_, kwargs, config)
return out

def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
losses = []
for layer in self.get_layers():
if layer is not self:
losses += layer.get_loss_definitions(count)
return losses

def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
for layer in self.get_layers():
if layer is not self:
layer.preprocess(batch, kwargs)


class Layer(LayerBase):
# Weight used to determine the stage size.
layer_count: float = 1.0

def get_layers(self) -> list["Layer"]:
# Return a breakdown of the layer into atomic ones,
# i.e. the list of layers from as seen from the Fast-LLM model.
return [self]

@abc.abstractmethod
def forward(
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None
) -> torch.Tensor:
pass

def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
raise NotImplementedError()

def unwrap(self) -> "Layer":
# Get the actual module contained in this layer,
# undoing any wrapping for the Fast-LLM engine (ex. `LayerWithNamespace`)
return self

class Sequential(Layer):
def __init__(self, distributed_config: DistributedConfig):
super().__init__(distributed_config)
self.layers = torch.nn.ModuleList(self.get_layers())

def __getitem__(self, item):
return self.layers[item]
class LayerWithNamespace(Layer):
"""
A layer with its own namespace for preprocessing (kwargs),
so that it doesn't inadvertently interact with other layers.
TODO: Consider namespace for losses and metrics?
"""

def __iter__(self):
return iter(self.layers)
def __init__(self, layer: Layer, namespace: str = None):
super().__init__(layer._distributed_config)
self._layer = layer
self._namespace = namespace
self.layer_count = self._layer.layer_count
self.get_compute_usage = self._layer.get_compute_usage
self.module_name = self._layer.module_name

def __len__(self):
return len(self.layers)
def setup(self, distributed: Distributed) -> None:
self._layer.setup(distributed)
super().setup(distributed)

def forward(
self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None
) -> torch.Tensor:
for layer in self.layers:
input_ = layer(input_, kwargs, losses, metrics)
return input_
if self._namespace in kwargs:
kwargs = kwargs[self._namespace]
else:
# TODO: Forward meta doesn't go through preprocessing so doesn't have a namespace.
# Using kwargs as-is since it's generally unused.
assert isinstance(input_, TensorMeta)
return self._layer.forward(input_, kwargs, losses, metrics)

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass
def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
assert self._namespace not in kwargs
kwargs[self._namespace] = kwargs.copy()
self._layer.preprocess(batch, kwargs[self._namespace])

def setup(self, distributed: Distributed) -> None:
super().setup(distributed)
for layer in self.layers:
layer.setup(distributed)
def unwrap(self) -> "Layer":
return self._layer.unwrap()


class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential):
class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], LayerBase):

def __init__(
self,
config: BaseModelConfig,
distributed_config: DistributedConfig,
):
super().__init__(config, distributed_config)
for key, value in self.named_modules():
value.module_name = key
for key, value in self.named_parameters():
Assert.custom(isinstance, value, ParameterMeta)
# Rename to the parameter full name
value.tensor_name = key

# Reference models
# TODO: Add basic handling (preprocessor) in this class.
self._reference_models: dict[str, "InferenceRunner"] = {}

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass

@abc.abstractmethod
def preprocess_meta(self, batch_meta: typing.Any, phase: PhaseType) -> list[tuple[TensorMeta, dict]]:
# TODO Remove (Move batch splitting elsewhere)
pass

@abc.abstractmethod
def preprocess(
def preprocess_batch(
self,
batch: typing.Any,
preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None,
Expand All @@ -114,13 +146,19 @@ def preprocess(
iteration: int,
metrics: dict | None = None,
) -> list[tuple[torch.Tensor, dict]]:
# TODO Move batch splitting elsewhere, align interface with LayerBase
pass

def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]:
# For each tied weight, return the weight and the tuple of layers sharing it.
# The weight should be defined in the first layer in the set.
# Warning: This may return buffers instead of metas after stage setup.
# The name (dict key) is used to insert the weight in the kwargs of the forward pass.
def get_tied_parameters(self) -> dict[str, list[ParameterMeta]]:
"""
Return tuples of independently defined metas to tie together.
Metas should be compatible, i.e. have the same tensor dimensions.
Tied weights are named (dict keys) for convenience only.
Warning: Initialization and optimization properties are defined on the first appearance of the tied weight.
To prevent any confusion, the metas should be provided in the same order they appear in the model.
TODO: Improve?
Note: This may return buffers instead of metas after stage setup.
"""
return {}

def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None:
Expand Down
34 changes: 26 additions & 8 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

from fast_llm.config import MISSING, Config, Field, FieldHint, FieldVerboseLevel, config_class
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.utils import compare_nested, log
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.utils import Assert, compare_nested, log

if typing.TYPE_CHECKING:
import torch
from fast_llm.engine.base_model.base_model import BaseModel


@config_class()
class BaseModelConfig(Config):
class ModuleConfig(Config):
"""
Abstract config class for a base model.
# TODO: Find better name?
Expand Down Expand Up @@ -43,7 +44,7 @@ def _get_architecture(self) -> dict[str, typing.Any]:
return architecture

def _serialize_architecture_field(self, value: typing.Any) -> typing.Any:
if isinstance(value, BaseModelConfig):
if isinstance(value, ModuleConfig):
# TODO: Make sure all nested configs have an architecture type hint?
return value._get_architecture()
elif isinstance(value, Config):
Expand All @@ -57,12 +58,29 @@ def _serialize_architecture_field(self, value: typing.Any) -> typing.Any:
return self._serialize_value(value)


class Preprocessor(abc.ABC):
def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
pass
@config_class()
class BaseModelConfig(ModuleConfig):
"""
Abstract config class for a base model.
"""

def get_base_model(self, distributed_config: DistributedConfig) -> "BaseModel":
from fast_llm.tensor import ParameterMeta

model = self.base_model_class(self, distributed_config)
# Storing the global name of each module and tensor.
# Done here because it needs to run right after `model.__init__()`
for key, value in model.named_modules():
value.module_name = key
for key, value in model.named_parameters():
Assert.custom(isinstance, value, ParameterMeta)
# Rename to the parameter full name
value.tensor_name = key
return model

@property
@abc.abstractmethod
def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
def base_model_class(self) -> type["BaseModel"]:
pass


Expand Down
12 changes: 7 additions & 5 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,18 @@ def export_config(cls, config: BaseModelConfig) -> dict:

@classmethod
@abc.abstractmethod
def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]:
def get_converters(cls, config: BaseModelConfig, exported_config: dict) -> list[WeightConverter]:
pass


class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC):
architecture: typing.ClassVar[str]
base_model_converter_class: typing.ClassVar[type[HuggingFaceBaseModelConverter]]

def __init__(self, model: "FastLLMModel"):
self._exported_config = self._export_config(model.config)
super().__init__(model)

@classmethod
@abc.abstractmethod
def get_transformers_configuration_class(cls) -> type["transformers.PretrainedConfig"]:
Expand Down Expand Up @@ -126,10 +130,8 @@ def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig:
Assert.eq(config["architecture"], cls.architecture)
return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)})

def _create_weight_converters(
self,
) -> list[WeightConverter]:
return self.base_model_converter_class.get_converters(self._model.config.base_model)
def _create_weight_converters(self) -> list[WeightConverter]:
return self.base_model_converter_class.get_converters(self._model.config.base_model, self._exported_config)

def _load_weights(
self, config: CheckpointLoadConfig, device
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/config_utils/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from triton import language as tl


class DataType(str, enum.Enum):
class DataType(enum.StrEnum):
"""
An enum to represent data types independently of third party libraries,
so we can swap them more easily and allow for lazy imports.
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/config_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
self._distributed.config.data_rank == 0 and self._distributed.config.tensor_rank == 0
)
config_dict = config.to_dict()
config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance)
config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.debug)

if self._config.experiment_dir is not None:
self._experiment_directory = self._config.experiment_dir.resolve()
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def setup(
phase=PhaseType.validation,
)

self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions()
self._loss_defs = self._multi_stage.base_model.get_loss_definitions()
self._evaluation_iterator = None
self._is_setup = True

Expand Down
Loading