diff --git a/Megatron-LM b/Megatron-LM index 30e7aeccd..dee27459d 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 30e7aeccd87ec22e424f35c6e61f05ceb878a8df +Subproject commit dee27459d46fecc513be76732a0095bb38be32fb diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 88655954f..2e4a57de7 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -27,7 +27,7 @@ optimizer: beta_2: 0.95 model: base_model: - embeddings_layer: + embeddings: hidden_size: 4096 vocab_size: 32000 dropout: 0.0 @@ -54,11 +54,11 @@ model: epsilon: 1.0e-05 dropout: 0.0 num_blocks: 32 - output_layer: - tied_weight: false + head: normalization: type: rms_norm epsilon: 1.0e-05 + tied_embedding_weight: false multi_stage: zero_stage: 2 distributed: diff --git a/fast_llm/core/ops.py b/fast_llm/core/ops.py index a7492daa5..bb61aadd0 100644 --- a/fast_llm/core/ops.py +++ b/fast_llm/core/ops.py @@ -26,7 +26,7 @@ def reduce_op( return (input_, handle) if async_op else input_ -def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: +def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: """Split the tensor along its last dimension and keep the corresponding slice.""" if group: @@ -139,11 +139,11 @@ class _Split(torch.autograd.Function): """Split the input and keep only the corresponding chuck to the rank.""" @staticmethod - def symbolic(graph, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: # noqa + def symbolic(graph, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: # noqa return split_op(input_, group, dim) @staticmethod - def forward(ctx, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: # noqa + def forward(ctx, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: # noqa ctx.group = group ctx.dim = dim return split_op(input_, group, dim) @@ -209,7 +209,7 @@ def reduce_backward(input_: torch.Tensor, group: ProcessGroup | None) -> torch.T @torch._dynamo.disable # noqa -def split(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: +def split(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: return _Split.apply(input_, group, dim) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 0a3f8d1ce..5df59d4cd 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -1,23 +1,19 @@ import abc import typing -import torch import torch.nn from fast_llm.config import Configurable -from fast_llm.engine.base_model.config import BaseModelConfig, ResourceUsageConfig +from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, ResourceUsageConfig from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import ParameterMeta, TensorMeta -from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.engine.inference.runner import InferenceRunner -class Module(torch.nn.Module, abc.ABC): - """ """ - +class LayerBase(torch.nn.Module, abc.ABC): _is_setup: bool = False _distributed: Distributed @@ -27,57 +23,102 @@ def __init__(self, distributed_config: DistributedConfig): def setup(self, distributed: Distributed) -> None: assert not self._is_setup + for layer in self.get_layers(): + if layer is not self: + layer.setup(distributed) distributed.check_config(self._distributed_config) self._distributed = distributed self._is_setup = True + @abc.abstractmethod + def get_layers(self) -> list["Layer"]: + """ + The list of layers as meant to be seen by the Fast-LLM engine. + May differ from the module configuration seen by pytorch. + """ -class Layer(Module): - # Weight used to determine the stage size + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + out = 0 + for layer in self.get_layers(): + if layer is self: + raise NotImplementedError() + out += layer.get_compute_usage(input_, kwargs, config) + return out + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + losses = [] + for layer in self.get_layers(): + if layer is not self: + losses += layer.get_loss_definitions(count) + return losses + + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + for layer in self.get_layers(): + if layer is not self: + layer.preprocess(batch, kwargs) + + +class Layer(LayerBase): + # Weight used to determine the stage size. layer_count: float = 1.0 + def get_layers(self) -> list["Layer"]: + # Return a breakdown of the layer into atomic ones, + # i.e. the list of layers from as seen from the Fast-LLM model. + return [self] + @abc.abstractmethod def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: pass - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - raise NotImplementedError() - + def unwrap(self) -> "Layer": + # Get the actual module contained in this layer, + # undoing any wrapping for the Fast-LLM engine (ex. `LayerWithNamespace`) + return self -class Sequential(Layer): - def __init__(self, distributed_config: DistributedConfig): - super().__init__(distributed_config) - self.layers = torch.nn.ModuleList(self.get_layers()) - def __getitem__(self, item): - return self.layers[item] +class LayerWithNamespace(Layer): + """ + A layer with its own namespace for preprocessing (kwargs), + so that it doesn't inadvertently interact with other layers. + TODO: Consider namespace for losses and metrics? + """ - def __iter__(self): - return iter(self.layers) + def __init__(self, layer: Layer, namespace: str = None): + super().__init__(layer._distributed_config) + self._layer = layer + self._namespace = namespace + self.layer_count = self._layer.layer_count + self.get_compute_usage = self._layer.get_compute_usage + self.module_name = self._layer.module_name - def __len__(self): - return len(self.layers) + def setup(self, distributed: Distributed) -> None: + self._layer.setup(distributed) + super().setup(distributed) def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: - for layer in self.layers: - input_ = layer(input_, kwargs, losses, metrics) - return input_ + if self._namespace in kwargs: + kwargs = kwargs[self._namespace] + else: + # TODO: Forward meta doesn't go through preprocessing so doesn't have a namespace. + # Using kwargs as-is since it's generally unused. + assert isinstance(input_, TensorMeta) + return self._layer.forward(input_, kwargs, losses, metrics) - @abc.abstractmethod - def get_layers(self) -> list[Layer]: - pass + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + assert self._namespace not in kwargs + kwargs[self._namespace] = kwargs.copy() + self._layer.preprocess(batch, kwargs[self._namespace]) - def setup(self, distributed: Distributed) -> None: - super().setup(distributed) - for layer in self.layers: - layer.setup(distributed) + def unwrap(self) -> "Layer": + return self._layer.unwrap() -class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential): +class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], LayerBase): def __init__( self, @@ -85,27 +126,18 @@ def __init__( distributed_config: DistributedConfig, ): super().__init__(config, distributed_config) - for key, value in self.named_modules(): - value.module_name = key - for key, value in self.named_parameters(): - Assert.custom(isinstance, value, ParameterMeta) - # Rename to the parameter full name - value.tensor_name = key # Reference models # TODO: Add basic handling (preprocessor) in this class. self._reference_models: dict[str, "InferenceRunner"] = {} - @abc.abstractmethod - def get_layers(self) -> list[Layer]: - pass - @abc.abstractmethod def preprocess_meta(self, batch_meta: typing.Any, phase: PhaseType) -> list[tuple[TensorMeta, dict]]: + # TODO Remove (Move batch splitting elsewhere) pass @abc.abstractmethod - def preprocess( + def preprocess_batch( self, batch: typing.Any, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, @@ -114,13 +146,19 @@ def preprocess( iteration: int, metrics: dict | None = None, ) -> list[tuple[torch.Tensor, dict]]: + # TODO Move batch splitting elsewhere, align interface with LayerBase pass - def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: - # For each tied weight, return the weight and the tuple of layers sharing it. - # The weight should be defined in the first layer in the set. - # Warning: This may return buffers instead of metas after stage setup. - # The name (dict key) is used to insert the weight in the kwargs of the forward pass. + def get_tied_parameters(self) -> dict[str, list[ParameterMeta]]: + """ + Return tuples of independently defined metas to tie together. + Metas should be compatible, i.e. have the same tensor dimensions. + Tied weights are named (dict keys) for convenience only. + Warning: Initialization and optimization properties are defined on the first appearance of the tied weight. + To prevent any confusion, the metas should be provided in the same order they appear in the model. + TODO: Improve? + Note: This may return buffers instead of metas after stage setup. + """ return {} def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None: diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 78fafea34..f1eef47b9 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -4,14 +4,15 @@ from fast_llm.config import MISSING, Config, Field, FieldHint, FieldVerboseLevel, config_class from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.utils import compare_nested, log +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.utils import Assert, compare_nested, log if typing.TYPE_CHECKING: - import torch + from fast_llm.engine.base_model.base_model import BaseModel @config_class() -class BaseModelConfig(Config): +class ModuleConfig(Config): """ Abstract config class for a base model. # TODO: Find better name? @@ -43,7 +44,7 @@ def _get_architecture(self) -> dict[str, typing.Any]: return architecture def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: - if isinstance(value, BaseModelConfig): + if isinstance(value, ModuleConfig): # TODO: Make sure all nested configs have an architecture type hint? return value._get_architecture() elif isinstance(value, Config): @@ -57,12 +58,29 @@ def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: return self._serialize_value(value) -class Preprocessor(abc.ABC): - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - pass +@config_class() +class BaseModelConfig(ModuleConfig): + """ + Abstract config class for a base model. + """ + + def get_base_model(self, distributed_config: DistributedConfig) -> "BaseModel": + from fast_llm.tensor import ParameterMeta + + model = self.base_model_class(self, distributed_config) + # Storing the global name of each module and tensor. + # Done here because it needs to run right after `model.__init__()` + for key, value in model.named_modules(): + value.module_name = key + for key, value in model.named_parameters(): + Assert.custom(isinstance, value, ParameterMeta) + # Rename to the parameter full name + value.tensor_name = key + return model + @property @abc.abstractmethod - def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + def base_model_class(self) -> type["BaseModel"]: pass diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index e5d14711d..afe381295 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -31,7 +31,7 @@ def export_config(cls, config: BaseModelConfig) -> dict: @classmethod @abc.abstractmethod - def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]: + def get_converters(cls, config: BaseModelConfig, exported_config: dict) -> list[WeightConverter]: pass @@ -39,6 +39,10 @@ class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, architecture: typing.ClassVar[str] base_model_converter_class: typing.ClassVar[type[HuggingFaceBaseModelConverter]] + def __init__(self, model: "FastLLMModel"): + self._exported_config = self._export_config(model.config) + super().__init__(model) + @classmethod @abc.abstractmethod def get_transformers_configuration_class(cls) -> type["transformers.PretrainedConfig"]: @@ -126,10 +130,8 @@ def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: Assert.eq(config["architecture"], cls.architecture) return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) - def _create_weight_converters( - self, - ) -> list[WeightConverter]: - return self.base_model_converter_class.get_converters(self._model.config.base_model) + def _create_weight_converters(self) -> list[WeightConverter]: + return self.base_model_converter_class.get_converters(self._model.config.base_model, self._exported_config) def _load_weights( self, config: CheckpointLoadConfig, device diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index 0929b7cb1..f4a2cfd6c 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -9,7 +9,7 @@ from triton import language as tl -class DataType(str, enum.Enum): +class DataType(enum.StrEnum): """ An enum to represent data types independently of third party libraries, so we can swap them more easily and allow for lazy imports. diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 1737f4308..1849a2316 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -136,7 +136,7 @@ def __init__( self._distributed.config.data_rank == 0 and self._distributed.config.tensor_rank == 0 ) config_dict = config.to_dict() - config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance) + config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.debug) if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index d5202a90f..e055595bd 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -116,7 +116,7 @@ def setup( phase=PhaseType.validation, ) - self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions() + self._loss_defs = self._multi_stage.base_model.get_loss_definitions() self._evaluation_iterator = None self._is_setup = True diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index aa18f5052..27c0e2b7b 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -137,11 +137,6 @@ class StageConfig(Config): desc="Check for tensor-parallel desyncs and log an error if a desync is found. High overhead", hint=FieldHint.logging, ) - compile_all: bool = Field( - default=False, - desc="Compile the whole model using torch.compile.", - hint=FieldHint.expert, - ) @config_class() diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 868cc2db4..827079f6e 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -56,11 +56,11 @@ def __init__( # The index range of the parameters in the buffer. self._parameter_begins_in_buffer = { parameter_meta.tensor_name: offset - for parameter_meta, offset in zip(parameter_metas, parameter_offsets[:-1]) + for parameter_meta, offset in zip(parameter_metas, parameter_offsets[:-1], strict=True) } self._parameter_ends_in_buffer = { parameter_meta.tensor_name: offset - for parameter_meta, offset in zip(parameter_metas, parameter_offsets[1:]) + for parameter_meta, offset in zip(parameter_metas, parameter_offsets[1:], strict=True) } # Shard properties @@ -377,7 +377,7 @@ def reduce_gradients( assert self._mode.support_backward if not self._requires_grad: return - for buffer, meta in zip(self._parameter_buffers.values(), self._parameter_metas.values()): + for buffer, meta in zip(self._parameter_buffers.values(), self._parameter_metas.values(), strict=True): if buffer.param_grad_is_zero: # noqa assert allow_no_grad or meta.allow_no_grad, meta triton_fill(buffer.grad_buffer, 0) # noqa diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index e48fdb88b..f45f93862 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -3,7 +3,6 @@ import typing import warnings -import numpy as np import torch from torch._C._distributed_c10d import ProcessGroup @@ -25,7 +24,6 @@ class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): - base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel _is_setup: bool = False _flat_shard: torch.Tensor _shards: dict[str, torch.Tensor] @@ -46,7 +44,8 @@ def __init__( stage_filter: set | None = None, ): super().__init__(config) - self._base_model = self.base_model_class(self._config.base_model, self._config.distributed) + self._base_model = self._config.base_model.get_base_model(self._config.distributed) + self._layers = self._base_model.get_layers() self._training = None self._verbose = verbose self._stage_filter = stage_filter @@ -62,44 +61,51 @@ def __init__( self._num_stages, self._config.distributed.pipeline_parallel * self._config.multi_stage.stages_per_pipeline_stage, ) + # Keep track of which stage each parameter belongs to. + self._parameter_stages: dict[str, int] = {} + for stage_index in range(self._num_stages): + for layer in self._layers[stage_splits[stage_index] : stage_splits[stage_index + 1]]: + for meta in layer.parameters(): + assert meta.tensor_name not in self._parameter_stages + self._parameter_stages[meta.tensor_name] = stage_index + + # Determine which stages belong to this pipeline rank. + self._stage_pipeline_ranks = { + stage_index: (stage_index // self._config.multi_stage.stages_per_pipeline_stage) + % self._config.distributed.pipeline_parallel + for stage_index in (range(self._num_stages)) + } + + # Set up tied weights. + self._tied_parameters = self._get_tied_parameters() + self._tied_parameter_duplicates = [{} for _ in range(self._num_stages)] + for tied_parameter in self._tied_parameters.values(): + for meta in tied_parameter.metas[1:]: + self._tied_parameter_duplicates[self._parameter_stages[meta.tensor_name]][ + meta.tensor_name + ] = tied_parameter # Create the stages. self._stages = [ Stage( config=self._config.multi_stage, - base_model=self._base_model, + layers=self._layers[stage_splits[stage_index] : stage_splits[stage_index + 1]], distributed_config=self._config.distributed, - begin=stage_splits[i], - end=stage_splits[i + 1], - index=i, + index=stage_index, + tied_parameter_duplicates=tied_parameter_duplicates_.keys(), ) - for i in (range(self._num_stages)) + for stage_index, tied_parameter_duplicates_ in enumerate(self._tied_parameter_duplicates) ] if self._verbose: log_main_rank(lambda: f" Total parameters: {sum(stage_.parameter_count for stage_ in self._stages):,} ") - # Keep track of which stage each parameter belongs to. - self._parameter_stages: dict[str, int] = {} - for stage_index, stage in enumerate(self._stages): - for parameter_name in stage.parameter_names: - assert parameter_name not in self._parameter_stages - self._parameter_stages[parameter_name] = stage_index - - # Determine which stages belong to this pipeline rank. - self._stage_pipeline_ranks = { - stage_index: (stage_index // self._config.multi_stage.stages_per_pipeline_stage) - % self._config.distributed.pipeline_parallel - for stage_index in (range(self._num_stages)) - } self._stages_owned = { stage_index: self._stages[stage_index] for stage_index, stage_rank in self._stage_pipeline_ranks.items() if stage_rank == self._config.distributed.pipeline_rank } - # Set up tied weights. - self._tied_parameters = self._get_tied_parameters(stage_splits[1:]) self._tied_weight_main_stages_on_device = { stage_index: self._stages[stage_index] for stage_index in sorted( @@ -320,6 +326,16 @@ def _setup_stages(self) -> None: if self._mode.support_forward and weight_buffer_index is not None else [] ) + tied_weight_duplicate_buffers = ( + { + parameter_name: self._stages[tied_parameter.main_stage].get_parameter_buffer( + tied_parameter.metas[0].tensor_name + ) + for parameter_name, tied_parameter in self._tied_parameter_duplicates[stage_index].items() + } + if self._mode.support_forward and stage_index in self._stages_on_device + else None + ) stage.setup( distributed=self._distributed, weight_shards=stage_weight_shards, @@ -328,6 +344,7 @@ def _setup_stages(self) -> None: grad_buffers=stage_grad_buffers, mode=self._mode if stage_index in self._stages_on_device else StageMode.off_device, is_tied_weight_copy=stage_index in self._stages_on_device and stage_index not in self._stages_owned, + tied_parameter_duplicate_buffers=tied_weight_duplicate_buffers, weight_buffer_shared_with=weight_buffer_shared_with, ) @@ -510,12 +527,9 @@ def _split_into_stages(self) -> list[int]: # Create stages (greedy split, could do better). stage_splits = [0] layer_counter, last_counter = 0, 0 - for i, layer in enumerate(self._base_model): + for i, layer in enumerate(self._layers): layer_counter += layer.layer_count # noqa - if ( - layer_counter >= last_counter + self._config.multi_stage.layers_per_stage - or i == len(self._base_model) - 1 - ): + if layer_counter >= last_counter + self._config.multi_stage.layers_per_stage or i == len(self._layers) - 1: stage_splits.append(i + 1) last_counter = layer_counter return stage_splits @@ -538,17 +552,43 @@ def _get_buffer_placement(self, num_shared_buffers: int | None) -> tuple[list[se } return buffer_contents, buffer_indices - def _get_tied_parameters(self, stage_ends) -> dict[str, "TiedParameter"]: + def _get_tied_parameters(self) -> dict[str, "TiedParameter"]: tied_parameters = {} - for name, (meta, layer_indexes) in self._base_model.get_tied_weights().items(): - Assert.eq(list(layer_indexes), sorted(layer_indexes)) - Assert.incl(meta, list(self._base_model[layer_indexes[0]].parameters())) - stage_indexes = sorted({np.searchsorted(stage_ends, i, side="right").item() for i in layer_indexes}) + for name, metas in self._base_model.get_tied_parameters().items(): + if len(metas) <= 1: + continue + stage_indexes = [self._parameter_stages[meta.tensor_name] for meta in metas] + # TODO: Ambiguous if multiple instances are on the same stage? + Assert.eq( + sorted(stage_indexes), + stage_indexes, + msg="Tied parameters should be provided in the order they appear in the model.", + ) + for meta in metas[1:]: + # TODO: Improve. Compare initializations? (Not currently possible) + if ( + len(meta.dims) != len(metas[0].dims) + or any(dim != dim_ for dim, dim_ in zip(meta.dims, metas[0].dims, strict=True)) + or meta.sequence_tensor_parallel != metas[0].sequence_tensor_parallel + ): + raise ValueError( + f"Tied parameter group `{name}` has incompatible tied parameters {metas[0]} and {meta}." + ) + if ( + meta.requires_grad != metas[0].requires_grad + or meta.lr_scale != metas[0].lr_scale + or meta.param_weight_decay != metas[0].param_weight_decay + ): + logger.warning( + f"Tied parameters `{metas[0]}` and `{meta}` in tied parameter group `{name}` have different optimization parameters." + f" Only those of `{metas[0].tensor_name}` will be used." + ) + all_ranks = {self._stage_pipeline_ranks[stage_index] for stage_index in stage_indexes} tied_parameters[name] = TiedParameter( name=name, - meta=meta, + metas=tuple(metas), all_ranks=all_ranks, on_device=self._config.distributed.pipeline_rank in all_ranks, main_stage=stage_indexes[0], @@ -560,11 +600,11 @@ def _get_tied_parameters(self, stage_ends) -> dict[str, "TiedParameter"]: class TiedParameter: name: str # Parameter definition. - meta: ParameterMeta + metas: tuple[ParameterMeta, ...] # Whether the local rank is involved at all. on_device: bool # Process group for reduction. - group: ProcessGroup | None = dataclasses.field(init=False) + group: ProcessGroup | None = dataclasses.field(repr=False, init=False) all_ranks: set[int] # The index of the main stage. main_stage: int diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 7829c243b..9f5543590 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -47,6 +47,7 @@ def setup( # noqa grad_buffers: list[torch.Tensor | None] | None = None, mode: StageMode = StageMode.training, is_tied_weight_copy: bool = False, + tied_parameter_duplicate_buffers: dict[str, torch.nn.Parameter] | None = None, weight_buffer_shared_with: collections.abc.Sequence["Stage"] = (), ) -> None: super().setup( @@ -56,6 +57,7 @@ def setup( # noqa weight_buffers=weight_buffers, grad_buffers=grad_buffers, mode=mode, + tied_parameter_duplicate_buffers=tied_parameter_duplicate_buffers, ) self._is_tied_weight_copy = is_tied_weight_copy if self._mode.support_forward: @@ -68,6 +70,9 @@ def setup( # noqa self._accumulators = [] with torch.enable_grad(): for meta in self._parameter_metas: + if meta.tensor_name in self._tied_parameter_duplicates: + # Already handled in the main stage. + continue buffer = self.get_parameter_buffer(meta.tensor_name) if not buffer.requires_grad: continue @@ -139,7 +144,8 @@ def forward( else: # TODO: Handle variable shape. output_global = output - kwargs["hidden_states"][self._layer_range[i]] = { + + kwargs["hidden_states"][self._layers[i].module_name] = { "layer_type": type(layer).__name__, "tensor": output_global, } @@ -223,9 +229,9 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] and self._distributed.tensor_group is not None and not self._meta_outputs[i].is_tensor_parallel ): - check_parallel_match(output, self._distributed.tensor_group, f"layer {self._layer_range[i]} fw") + check_parallel_match(output, self._distributed.tensor_group, f"layer {self._layers[i].module_name} fw") if self._config.debug_layer_outputs: - name = f"layer {self._layer_range[i]} fw" + name = f"{self._layers[i].module_name} fw" if (nmb := kwargs.get("num_micro_batches", 1)) > 1: name = f"{name}, mb={kwargs.get('micro_batch',0)}/{nmb}" if (nms := kwargs.get("micro_batch_splits", 1)) > 1: @@ -242,7 +248,7 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] meta=self._meta_outputs[i], ) if self._config.debug_activation_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"layer {self._layer_range[i]} fw", str)) + log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"layer {self._layers[i].module_name} fw", str)) def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any], i: int) -> None: if not input_.requires_grad: @@ -254,11 +260,11 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any ): input_.register_hook( lambda grad: check_parallel_match( - grad, self._distributed.tensor_group, f"layer {self._layer_range[i]} bw" + grad, self._distributed.tensor_group, f"layer {self._layers[i].module_name} bw" ) ) if self._config.debug_layer_gradients: - name = f"layer {self._layer_range[i]} bw" + name = f"{self._layers[i].module_name} bw" if (nmb := kwargs.get("num_micro_batches", 1)) > 1: name = f"{name}, mb={kwargs.get('micro_batch',0)}/{nmb}" if (nms := kwargs.get("micro_batch_splits", 1)) > 1: @@ -276,6 +282,6 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any if self._config.debug_activation_memory: input_.register_hook( lambda grad: log_pipeline_parallel_main_rank( - lambda: log_memory_usage(f"layer {self._layer_range[i]} bw", str) + lambda: log_memory_usage(f"layer {self._layers[i].module_name} bw", str) ) ) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index ded24e538..96d80ce06 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -6,7 +6,7 @@ from fast_llm.config import Configurable from fast_llm.core.distributed import check_parallel_match -from fast_llm.engine.base_model.base_model import BaseModel, Layer +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed @@ -28,27 +28,22 @@ def __init__( self, *, config: StageConfig, - base_model: BaseModel | list[Layer], + layers: list[Layer], distributed_config: DistributedConfig, - begin: int, - end: int, index: int, + tied_parameter_duplicates: typing.Iterable[str] = (), ): super().__init__(config) self._distributed_config = distributed_config.validate() - Assert.in_range(begin, 0, end) - Assert.leq(end, len(base_model)) - self._fsdp_rank = self._distributed_config.data_rank self._fsdp_size = self._distributed_config.data_parallel self._is_setup = False self._index = index + self._layers = layers + self._tied_parameter_duplicates = set(tied_parameter_duplicates) - self._layers = [torch.compile(layer) if self._config.compile_all else layer for layer in base_model[begin:end]] - self._layer_range = list(range(begin, end)) - - parameter_metas, frozen_metas = self._get_parameter_metas() - self._parameter_metas = parameter_metas + frozen_metas + parameter_metas, frozen_metas, duplicate_metas = self._get_parameter_metas() + self._parameter_metas = parameter_metas + frozen_metas + duplicate_metas self._fsdps = [] if parameter_metas: self._fsdps.append( @@ -113,6 +108,7 @@ def setup( weight_buffers: list[torch.Tensor | None] | None, grad_buffers: list[torch.Tensor | None] | None, mode: StageMode = StageMode.training, + tied_parameter_duplicate_buffers: dict[str, torch.nn.Parameter] | None, ) -> None: assert not self._is_setup distributed.check_config(self._distributed_config) @@ -149,7 +145,11 @@ def _replace(module: torch.nn.Module): nonlocal i for key in module._parameters: meta = typing.cast(ParameterMeta, module._parameters[key]) - module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) + if meta.tensor_name in self._tied_parameter_duplicates: + assert tied_parameter_duplicate_buffers is not None + module._parameters[key] = tied_parameter_duplicate_buffers.pop(meta.tensor_name) + else: + module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) i += 1 i = 0 @@ -157,6 +157,7 @@ def _replace(module: torch.nn.Module): layer.apply(_replace) Assert.eq(i, len(self._parameter_metas)) + assert not tied_parameter_duplicate_buffers, tied_parameter_duplicate_buffers.keys() def initialize_weights(self) -> None: # TODO: Avoid all the _on_device checks @@ -179,6 +180,9 @@ def initialize_weights(self) -> None: ] for meta in metas: + if meta.tensor_name in self._tied_parameter_duplicates: + # Initialization is not managed by this stage. + continue fsdp = self._fsdps[fsdp_index := self._fsdp_index[meta.tensor_name]] parameter = weight_shards_split[fsdp_index][meta.tensor_name] # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) @@ -316,24 +320,31 @@ def _export_shard( for fsdp, shard in zip(self._fsdps, shards, strict=True): yield from fsdp.export_shard(shard, data_type) - def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta]]: + def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta], list[ParameterMeta]]: # Get all the stage parameters, # then separate the parameters with and without weight decay, # and squeeze the non-tensor parallel and sequence parallel ones in the middle. # This allows running the optimizer, grad norm and sequence_parallel reduction on contiguous buffers. parameter_metas: list[ParameterMeta] = [] frozen_metas: list[ParameterMeta] = [] + duplicate_metas: list[ParameterMeta] = [] meta: ParameterMeta for layer in self._layers: - for name, meta in layer.named_parameters(): + for meta in layer.parameters(): Assert.custom(isinstance, meta, ParameterMeta) Assert.eq(meta.dtype, self._distributed_config.optimization_dtype.torch) - if meta.requires_grad: + if meta.tensor_name in self._tied_parameter_duplicates: + duplicate_metas.append(meta) + elif meta.requires_grad: parameter_metas.append(meta) else: frozen_metas.append(meta) - return self._reorder_parameter_metas(parameter_metas), self._reorder_parameter_metas(frozen_metas) + return ( + self._reorder_parameter_metas(parameter_metas), + self._reorder_parameter_metas(frozen_metas), + self._reorder_parameter_metas(duplicate_metas), + ) @classmethod def _reorder_parameter_metas(cls, parameter_metas): diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index dbdd035a4..133b3206b 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -95,7 +95,7 @@ def __init__( self._num_stages = len(self._stages) self._loss_definitions = { loss_definition.name: loss_definition - for loss_definition in self._multi_stage.base_model.config.get_loss_definitions() + for loss_definition in self._multi_stage.base_model.get_loss_definitions() } def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: @@ -324,7 +324,7 @@ def _preprocess_data( for micro_batch in range(batch_config.sequential_micro_batches): micro_batch_data = next(data_iterator) if not preprocessed: - micro_batch_data = self._multi_stage.base_model.preprocess( + micro_batch_data = self._multi_stage.base_model.preprocess_batch( micro_batch_data, context.schedule.preprocessed_meta, phase=context.phase, @@ -339,11 +339,6 @@ def _preprocess_data( num_micro_batches=batch_config.sequential_micro_batches, micro_batch_splits=batch_config.micro_batch_splits, ) - for name, tied_parameter in self._tied_parameters.items(): - if tied_parameter.on_device: - kwargs[name] = self._stages[tied_parameter.main_stage].get_parameter_buffer( - tied_parameter.meta.tensor_name - ) data_index = context.schedule.get_data_index(micro_batch, micro_batch_split) if self._stages_owned[0]: context.inputs[context.schedule.get_step(StepType.forward, 0, data_index).global_index] = input_ diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index a752bec28..aa4f2d570 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -149,7 +149,7 @@ def __init__(self, config: TrainerConfig): multi_stage=self._multi_stage, distributed_config=self._config.model.distributed, ) - self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions() + self._loss_definitions = self._multi_stage.base_model.get_loss_definitions() if not self._is_evaluation_only: steps_per_split = { @@ -320,7 +320,7 @@ def _run_training(self) -> None: phase=PhaseType.test, num_iters=self._config.training.test_iters, ) - formatted_metrics = format_metrics(metrics[metrics_key], self._loss_defs, PhaseType.test) + formatted_metrics = format_metrics(metrics[metrics_key], self._loss_definitions, PhaseType.test) log_main_rank(formatted_metrics) self._wandb.alert("Testing results", formatted_metrics, "WARN") # TODO: This may erase some metrics. @@ -331,7 +331,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: advanced_iters = 0 skipped_iters = 0 nan_iters = 0 - total_losses = {loss_def.name: 0.0 for loss_def in self._loss_defs} + total_losses = {loss_def.name: 0.0 for loss_def in self._loss_definitions} # Profiling profiler = self._config.profiling.get_profiler( @@ -435,7 +435,9 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: **get_and_reset_memory_usage_mib(), } - formatted_metrics = format_metrics(metrics[metrics_key], self._loss_defs, PhaseType.training) + formatted_metrics = format_metrics( + metrics[metrics_key], self._loss_definitions, PhaseType.training + ) logger.info(formatted_metrics) if self._config.training.wandb.alert.enabled(self._completed_steps): self._wandb.alert("Training results", formatted_metrics, "INFO") @@ -443,7 +445,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: advanced_iters = 0 skipped_iters = 0 nan_iters = 0 - total_losses = {loss_def.name: 0.0 for loss_def in self._loss_defs} + total_losses = {loss_def.name: 0.0 for loss_def in self._loss_definitions} self._run.save_logged_tensors(f"train_{self._completed_steps}") diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 9a940f4cb..167184193 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -56,6 +56,11 @@ class Attention[ConfigType: AttentionConfig](BlockWithBias[ConfigType]): _config: ConfigType + # Preprocessing + _backup_attention_mask: torch.Tensor + _backup_attention_mask_value: torch.Tensor + _backup_attention_tensor_cache_max_sequence_length: int = -1 + def __init__( self, config: ConfigType, @@ -64,6 +69,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( config, @@ -71,6 +77,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) @@ -273,7 +280,7 @@ def _query_key_value_backward( input_grad.add_(self.key_value.backward(key_value_grad, context.pop("key_value"))) return input_grad - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], @@ -340,7 +347,7 @@ def forward( max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), dropout_p=self._config.dropout if self.training else 0.0, window_size=window_size, - causal=True, + causal=self._config.causal, softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -350,7 +357,7 @@ def forward( value, window_size=window_size, dropout_p=self._config.dropout if self.training else 0.0, - causal=True, + causal=self._config.causal, softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) @@ -429,3 +436,115 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c self.dense.get_compute_usage(dense_input, config), ) ) + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._rotary.preprocess(batch, kwargs) + if not self._use_flash_attention: + self._preprocess_for_backup_attention(batch, kwargs) + elif AttentionKwargs.sequence_lengths in kwargs: + self._preprocess_for_varlen(batch, kwargs) + + def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + if ( + sequence_length := kwargs[AttentionKwargs.sequence_length] + ) > self._backup_attention_tensor_cache_max_sequence_length: + # Create tensor cache. + self._backup_attention_tensor_cache_max_sequence_length = sequence_length + + self._backup_attention_mask = torch.ones( + (sequence_length, sequence_length), + dtype=torch.bool, + device=batch.device, + ).tril_() + + if self._config.window_size is not None: + self._backup_attention_mask.triu_(-self._config.window_size + 1) + self._backup_attention_mask_value = torch.full( + [], + torch.finfo(self._distributed_config.compute_dtype.torch).min, + dtype=self._distributed_config.compute_dtype.torch, + device=batch.device, + ) + + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + kwargs[AttentionKwargs.attention_mask] = self._backup_attention_mask[ + None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k + ] + if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: + seq_ids = torch.stack( + [ + torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) + for sample_lens in sequence_lengths + ] + ) + document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) + kwargs[AttentionKwargs.attention_mask] = ( + kwargs[AttentionKwargs.attention_mask] + & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] + ) + kwargs[AttentionKwargs.attention_mask_value] = self._backup_attention_mask_value + + def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + """ + Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 + cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively. + Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k. + If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally + also contain previous tokens from the first document in micro-sequence. + We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. + """ + if AttentionKwargs.sequence_lengths not in kwargs: + return + sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + if sequence_q < kwargs[AttentionKwargs.sequence_length]: + cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] + # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents + # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets + # of the first documents so that we can index into their kv pairs + start_seq_idx = [ + torch.argmax((cu_seqlens >= sequence_k - sequence_q).to(torch.uint8), dim=0) for cu_seqlens in cumsums + ] + end_seq_idx = [torch.argmax((cu_seqlens >= sequence_k).to(torch.uint8), dim=0) for cu_seqlens in cumsums] + seqlens_q = [] + seqlens_k = [] + for idx, sample_seqlens in enumerate(sequence_lengths): + start_idx = start_seq_idx[idx] + end_idx = end_seq_idx[idx] + seqlens_q.extend([0] * start_idx) + n_attention_tokens = sample_seqlens[end_idx] - (cumsums[idx][end_idx] - sequence_k) + if start_idx == end_idx: + seqlens_q.append(sequence_q) + else: + start_q_tokens = cumsums[idx][start_idx] - (sequence_k - sequence_q) + seqlens_q.extend( + [ + start_q_tokens, + *(sample_seqlens[idx] for idx in range(start_idx + 1, end_idx)), + n_attention_tokens, + ] + ) + seqlens_k.extend(sample_seqlens[: end_idx + 1]) + seqlens_k[-1] = n_attention_tokens + seqlens_q = torch.tensor(seqlens_q, dtype=torch.int32) + seqlens_k = torch.tensor(seqlens_k, dtype=torch.int32) + else: + seqlens_q = torch.cat(sequence_lengths) + seqlens_k = torch.cat(sequence_lengths) + kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( + ( + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(batch.device), + ) + ) + kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( + ( + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(batch.device), + ) + ) + kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() + kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 2910c7c76..68b6dde91 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -3,9 +3,7 @@ import warnings from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig @@ -80,6 +78,11 @@ class AttentionConfig(MixerConfig): desc="Add biases to linear layers. May be overridden for individual layers.", hint=FieldHint.architecture, ) + causal: bool = Field( + default=True, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) dropout: float = Field( default=0.0, desc="Dropout applied to the attention intermediate states.", @@ -121,19 +124,3 @@ def layer_class(self) -> "type[Attention]": def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) - - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. - # TODO: Find a better solution. - preprocessors: list[Preprocessor] = [ - self.rotary.get_layer(TensorDim("head_size", self.head_size)), - ] - if self.do_use_flash_attention(distributed_config): - from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor - - preprocessors.append(FlashAttnVarlenPreprocessor(self, distributed_config)) - else: - from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor - - preprocessors.append(BackupAttentionPreprocessor(self, distributed_config)) - return preprocessors diff --git a/fast_llm/layers/attention/preprocessing.py b/fast_llm/layers/attention/preprocessing.py deleted file mode 100644 index 204c08ad2..000000000 --- a/fast_llm/layers/attention/preprocessing.py +++ /dev/null @@ -1,153 +0,0 @@ -import logging -import typing - -import torch - -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs -from fast_llm.tensor import TensorMeta - -logger = logging.getLogger(__name__) - - -class BackupAttentionPreprocessor(Preprocessor): - _head_size_dim: TensorDim - _rotary_embedding_frequencies: torch.Tensor - _mask: torch.Tensor - _mask_value: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - - def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): - self._config = config - self._distributed_config = distributed_config - assert not self._config.do_use_flash_attention(self._distributed_config) - - def _create_tensors(self, sequence_length: int, device: torch.device) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - self._mask = torch.ones( - (sequence_length, sequence_length), - dtype=torch.bool, - device=device, - ).tril_() - - if self._config.window_size is not None: - self._mask.triu_(-self._config.window_size + 1) - self._mask_value = torch.full( - [], - torch.finfo(self._distributed_config.compute_dtype.torch).min, - dtype=self._distributed_config.compute_dtype.torch, - device=device, - ) - - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length], batch.device) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size - kwargs[AttentionKwargs.attention_mask] = self._mask[ - None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k - ] - if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: - seq_ids = torch.stack( - [ - torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) - for sample_lens in sequence_lengths - ] - ) - document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) - kwargs[AttentionKwargs.attention_mask] = ( - kwargs[AttentionKwargs.attention_mask] - & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] - ) - kwargs[AttentionKwargs.attention_mask_value] = self._mask_value - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( - ( - scalar_dim, - scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - scalar_dim, - kwargs[AttentionKwargs.sequence_k_dim], - ), - tensor_name=AttentionKwargs.attention_mask, - dtype=torch.bool, - ) - kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( - (scalar_dim,), - tensor_name=AttentionKwargs.attention_mask_value, - dtype=self._distributed_config.compute_dtype.torch, - ) - - -class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): - assert config.do_use_flash_attention(distributed_config) - - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - """ - Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 - cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively. - Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k. - If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally - also contain previous tokens from the first document in micro-sequence. - We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. - """ - if AttentionKwargs.sequence_lengths not in kwargs: - return - sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size - if sequence_q < kwargs[AttentionKwargs.sequence_length]: - cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] - # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents - # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets - # of the first documents so that we can index into their kv pairs - start_seq_idx = [ - torch.argmax((cu_seqlens >= sequence_k - sequence_q).to(torch.uint8), dim=0) for cu_seqlens in cumsums - ] - end_seq_idx = [torch.argmax((cu_seqlens >= sequence_k).to(torch.uint8), dim=0) for cu_seqlens in cumsums] - seqlens_q = [] - seqlens_k = [] - for idx, sample_seqlens in enumerate(sequence_lengths): - start_idx = start_seq_idx[idx] - end_idx = end_seq_idx[idx] - seqlens_q.extend([0] * start_idx) - n_attention_tokens = sample_seqlens[end_idx] - (cumsums[idx][end_idx] - sequence_k) - if start_idx == end_idx: - seqlens_q.append(sequence_q) - else: - start_q_tokens = cumsums[idx][start_idx] - (sequence_k - sequence_q) - seqlens_q.extend( - [ - start_q_tokens, - *(sample_seqlens[idx] for idx in range(start_idx + 1, end_idx)), - n_attention_tokens, - ] - ) - seqlens_k.extend(sample_seqlens[: end_idx + 1]) - seqlens_k[-1] = n_attention_tokens - seqlens_q = torch.tensor(seqlens_q, dtype=torch.int32) - seqlens_k = torch.tensor(seqlens_k, dtype=torch.int32) - else: - seqlens_q = torch.cat(sequence_lengths) - seqlens_k = torch.cat(sequence_lengths) - kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( - ( - torch.zeros(1, dtype=torch.int32, device=batch.device), - torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(batch.device), - ) - ) - kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( - ( - torch.zeros(1, dtype=torch.int32, device=batch.device), - torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(batch.device), - ) - ) - kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 5bd7a9b87..26877ee0c 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -4,7 +4,7 @@ import warnings from fast_llm.config import Field, FieldHint, config_class -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.utils import Assert @@ -14,7 +14,7 @@ @config_class(registry=True) -class RotaryConfig(BaseModelConfig): +class RotaryConfig(ModuleConfig): # TODO: Move rotary to its own submodule. @classmethod diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 889711839..d57d72947 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -5,8 +5,7 @@ import torch from fast_llm.config import Configurable -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.attention.rotary.config import ( @@ -16,7 +15,6 @@ RotaryConfig, YarnRotaryConfig, ) -from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -41,7 +39,7 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) -class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module, Preprocessor): +class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module): def __init__( self, config: ConfigType, @@ -56,6 +54,9 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor]: pass + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + pass + class NoRotary[ConfigType: NoRotaryConfig](Rotary[ConfigType]): def forward( @@ -63,12 +64,6 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor]: return query, key - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - pass - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - pass - class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[ConfigType]): _rotary_embedding_frequencies: torch.Tensor @@ -82,26 +77,6 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None ] kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( - ( - scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - scalar_dim, - self._head_size_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_q, - ) - kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( - ( - scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - scalar_dim, - self._head_size_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_k, - ) - def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 773cce87e..ab6cb22b0 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -4,13 +4,13 @@ import torch -from fast_llm.config import Config, Configurable -from fast_llm.engine.base_model.base_model import Layer, Module -from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.config import Configurable +from fast_llm.engine.base_model.base_model import Layer, LayerBase +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.logging import get_model_debug_level, log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -93,9 +93,9 @@ def __call__[ ) -class BaseBlock[ConfigType: Config](Configurable[ConfigType], Module): +class BlockBase[ConfigType: ModuleConfig](Configurable[ConfigType], LayerBase): """ - Base class for blocks and block-like layers (mlp, mixers, etc.). + Base class for blocks and block-like layers (mlp, mixers, block sequences, etc.). """ def __init__( @@ -115,24 +115,6 @@ def __init__( self._lr_scale = lr_scale self._peft = peft - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - raise NotImplementedError() - -class Block[ConfigType: Config](BaseBlock[ConfigType], Layer): - """ - Base class for actual blocks, i.e., base blocks that are also `Layers`. - """ - - def __init__( - self, - config: ConfigType, - distributed_config: DistributedConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - return_input: bool = False, - ): - super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) - self._return_input = return_input +class Block[ConfigType: BlockConfig](BlockBase[ConfigType], Layer): + pass diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index df5bd8181..f3e93edeb 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,11 +1,9 @@ -import abc -import collections import functools import typing import warnings from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, Preprocessor +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -13,7 +11,8 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.block.block import Block + from fast_llm.layers.block.block import BlockBase + from fast_llm.layers.block.sequence import FixedBlockSequence, PatternBlockSequence class BlockDimNames: @@ -40,8 +39,8 @@ class BlockKwargs: grad_output = "grad_output" -@config_class() -class BaseBlockConfig(BaseModelConfig): +@config_class(registry=True) +class BlockConfig(ModuleConfig): """ Base configuration class for blocks and block-like layers (mlp, mixers, etc.). """ @@ -55,19 +54,6 @@ class BaseBlockConfig(BaseModelConfig): hint=FieldHint.feature, ) - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - return [] - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return [] - - -@config_class(registry=True) -class BlockConfig(BaseBlockConfig): - """ - Base configuration class for actual blocks, i.e., base blocks that are also `Layers`. - """ - @classmethod def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is BlockConfig and cls.get_subclass(default.get("type")) is None: @@ -78,29 +64,27 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi return super()._from_dict(default, strict=strict) @property - def layer_class(self) -> "type[Block]": + def layer_class(self) -> "type[BlockBase]": raise NotImplementedError() - def get_block( + def get_layer( self, distributed_config: DistributedConfig, hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - return_input: bool = False, - ) -> "Block": + ) -> "BlockBase": return self.layer_class( self, distributed_config, hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, - return_input=return_input, ) @config_class(registry=True) -class BlockSequenceConfig(BaseModelConfig): +class BlockSequenceConfig(BlockConfig): @classmethod def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is BlockSequenceConfig and cls.get_subclass(default.get("type")) is None: @@ -108,21 +92,6 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi return FixedBlockSequenceConfig._from_dict(default, strict) return super()._from_dict(default, strict=strict) - @abc.abstractmethod - def __len__(self) -> int: - pass - - @abc.abstractmethod - def __getitem__(self, index: int) -> BlockConfig: - pass - - @abc.abstractmethod - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - pass - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return [] - @config_class(dynamic_type={BlockSequenceConfig: "fixed"}) class FixedBlockSequenceConfig(BlockSequenceConfig): @@ -138,18 +107,11 @@ class FixedBlockSequenceConfig(BlockSequenceConfig): valid=check_field(Assert.geq, 0), ) - def __len__(self) -> int: - return self.num_blocks - - def __getitem__(self, index: int) -> BlockConfig: - return self.block - - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # TODO: Prevent name conflicts in preprocessed kwargs. - return self.block.get_preprocessors(distributed_config) + @property + def layer_class(self) -> "type[FixedBlockSequence]": + from fast_llm.layers.block.sequence import FixedBlockSequence - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.block.get_loss_definitions(count=count * self.num_blocks) + return FixedBlockSequence @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) @@ -182,26 +144,18 @@ def _validate(self): super()._validate() - def __len__(self) -> int: - return self.num_blocks + @property + def layer_class(self) -> "type[PatternBlockSequence]": + from fast_llm.layers.block.sequence import PatternBlockSequence - def __getitem__(self, index: int) -> BlockConfig: - return self.blocks[self.expanded_pattern[index]] + return PatternBlockSequence @functools.cached_property def expanded_pattern(self) -> list[str]: + # The complete list of block names, expanded to `num_blocks` return (self.pattern * (self.num_blocks // len(self.pattern) + 1))[: self.num_blocks] - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # TODO: Prevent name conflicts in preprocessed kwargs. - return sum((block.get_preprocessors(distributed_config) for block in self.blocks.values()), []) - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - # TODO: Prevent name conflicts. - return sum( - ( - self.blocks[name].get_loss_definitions(count=count * count_) - for name, count_ in collections.Counter(self.expanded_pattern).items() - ), - [], - ) + @functools.cached_property + def preprocessing_layers(self) -> dict[str, int]: + # The index at which each block first appears. These blocks are used for preprocessing. + return {name: self.expanded_pattern.index(name) for name in set(self.expanded_pattern)} diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index e69de29bb..530df950e 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -0,0 +1,124 @@ +import collections +import functools +import typing + +import torch.nn + +from fast_llm.engine.base_model.base_model import Layer, LayerWithNamespace +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.common.peft.config import PeftConfig + + +class FixedBlockSequence[ConfigType: FixedBlockSequenceConfig](BlockBase[ConfigType], torch.nn.ModuleList): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + + self.extend( + [ + self._config.block.get_layer( + distributed_config, + hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + for _ in range(self._config.num_blocks) + ] + ) + + @functools.cached_property + def _layers_with_namespace(self) -> list[Layer]: + # This needs to be in a property because `module_name` is set after `__init__`. + # Wrap all blocks in a namespace using the unique module name of the first one. + namespace = self[0].module_name if self._config.num_blocks > 0 else "" + return [LayerWithNamespace(sublayer, namespace) for layer in self for sublayer in layer.get_layers()] + + def get_layers(self) -> list["Layer"]: + return self._layers_with_namespace + + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + self._layers_with_namespace[0].preprocess(batch, kwargs) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return ( + self[0].get_loss_definitions(count=count * self._config.num_blocks) if self._config.num_blocks > 0 else [] + ) + + +class PatternBlockSequence[ConfigType: PatternBlockSequenceConfig](BlockBase[ConfigType], torch.nn.ModuleList): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + self.extend( + [ + self._config.blocks[name].get_layer( + distributed_config, + hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + for name in self._config.expanded_pattern + ] + ) + + @functools.cached_property + def _layers_with_namespace(self) -> list[Layer]: + # This needs to be in a property because `module_name` is set after `__init__`. + # Wrap each set of blocks with identical config in a namespace + # using the unique module name of the first such block. + return [ + LayerWithNamespace(sublayer, self[self._config.preprocessing_layers[name]].module_name) + for name, layer in zip(self._config.expanded_pattern, self) + for sublayer in layer.get_layers() + ] + + def get_layers(self) -> list[Layer]: + return self._layers_with_namespace + + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + for _, index in self._config.preprocessing_layers.items(): + self._layers_with_namespace[index].preprocess(batch, kwargs) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + # TODO: Prevent name conflicts. + return sum( + ( + self[self._config.preprocessing_layers[name]].get_loss_definitions(count=count * count_) + for name, count_ in collections.Counter(self._config.expanded_pattern).items() + ), + [], + ) diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index c1ced10df..a80a19280 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -3,7 +3,7 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import ParameterConfig, combine_lr_scales from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -26,7 +26,7 @@ class NormalizationImplementation(str, enum.Enum): @config_class(registry=True) -class NormalizationConfig(BaseModelConfig): +class NormalizationConfig(ModuleConfig): lr_scale: float | None = Field( default=None, desc="Scaling factor for the layer learning rate." diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index ba4c370c2..8b19db66a 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -4,33 +4,58 @@ import torch -from fast_llm.config import Config from fast_llm.core.distributed import set_generator -from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.block import BaseBlock, Block +from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) -class BlockWithBias[ConfigType: Config](BaseBlock[ConfigType]): +class BlockWithBias[ConfigType: BlockWithBiasConfig](Block[ConfigType]): """ Base class for mixer and MLP modules. """ - @abc.abstractmethod + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) + self._return_bias = return_bias + def forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor: + output, bias = self._forward(input_, kwargs, losses, metrics) + if self._return_bias: + return output, bias + else: + return output if bias is None else output + bias + + @abc.abstractmethod + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: pass @@ -58,18 +83,16 @@ def __init__( peft=peft, ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. - self._return_input: bool = return_input - # Note, layer_lr_scale does not impact the norms - # TODO: add a separate norm_lr_scale + self._return_input = return_input self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) self.norm_2 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) - # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. self.mixer = self._config.mixer.get_layer( self._distributed_config, self._hidden_dim, self._lr_scale, peft=peft, + return_bias=True, ) self.mlp = self._config.mlp.get_layer( @@ -77,6 +100,7 @@ def __init__( self._hidden_dim, self._lr_scale, peft=peft, + return_bias=True, ) def setup(self, distributed: Distributed) -> None: @@ -150,3 +174,10 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c self.mlp.get_compute_usage(input_, kwargs, config), ) ) + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self.mixer.preprocess(batch, kwargs) + self.mlp.preprocess(batch, kwargs) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 5f8131b5c..403b204c8 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -1,11 +1,10 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import LossDef, Preprocessor from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.config import BaseBlockConfig, BlockConfig +from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -15,7 +14,7 @@ @config_class() -class BlockWithBiasConfig(BaseBlockConfig): +class BlockWithBiasConfig(BlockConfig): """ A common interface for various blocks and block layers. """ @@ -30,6 +29,7 @@ def get_layer( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = False, ) -> "BlockWithBias": return self.layer_class( self, @@ -37,6 +37,7 @@ def get_layer( hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, + return_bias=return_bias, ) @@ -94,8 +95,19 @@ def layer_class(self) -> "type[DecoderBlock]": return DecoderBlock - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - return self.mixer.get_preprocessors(distributed_config) + self.mlp.get_preprocessors(distributed_config) - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) + def get_layer( + self, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_input: bool = False, + ) -> "DecoderBlock": + return self.layer_class( + self, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + return_input=return_input, + ) diff --git a/fast_llm/layers/decoder/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py index 100f53740..36841b45b 100644 --- a/fast_llm/layers/decoder/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -3,7 +3,6 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import LossDef from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.layers.common.linear.config import AffineLinearConfig, LinearConfig from fast_llm.layers.decoder.config import MLPBaseConfig @@ -152,23 +151,3 @@ def _validate(self) -> None: super()._validate() Assert.leq(self.shared_experts, self.experts) Assert.leq(self.shared_experts + self.experts_per_token, self.experts) - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_definitions = [] - if self.routing == RoutingType.topk: - loss_definitions.append( - LossDef( - name=MLPLossNames.load_balancing_loss, - formatted_name="load balancing loss", - count=1, - ) - ) - if self.z_loss_coefficient: - loss_definitions.append( - LossDef( - name=MLPLossNames.router_z_loss, - formatted_name="router z loss", - count=1, - ) - ) - return loss_definitions diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 089fa2dc7..ffc9eadba 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -5,7 +5,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, set_generator -from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -46,6 +46,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): Assert.gt(config.experts, 1) # TODO: Implement? @@ -56,6 +57,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) self.router = self._config.router.get_layer( self._hidden_dim, @@ -83,9 +85,9 @@ def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: CompositeTensorDim("moe_intermediate_2", (experts_dim, intermediate_2_dim)), ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, None]: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug.enabled: @@ -261,6 +263,26 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c return super().get_compute_usage(moe_input, kwargs, config) + self.router.get_compute_usage(input_, config) + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + loss_definitions = [] + if self._config.routing == RoutingType.topk: + loss_definitions.append( + LossDef( + name=MLPLossNames.load_balancing_loss, + formatted_name="load balancing loss", + count=1, + ) + ) + if self._config.z_loss_coefficient: + loss_definitions.append( + LossDef( + name=MLPLossNames.router_z_loss, + formatted_name="router z loss", + count=1, + ) + ) + return loss_definitions + def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 9dd17d698..aaea94adb 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -28,6 +28,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( config, @@ -35,6 +36,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, self._intermediate_2_dim = self._get_intermediate_dims() @@ -102,7 +104,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): _config: MLPConfig - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index f59b4cffd..d2fbc4909 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,7 +1,8 @@ +import abc import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, Preprocessor +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -9,25 +10,13 @@ from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.layers.language_model.embedding import LanguageModelEmbedding - from fast_llm.layers.language_model.head import LanguageModelHead - - -class LanguageModelLossNames: - language_model_loss = "language_model_loss" - z_loss = "z_loss" - dpo_loss = "dpo_loss" - distil_lm_loss = "distillation_language_model_loss" # the next token perdiciton of combined distillation loss - distillation_loss = "distillation_loss" - - @staticmethod - def multi_token_prediction_loss(index: int) -> str: - if index == 0: - return LanguageModelLossNames.language_model_loss - return f"language_model_loss_{index}" + from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase + from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction class LanguageModelKwargs(BlockKwargs): @@ -100,17 +89,42 @@ def layer_class(self) -> "type[LanguageModelEmbedding]": return LanguageModelEmbedding - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - preprocessors = [] - if self.position_embeddings.enabled: - from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor - preprocessors.append(PositionEmbeddingPreprocessor(self, distributed_config)) - return preprocessors +@config_class(registry=True) +class LanguageModelHeadBaseConfig(BlockConfig): + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is LanguageModelHeadBaseConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return LanguageModelHeadConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + def get_layer( + self, + distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ) -> "LanguageModelHeadBase": + return self.layer_class( + self, + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + ) + + @property + @abc.abstractmethod + def max_prediction_distance(self) -> int: + pass -@config_class() -class LanguageModelHeadConfig(BlockConfig): +@config_class(dynamic_type={LanguageModelHeadBaseConfig: "language_model_head"}) +class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): _abstract = False normalization: NormalizationConfig = Field( desc="Configuration for the final normalization layer.", @@ -121,17 +135,6 @@ class LanguageModelHeadConfig(BlockConfig): desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - tied_weight: bool = Field( - default=True, - desc="Tie the output weights (logits) with the vocabulary embedding.", - hint=FieldHint.architecture, - ) - prediction_heads: int = Field( - default=1, - desc="Number of multi-token prediction heads.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) cross_entropy_implementation: CrossEntropyImpl = Field( default=CrossEntropyImpl.auto, desc="Implementation for the cross-entropy computation.", @@ -173,12 +176,6 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - prediction_loss_coefficient: list[float] | None = Field( - default=None, - desc="Loss coefficient for each prediction head.", - doc="If not provided, all heads are equally weighted.", - hint=FieldHint.feature, - ) teacher_softmax_temperature: float = Field( default=1.0, desc="Divides distillation target logits by this factor.", @@ -186,9 +183,9 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - enable_dpo: bool | None = Field( - default=False, - desc="Whether to enable DPO loss", + dpo_reference_model: str | None = Field( + default=None, + desc="Name of the reference model to use for dpo.", hint=FieldHint.feature, ) dpo_beta: float | None = Field( @@ -196,11 +193,6 @@ class LanguageModelHeadConfig(BlockConfig): desc="Beta value for DPO loss.", hint=FieldHint.feature, ) - dpo_reference_model: str | None = Field( - default=None, - desc="Name of the reference model to use for dpo.", - hint=FieldHint.feature, - ) distillation_model: str | None = Field( default=None, desc="Name of the reference model to use for knowledge distillation." @@ -208,6 +200,30 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.feature, ) + def get_layer( + self, + distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + prediction_distance: int = 0, + prediction_heads: int = 1, + loss_coefficient: float = 1.0, + ): + return self.layer_class( + self, + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + prediction_distance=prediction_distance, + prediction_heads=prediction_heads, + loss_coefficient=loss_coefficient, + ) + @property def layer_class(self) -> "type[LanguageModelHead]": from fast_llm.layers.language_model.head import LanguageModelHead @@ -222,125 +238,81 @@ def _validate(self) -> None: else: self.language_model_loss_factor = 0.0 super()._validate() - if self.distillation_model is not None: - if self.prediction_heads > 1: - raise NotImplementedError("Multi-token prediction not supported with distillation.") - if isinstance(self.prediction_loss_coefficient, list): - Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) - for coeff in self.prediction_loss_coefficient: - Assert.geq(coeff, 0) - - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - preprocessors: list[Preprocessor] = [] + assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both - if self.enable_dpo: # TODO better way to pass in? - from fast_llm.layers.language_model.preprocessing import PreferenceSpanPreprocessor + @property + def max_prediction_distance(self) -> int: + return 1 - preprocessors.append(PreferenceSpanPreprocessor()) + @property + def enable_dpo(self) -> bool: + return self.dpo_reference_model is not None - return preprocessors - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_defs = [] - if self.logit_z_loss: - LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=count) +@config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) +class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): + _abstract = False + # Needs to be `DecoderBlockConfig` for the `return_input` interface. + # TODO: Make a generic wrapper for returning input instead? + block: DecoderBlockConfig = Field( + desc="Configuration for the decoder block before each head.", + hint=FieldHint.architecture, + ) + # TODO: Generalize? (needs the extra initialization arguments) + head: LanguageModelHeadConfig = Field( + desc="Configuration for the multi-token-prediction heads.", + hint=FieldHint.architecture, + ) + prediction_heads: int = Field( + default=1, + desc="Prediction heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) - if self.enable_dpo: - loss_defs.append(LossDef(name=LanguageModelLossNames.dpo_loss, formatted_name="dpo loss", count=count)) + def _validate(self) -> None: + super()._validate() + if isinstance(self.prediction_loss_coefficient, list): + Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) + for coeff in self.prediction_loss_coefficient: + Assert.geq(coeff, 0) - if self.distillation_model is not None: - loss_defs.append( - LossDef(name=LanguageModelLossNames.distillation_loss, formatted_name="distillation loss", count=count) - ) - if self.language_model_loss_factor > 0.0: - loss_defs.append( - LossDef( - name=LanguageModelLossNames.distil_lm_loss, formatted_name="distillation lm loss", count=count - ) - ) + @property + def layer_class(self) -> "type[MultiTokenPrediction]": + from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction - for i in range(self.prediction_heads): - loss_defs.append( - LossDef( - name=LanguageModelLossNames.multi_token_prediction_loss(i), - formatted_name=f"language model loss {i}", - count=count, - ) - ) - return loss_defs + return MultiTokenPrediction - def get_block( - self, - distributed_config: DistributedConfig, - embeddings_config: LanguageModelEmbeddingsConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - prediction_distance: int = 0, - ): - return self.layer_class( - self, - distributed_config, - embeddings_config, - hidden_dim=hidden_dim, - lr_scale=combine_lr_scales(lr_scale, self.lr_scale), - peft=peft, - prediction_distance=prediction_distance, - ) - - def get_blocks( - self, - distributed_config: DistributedConfig, - embeddings_config: LanguageModelEmbeddingsConfig, - mtp_block_config: BlockConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - ): - blocks = [] - for i in range(self.prediction_heads): - if i > 0: - blocks.append( - mtp_block_config.get_block( - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, - # The last block only returns the model output. - # The previous blocks return a stack of shared_hidden and transformer_output. - return_input=i < self.prediction_heads - 1, - ) - ) - blocks.append( - self.get_block( - distributed_config, - embeddings_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, - prediction_distance=i, - ) - ) - return blocks + @property + def max_prediction_distance(self) -> int: + return self.prediction_heads -# TODO: `BlockSequenceConfig`? (interface not fully compatible) @config_class() -class LanguageModelBaseConfig(BaseModelConfig): +class LanguageModelConfig(ModuleConfig): # TODO: block decoder: BlockSequenceConfig = Field( desc="Configuration for the language model decoder.", hint=FieldHint.architecture, ) - embeddings_layer: LanguageModelEmbeddingsConfig = Field() - output_layer: LanguageModelHeadConfig = Field() + embeddings: LanguageModelEmbeddingsConfig = Field() + head: LanguageModelHeadBaseConfig = Field() # TODO: Allow overriding in sub-models? peft: PeftConfig = Field( desc="Configuration for parameter-efficient fine tuning.", hint=FieldHint.architecture, ) + tied_embedding_weight: bool = Field( + default=False, + desc="Tie the output weights (logits) with the vocabulary embedding.", + hint=FieldHint.architecture, + ) sequence_first: bool | None = Field( default=None, desc="Override the default dimension ordering", @@ -349,62 +321,3 @@ class LanguageModelBaseConfig(BaseModelConfig): " Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error.", hint=FieldHint.testing, ) - - def __len__(self) -> int: - return len(self.decoder) + 2 * self.output_layer.prediction_heads - - def __getitem__(self, index: int) -> BlockConfig: - if index <= 0: - Assert.eq(index, 0) - return self.embeddings_layer - elif index <= len(self.decoder): - return self.decoder[index - 1] - else: - # Start at the last decoder layer so all MTP heads are treated similarly. - index - len(self.decoder) - return self.embeddings_layer - - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - return ( - self.embeddings_layer.get_preprocessors(distributed_config) - + self.decoder.get_preprocessors(distributed_config) - + self.output_layer.get_preprocessors(distributed_config) - ) - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return ( - self.embeddings_layer.get_loss_definitions(count) - + self.decoder.get_loss_definitions(count) - + self.output_layer.get_loss_definitions(count) - ) - - def get_blocks(self, distributed_config: DistributedConfig): - hidden_dim = TensorDim("hidden", self.embeddings_layer.hidden_size) - return [ - self.embeddings_layer.get_block( - distributed_config, - hidden_dim=hidden_dim, - lr_scale=None, - peft=self.peft, - ), - *[ - self.decoder[i].get_block( - distributed_config, - hidden_dim, - lr_scale=None, - peft=self.peft, - # The last layer only returns the transformer output. - # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self.output_layer.prediction_heads > 1 and i == len(self.decoder) - 1, - ) - for i in range(len(self.decoder)) - ], - *self.output_layer.get_blocks( - distributed_config, - self.embeddings_layer, - self.decoder[len(self.decoder) - 1], - hidden_dim=hidden_dim, - lr_scale=None, - peft=self.peft, - ), - ] diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d1e13a5b..0ad3225c8 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -14,8 +14,6 @@ from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert -WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" - class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[ConfigType]): """ @@ -28,6 +26,10 @@ class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[Co layer_count: float = 1000.0 _config: ConfigType + # Position embedding preprocessing + _position_ids: torch.Tensor + _tensor_cache_max_sequence_length: int = -1 + def __init__( self, config: ConfigType, @@ -36,17 +38,13 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - return_input: bool = False, ): - if return_input: - raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, - return_input=return_input, ) self._residual_dtype = ( self._distributed_config.optimization_dtype @@ -121,7 +119,7 @@ def forward( if isinstance(input_, TensorMeta): return TensorMeta.from_dims( kwargs[LanguageModelKwargs.hidden_dims], - tensor_name="Embedding output", + tensor_name=f"{self.module_name} output", dtype=self._residual_dtype, ) return self._forward( @@ -131,3 +129,30 @@ def forward( def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (embeddings) return 0 + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + if not self._config.position_embeddings.enabled: + return + self._create_position_embeddings(kwargs[LanguageModelKwargs.sequence_length], batch.device) + sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: + position_ids = torch.stack( + [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] + ).to(batch.device, dtype=torch.int64) + position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] + if kwargs[LanguageModelKwargs.sequence_first]: + position_ids = position_ids.transpose(0, 1) + kwargs[LanguageModelKwargs.position_ids] = position_ids + else: + kwargs[LanguageModelKwargs.position_ids] = self._position_ids[ + sequence_k - sequence_q : sequence_k + ].unsqueeze(int(kwargs[LanguageModelKwargs.sequence_first])) + + def _create_position_embeddings(self, sequence_length: int, device: torch.device) -> None: + if sequence_length <= self._tensor_cache_max_sequence_length: + return + self._tensor_cache_max_sequence_length = sequence_length + + Assert.leq(sequence_length, self._config.num_position_embeddings) + self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 0ab64cc9f..4b0e3d102 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,3 +1,5 @@ +import abc +import functools import logging import typing @@ -6,7 +8,7 @@ from torch.distributed import all_reduce from fast_llm.core.ops import gather_op, split_op -from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames @@ -21,11 +23,10 @@ from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( LanguageModelEmbeddingsConfig, + LanguageModelHeadBaseConfig, LanguageModelHeadConfig, LanguageModelKwargs, - LanguageModelLossNames, ) -from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -34,7 +35,13 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelHeadConfig](Block[ConfigType]): +class LanguageModelHeadBase[ConfigType: LanguageModelHeadBaseConfig](Block[ConfigType]): + @abc.abstractmethod + def get_output_weights(self) -> list[torch.Tensor]: + pass + + +class LanguageModelHead[ConfigType: LanguageModelHeadConfig](LanguageModelHeadBase[ConfigType]): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). TODO: Cleanup (dynamic type? composition?) @@ -51,19 +58,28 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - prediction_distance: int, - return_input: bool = False, + prediction_distance: int = 0, + prediction_heads: int = 1, + loss_coefficient: float = 1.0, ): - if return_input: - raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, - return_input=return_input, ) + if prediction_distance > 0 and ( + self._config.distillation_model is not None or self._config.dpo_reference_model is not None + ): + raise NotImplementedError("Multi-token prediction not supported with distillation or dpo.") + + Assert.in_range(prediction_distance, 0, prediction_heads) + self._prediction_distance = prediction_distance + self._prediction_heads = prediction_heads + self._loss_coefficient = loss_coefficient + self._is_last_head = self._prediction_distance == self._prediction_heads - 1 + self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -71,20 +87,6 @@ def __init__( if self._config.cross_entropy_splits is not None and self._sequence_parallel: assert not self._vocab_parallel - self._loss_coefficient = ( - self._config.prediction_loss_coefficient[prediction_distance] - if self._config.prediction_loss_coefficient - else 1.0 - ) - self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - - # Distance of the target token prediction - # 0: next-token prediction - # >0: multi-token prediction (MTP) - Assert.geq(prediction_distance, 0) - self._prediction_distance = prediction_distance - self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 - if not self._config.enable_dpo: self._cross_entropy_impl = self._config.cross_entropy_implementation if self._cross_entropy_impl == CrossEntropyImpl.auto: @@ -104,15 +106,12 @@ def __init__( self._vocab_dim = TensorDim( "vocab", embeddings_config.vocab_size, self._parallel_dim if self._vocab_parallel else None ) - # Only the first head defines the output weights - if self._prediction_distance == 0 and not self._config.tied_weight: - # untie embedding weights - self.output_weights = self._config.output_weight.get_parameter( - (self._vocab_dim, self._hidden_dim), - default_initialization=init_normal_(std=self._hidden_size**-0.5), - lr_scale=self._lr_scale, - peft=self._peft, - ) + self.output_weights = self._config.output_weight.get_parameter( + (self._vocab_dim, self._hidden_dim), + default_initialization=init_normal_(std=self._hidden_size**-0.5), + lr_scale=self._lr_scale, + peft=self._peft, + ) def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -121,7 +120,7 @@ def forward( if self._is_last_head: return TensorMeta.from_dims( (scalar_dim,), - tensor_name="Loss", + tensor_name=f"{self.module_name} output", reductions=( (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.AVG), ), @@ -190,7 +189,7 @@ def _forward_backward( self._parallel_dim.size if self._sequence_parallel_logits else 1 ) - output_weights = self._get_output_weights(kwargs) + output_weights = self.output_weights loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( ln_output.detach(), targets, output_weights, grad_output, kwargs, losses ) @@ -226,9 +225,7 @@ def _get_targets( if lm_target is not None: # MTP: Shift the labels lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) - + 1 - - self._config.prediction_heads + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads ) if LanguageModelKwargs.sequence_q_dim in kwargs: Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) @@ -251,12 +248,8 @@ def _get_targets( targets = None return targets - def _get_output_weights(self, kwargs: dict) -> torch.Tensor: - if self._config.tied_weight: - return kwargs[WORD_EMBEDDINGS_WEIGHT] - if self._prediction_distance > 0: - return kwargs[OUTPUT_WEIGHTS] - return self.output_weights + def get_output_weights(self) -> list[torch.Tensor]: + return [self.output_weights] def _logits_cross_entropy_forward_backward_split( self, @@ -348,7 +341,7 @@ def _logits_cross_entropy_forward_backward( self.training, grad_output, losses, - LanguageModelLossNames.z_loss, + self._z_loss_name, logits_scale_factor=self._config.logits_scale_factor, ) if self._debug.enabled and self._config.cross_entropy_splits is None: @@ -436,14 +429,83 @@ def _logits_cross_entropy_forward_backward( loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) if self.training and losses is not None: if dpo_loss is not None: - losses[LanguageModelLossNames.dpo_loss].append(dpo_loss.detach()) + losses[self._dpo_loss_name].append(dpo_loss.detach()) if self._config.distillation_model is not None and distillation_loss is not None: - losses[LanguageModelLossNames.distillation_loss].append(distillation_loss.detach()) + losses[self._distillation_loss_name].append(distillation_loss.detach()) if self._config.distillation_model is not None and lm_loss is not None: - losses[LanguageModelLossNames.distil_lm_loss].append(lm_loss.detach()) + losses[self._distillation_language_model_loss_name].append(lm_loss.detach()) return loss, output_parallel_linear_backward(grad, context) if self.training else None + @functools.cached_property + def _loss_name(self) -> str: + name = "language_model_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _z_loss_name(self) -> str: + name = "z_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _dpo_loss_name(self) -> str: + name = "dpo_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _distillation_language_model_loss_name(self) -> str: + name = "distillation_language_model_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _distillation_loss_name(self) -> str: + name = "distillation_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] + if self._config.logit_z_loss: + loss_defs.append( + LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) + ) + if self._config.enable_dpo: + loss_defs.append( + LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) + ) + + if self._config.distillation_model is not None: + loss_defs.append( + LossDef( + name=self._distillation_loss_name, + formatted_name=_format_name(self._distillation_loss_name), + count=count, + ) + ) + if self._config.language_model_loss_factor > 0.0: + loss_defs.append( + LossDef( + name=self._distillation_language_model_loss_name, + formatted_name=_format_name(self._distillation_language_model_loss_name), + count=count, + ) + ) + + return loss_defs + + +def _format_name(name: str) -> str: + return name.replace("_", " ") + def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: tensors = [tensor for tensor in tensors if tensor is not None] diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py new file mode 100644 index 000000000..9a3bef195 --- /dev/null +++ b/fast_llm/layers/language_model/language_model.py @@ -0,0 +1,61 @@ +import logging +import typing + +from fast_llm.config import Configurable +from fast_llm.engine.base_model.base_model import Layer, LayerBase +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.language_model.config import LanguageModelConfig +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding + +logger = logging.getLogger(__name__) + + +class LanguageModel[ConfigType: LanguageModelConfig](Configurable[ConfigType], LayerBase): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + ): + super().__init__(config, distributed_config) + + self._hidden_dim = TensorDim("hidden", config.embeddings.hidden_size) + self.embeddings: LanguageModelEmbedding = self._config.embeddings.get_layer( + distributed_config, + hidden_dim=self._hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + self.decoder = self._config.decoder.get_layer( + distributed_config, + self._hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + self.head = self._config.head.get_layer( + distributed_config, + self._config.embeddings, + hidden_dim=self._hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + + def get_layers(self) -> list["Layer"]: + return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() + + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? + self.embeddings.preprocess(batch, kwargs) + self.decoder.preprocess(batch, kwargs) + self.head.preprocess(batch, kwargs) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? + return ( + self.embeddings.get_loss_definitions(count) + + self.decoder.get_loss_definitions(count) + + self.head.get_loss_definitions(count) + ) diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py new file mode 100644 index 000000000..e0eb8175d --- /dev/null +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -0,0 +1,92 @@ +import functools +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer, LayerWithNamespace +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, MultiTokenPredictionConfig + + +class MultiTokenPrediction[ConfigType: MultiTokenPredictionConfig](BlockBase[ConfigType]): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + self.blocks = torch.nn.ModuleList( + [ + self._config.block.get_layer( + self._distributed_config, + self._hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + # The last block only returns the model output. + # The previous blocks return a stack of shared_hidden and transformer_output. + return_input=index < self._config.prediction_heads - 1, + ) + for index in range(self._config.prediction_heads) + ] + ) + self.heads = torch.nn.ModuleList( + [ + self._config.head.get_layer( + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + prediction_distance=index, + prediction_heads=self._config.prediction_heads, + loss_coefficient=( + 1.0 + if self._config.prediction_loss_coefficient is None + else self._config.prediction_loss_coefficient[index] + ), + ) + for index in range(self._config.prediction_heads) + ] + ) + + @functools.cached_property + def _layers_with_namespace(self) -> list[Layer]: + # Wrap all blocks in a namespace using the unique module name of the first one. + # This needs to be in a property because `module_name` is set after `__init__`. + namespace = self.blocks[0].module_name + return [LayerWithNamespace(sublayer, namespace) for layer in self.blocks for sublayer in layer.get_layers()] + + def get_layers(self) -> list[Layer]: + return [ + module + for block, head in zip(self._layers_with_namespace, self.heads, strict=True) + for module in (block, head) + ] + + def get_output_weights(self) -> list[torch.Tensor]: + return sum((head.get_output_weights() for head in self.heads), []) + + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + self._layers_with_namespace[0].preprocess(batch, kwargs) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return self.blocks[0].get_loss_definitions(count=count * self._config.prediction_heads) + [ + loss_definition for head in self.heads for loss_definition in head.get_loss_definitions(count=count) + ] diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py deleted file mode 100644 index fc1dac299..000000000 --- a/fast_llm/layers/language_model/preprocessing.py +++ /dev/null @@ -1,107 +0,0 @@ -import logging -import typing - -import torch - -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_dim import scalar_dim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs -from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert - -logger = logging.getLogger(__name__) - - -class PositionEmbeddingPreprocessor(Preprocessor): - _rotary_embedding_frequencies: torch.Tensor - _position_ids: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - - def __init__(self, config: LanguageModelEmbeddingsConfig, distributed_config: DistributedConfig): - self._config = config - assert config.position_embeddings.enabled - self._distributed_config = distributed_config - - def _create_tensors(self, sequence_length: int, device: torch.device) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - Assert.leq(sequence_length, self._config.num_position_embeddings) - self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) - - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[LanguageModelKwargs.sequence_length], batch.device) - sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size - sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: - position_ids = torch.stack( - [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] - ).to(batch.device, dtype=torch.int64) - position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] - if kwargs[LanguageModelKwargs.sequence_first]: - position_ids = position_ids.transpose(0, 1) - kwargs[LanguageModelKwargs.position_ids] = position_ids - else: - kwargs[LanguageModelKwargs.position_ids] = self._position_ids[ - sequence_k - sequence_q : sequence_k - ].unsqueeze(int(kwargs[LanguageModelKwargs.sequence_first])) - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - # Position embeddings will be broadcast. - sequence_q_dim = kwargs[LanguageModelKwargs.sequence_q_dim] - kwargs[LanguageModelKwargs.position_ids] = TensorMeta.from_dims( - ( - (sequence_q_dim, scalar_dim) - if kwargs[LanguageModelKwargs.sequence_first] - else (scalar_dim, sequence_q_dim) - ), - tensor_name=LanguageModelKwargs.position_ids, - dtype=torch.int64, - ) - - -class PreferenceSpanPreprocessor(Preprocessor): - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - return - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size - sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels - - if LanguageModelKwargs.chosen_spans not in kwargs or LanguageModelKwargs.rejected_spans not in kwargs: - raise ValueError("Expected chosen spans or rejected spans to be found within the batch.") - - chosen_spans = kwargs[LanguageModelKwargs.chosen_spans] - chosen_valid_spans = [] - for spans in chosen_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - chosen_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans - - rejected_spans = kwargs[LanguageModelKwargs.rejected_spans] - rejected_valid_spans = [] - for spans in rejected_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - rejected_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index f014012b2..c9fc609b0 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -43,6 +43,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( config, @@ -50,6 +51,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) state_dim = TensorDim("state", self._config.state_size) v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) @@ -128,7 +130,7 @@ def __init__( peft=self._peft, ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index e77a4468b..081aabe65 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -43,13 +43,10 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( - config, - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for Mamba" @@ -120,7 +117,7 @@ def __init__( peft=self._peft, ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b0657313d..4b0bd4366 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -41,13 +41,10 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( - config, - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) num_heads = div(self._config.d_inner, self._config.state_size) @@ -153,7 +150,7 @@ def __init__( BlockDimNames.sequence_q, ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 8fbb99cad..1e57f3b8c 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,12 +4,13 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.layers.language_model.config import LanguageModelConfig, MultiTokenPredictionConfig from fast_llm.models.gpt.conversion.config import ( AprielHybridSSMCheckpointFormat, AutoGPTHuggingfaceCheckpointFormat, @@ -26,7 +27,7 @@ if typing.TYPE_CHECKING: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM - from fast_llm.models.gpt.model import GPTInferenceRunner, GPTModel + from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel from fast_llm.models.gpt.trainer import GPTTrainer logger = logging.getLogger(__name__) @@ -80,7 +81,7 @@ def micro_batch_splits(self) -> int: @config_class() -class GPTBaseModelConfig(LanguageModelBaseConfig): +class GPTBaseModelConfig(LanguageModelConfig, BaseModelConfig): _abstract = False # Debug, to get an exact match with megatron init. @@ -88,6 +89,12 @@ class GPTBaseModelConfig(LanguageModelBaseConfig): default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing ) + @property + def base_model_class(self) -> type["GPTBaseModel"]: + from fast_llm.models.gpt.model import GPTBaseModel + + return GPTBaseModel + @config_class(dynamic_type={FastLLMModelConfig: "gpt"}) class GPTModelConfig(FastLLMModelConfig): @@ -141,41 +148,42 @@ class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): def _validate(self) -> None: if self.batch.sequence_length is None: # TODO: Drop this. - self.batch.sequence_length = self.model.base_model.embeddings_layer.num_position_embeddings + self.batch.sequence_length = self.model.base_model.embeddings.num_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() - if self.model.base_model.embeddings_layer.position_embeddings.enabled: - Assert.geq(self.model.base_model.embeddings_layer.num_position_embeddings, self.batch.sequence_length) + if self.model.base_model.embeddings.position_embeddings.enabled: + Assert.geq(self.model.base_model.embeddings.num_position_embeddings, self.batch.sequence_length) - distillation_model = self.model.base_model.output_layer.distillation_model - dpo_reference_model = self.model.base_model.output_layer.dpo_reference_model - - if self.model.base_model.output_layer.enable_dpo: - assert dpo_reference_model is not None - Assert.none(distillation_model) + # TODO: Avoid digging inside the model. + head = self.model.base_model.head + if isinstance(head, MultiTokenPredictionConfig): + prediction_heads = head.prediction_heads + head = head.head else: - Assert.none(dpo_reference_model) + prediction_heads = 1 - if distillation_model is None and dpo_reference_model is None: - Assert.empty(self.reference_models) - else: - assert distillation_model is None or dpo_reference_model is None # currently don't support both - expected_names = {name for name in (distillation_model, dpo_reference_model) if name is not None} - Assert.eq(self.reference_models.keys(), expected_names) + expected_names = {name for name in (head.distillation_model, head.dpo_reference_model) if name is not None} + Assert.eq(self.reference_models.keys(), expected_names) for reference_model in self.reference_models.values(): - output_layer = reference_model.model.base_model.output_layer - Assert.none(output_layer.distillation_model) - Assert.none(output_layer.dpo_reference_model) + reference_head = reference_model.model.base_model.head + if isinstance(reference_head, MultiTokenPredictionConfig): + reference_prediction_heads = reference_head.prediction_heads + reference_head = reference_head.heads + else: + reference_prediction_heads = 1 + Assert.geq(reference_prediction_heads, prediction_heads) + + Assert.none(reference_head.distillation_model) + Assert.none(reference_head.dpo_reference_model) # TODO: Support more LM head features. - Assert.none(output_layer.cross_entropy_splits) + Assert.none(reference_head.cross_entropy_splits) Assert.eq( - reference_model.model.base_model.embeddings_layer.vocab_parallel, - self.model.base_model.embeddings_layer.vocab_parallel, + reference_model.model.base_model.embeddings.vocab_parallel, + self.model.base_model.embeddings.vocab_parallel, ) - Assert.geq(output_layer.prediction_heads, output_layer.prediction_heads) @classmethod def get_trainer_class(cls) -> type["GPTTrainer"]: diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 5b32c481d..4b9849630 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -24,11 +24,11 @@ class AprielDiscreteMamba2Converter: @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: return { "type": "discrete_mamba_2", "state_size": config["ssm_cfg"]["d_state"], - "d_inner": config["ssm_cfg"].get("d_inner") or hidden_size * config["ssm_cfg"].get("expand", 1), + "d_inner": config["ssm_cfg"].get("d_inner") or config["hidden_size"] * config["ssm_cfg"].get("expand", 1), "add_linear_biases": config["ssm_cfg"]["bias"], "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, "n_qk_heads": config["ssm_cfg"]["n_qk_heads"], @@ -117,17 +117,17 @@ def get_converters( class AprielMamba2Converter: @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: return { "type": "mamba_2", "state_size": config["ssm_cfg"]["d_state"], - "d_inner": config["ssm_cfg"].get("d_inner") or hidden_size * config["ssm_cfg"].get("expand", 1), + "d_inner": config["ssm_cfg"].get("d_inner") or config["hidden_size"] * config["ssm_cfg"].get("expand", 1), "add_linear_biases": config["ssm_cfg"]["bias"], "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, - "d_xb": config["ssm_cfg"].get("d_xb") or hidden_size, + "d_xb": config["ssm_cfg"].get("d_xb") or config["hidden_size"], "dt_layer": {"bias": {"enabled": config["ssm_cfg"].get("dt_proj_bias", True)}}, "dt_rank": ( - math.ceil(hidden_size) + math.ceil(config["hidden_size"]) if config["ssm_cfg"].get("dt_rank", "auto") == "auto" else config["ssm_cfg"]["dt_rank"] ), @@ -246,8 +246,8 @@ class AprielBlockConverter: _config_classes = {value: key for key, value in layout_names.items()} @classmethod - def import_config(cls, config: dict, hidden_size: int, layout_name: str = "t") -> dict: - return cls._converter_classes[cls._config_classes[layout_name]].import_config(config, hidden_size) + def import_config(cls, config: dict, layout_name: str = "t") -> dict: + return cls._converter_classes[cls._config_classes[layout_name]].import_config(config) @classmethod def export_config(cls, config) -> dict: @@ -270,18 +270,18 @@ class AprielDecoderConverter(MistralDecoderConverter): block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: layout = config["hybrid_block_layout"] if len(layout) == 1: return { - "block": cls.block_converter_class.import_config(config, hidden_size, layout[0]), + "block": cls.block_converter_class.import_config(config, layout[0]), "num_blocks": config["num_hidden_layers"], } else: return { "type": "pattern", "blocks": { - layout_name: cls.block_converter_class.import_config(config, hidden_size, layout_name) + layout_name: cls.block_converter_class.import_config(config, layout_name) for layout_name in set(layout) }, "pattern": layout, @@ -317,14 +317,13 @@ def get_converters( fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False, - fast_llm_layer_start: int = 1, ) -> list[WeightConverter]: converters = [] for block_index in range(config.num_blocks): block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] converters += cls.block_converter_class.get_converters( block_config, - f"{fast_llm_prefix}.{block_index+fast_llm_layer_start}", + f"{fast_llm_prefix}.{block_index}", f"{hf_prefix}.{block_index}", drop_on_export, ) diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 629a3ceed..786d923f2 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -184,7 +184,7 @@ def import_weight( class LlamaAttentionConverter: @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: try: rope_type = config["rope_scaling"]["rope_type"] except (KeyError, TypeError): @@ -224,7 +224,7 @@ def import_config(cls, config: dict, hidden_size: int) -> dict: "dropout": config["attention_dropout"], } if out["head_size"] is None: - out["head_size"] = div(hidden_size, out["heads"]) + out["head_size"] = div(config["hidden_size"], out["heads"]) return out @@ -360,9 +360,9 @@ class LlamaBlockConverter: hf_norm_2_name: typing.ClassVar[str] = "post_attention_layernorm" @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: return { - "mixer": cls.mixer_converter_class.import_config(config, hidden_size), + "mixer": cls.mixer_converter_class.import_config(config), "mlp": cls.mlp_converter_class.import_config(config), "normalization": cls.normalization_converter_class.import_config(config), } @@ -412,9 +412,9 @@ class LlamaDecoderConverter: block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: return { - "block": cls.block_converter_class.import_config(config, hidden_size), + "block": cls.block_converter_class.import_config(config), "num_blocks": config["num_hidden_layers"], } @@ -434,13 +434,12 @@ def get_converters( fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False, - fast_llm_layer_start: int = 1, ) -> list[WeightConverter]: converters = [] for block_index in range(config.num_blocks): converters += cls.block_converter_class.get_converters( config.block, - f"{fast_llm_prefix}.{block_index+fast_llm_layer_start}", + f"{fast_llm_prefix}.{block_index}", f"{hf_prefix}.{block_index}", drop_on_export, ) @@ -477,47 +476,32 @@ class LlamaHeadConverter: @classmethod def import_config(cls, config: dict) -> dict: - return { - "tied_weight": config["tie_word_embeddings"], - "normalization": cls.normalization_converter_class.import_config(config), - } + return {"normalization": cls.normalization_converter_class.import_config(config)} @classmethod def export_config(cls, config: LanguageModelHeadConfig) -> dict: Assert.custom(isinstance, config, LanguageModelHeadConfig) - return safe_merge_dicts( - cls.normalization_converter_class.export_config(config.normalization), - {"tie_word_embeddings": config.tied_weight}, - ) + return cls.normalization_converter_class.export_config(config.normalization) @classmethod def get_converters( - cls, config: LanguageModelHeadConfig, block_config: DecoderBlockConfig, fast_llm_prefix: str, start_index: int + cls, + config: LanguageModelHeadConfig, + exported_config: dict, + fast_llm_prefix: str, ) -> list[WeightConverter]: - converters = [] - for prediction_distance in range(config.prediction_heads): - if prediction_distance > 0: - converters += cls.block_converter_class.get_converters( - block_config, - f"{fast_llm_prefix}.{start_index+2*prediction_distance-1}", - "", - drop_on_export=True, - ) - converters += cls.normalization_converter_class.get_converters( + return [ + *cls.normalization_converter_class.get_converters( config.normalization, - f"{fast_llm_prefix}.{start_index+2*prediction_distance}.final_norm", + f"{fast_llm_prefix}.final_norm", f"model.norm", - drop_on_export=prediction_distance > 0, - ) - converters.append( + ), get_parameter_converter( - f"{fast_llm_prefix}.{start_index}.output_weights", + f"{fast_llm_prefix}.output_weights", "lm_head.weight", - drop_on_import=config.tied_weight, - ) - ) - - return converters + drop_on_import=exported_config["tie_word_embeddings"], + ), + ] class LlamaBaseModelConverter: @@ -529,41 +513,30 @@ class LlamaBaseModelConverter: @classmethod def import_config(cls, config: dict) -> dict: return { - "embeddings_layer": cls.embeddings_converter_class.import_config(config), - "decoder": cls.decoder_converter_class.import_config(config, config["hidden_size"]), - "output_layer": cls.head_converter_class.import_config(config), + "embeddings": cls.embeddings_converter_class.import_config(config), + "decoder": cls.decoder_converter_class.import_config(config), + "head": cls.head_converter_class.import_config(config), + "tied_embedding_weight": config["tie_word_embeddings"], } @classmethod def export_config(cls, config: GPTBaseModelConfig) -> dict: Assert.custom(isinstance, config, GPTBaseModelConfig) return safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings_layer), + cls.embeddings_converter_class.export_config(config.embeddings), cls.decoder_converter_class.export_config(config.decoder), - cls.head_converter_class.export_config(config.output_layer), + cls.head_converter_class.export_config(config.head), + {"tie_word_embeddings": config.tied_embedding_weight}, ) @classmethod - def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: + def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: return [ - *cls.embeddings_converter_class.get_converters(config.embeddings_layer, "layers.0", "model"), - *cls.decoder_converter_class.get_converters(config.decoder, "layers", "model.layers"), - *cls.head_converter_class.get_converters( - config.output_layer, config.decoder[len(config.decoder) - 1], "layers", len(config.decoder) + 1 - ), + *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), + *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), + *cls.head_converter_class.get_converters(config.head, exported_config, "head"), ] - def _create_weight_converters( - self, - ) -> list[WeightConverter]: - base_model_config = self._model.config.base_model - self.embeddings_converter_class.get_converters(base_model_config.embeddings_layer, "layers.0", "model") - converters = self.decoder_converter_class.get_converters(base_model_config.decoder, "layers", "model.layers") - self.head_converter_class.get_converters( - base_model_config.decoder, base_model_config.decoder.block, "layers", len(base_model_config.decoder) + 1 - ) - return converters - class LlamaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: GPTModel diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index 4673f5b2c..bfc7d5569 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -16,8 +16,8 @@ class MistralAttentionConverter(LlamaAttentionConverter): @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: - return safe_merge_dicts(super().import_config(config, hidden_size), {"window_size": config["sliding_window"]}) + def import_config(cls, config: dict) -> dict: + return safe_merge_dicts(super().import_config(config), {"window_size": config["sliding_window"]}) @classmethod def export_config(cls, config: AttentionConfig) -> dict: diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 194c263f9..5b83fed69 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -4,63 +4,93 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import WeightConverter -from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.config import LanguageModelHeadConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig, MultiTokenPredictionConfig from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaBaseModelConverter, + LlamaBlockConverter, + LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, get_parameter_converter, ) -from fast_llm.utils import safe_merge_dicts +from fast_llm.utils import Assert, safe_merge_dicts class MTPLlamaHeadConverter(LlamaHeadConverter): @classmethod def import_config(cls, config: dict) -> dict: - return safe_merge_dicts( - super().import_config(config), - {"prediction_heads": config["prediction_heads"]}, - ) + return { + "type": "multi_token_prediction", + "block": LlamaBlockConverter.import_config(config), + "head": super().import_config(config), + "prediction_heads": config["prediction_heads"], + } @classmethod - def export_config(cls, config: LanguageModelHeadConfig) -> dict: + def export_config(cls, config: MultiTokenPredictionConfig) -> dict: + Assert.custom(isinstance, config, MultiTokenPredictionConfig) return safe_merge_dicts( - super().export_config(config), + super().export_config(config.head), {"prediction_heads": config.prediction_heads}, ) @classmethod def get_converters( - cls, config: LanguageModelHeadConfig, block_config: DecoderBlockConfig, fast_llm_prefix: str, start_index: int + cls, + config: LanguageModelHeadConfig, + exported_config: dict, + fast_llm_prefix: str, ) -> list[WeightConverter]: converters = [] for prediction_distance in range(config.prediction_heads): - if prediction_distance > 0: - converters += cls.block_converter_class.get_converters( - block_config, - f"{fast_llm_prefix}.{start_index+2*prediction_distance-1}", - f"model.mtp_heads.{prediction_distance - 1}", - ) + converters += cls.block_converter_class.get_converters( + config.block, + f"{fast_llm_prefix}.blocks.{prediction_distance}", + ( + f"model.layers.{exported_config["num_hidden_layers"]-1}" + if prediction_distance == 0 + else f"model.mtp_heads.{prediction_distance - 1}" + ), + ) converters += cls.normalization_converter_class.get_converters( - config.normalization, - f"{fast_llm_prefix}.{start_index+2*prediction_distance}.final_norm", + config.head.normalization, + f"{fast_llm_prefix}.heads.{prediction_distance}.final_norm", f"model.mtp_norms.{prediction_distance}", ) converters.append( get_parameter_converter( - f"{fast_llm_prefix}.{start_index}.output_weights", + f"{fast_llm_prefix}.heads.0.output_weights", "lm_head.weight", - drop_on_import=config.tied_weight, + drop_on_import=exported_config["tie_word_embeddings"], ) ) return converters +class MTPLlamaDecoderConverter(LlamaDecoderConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "block": cls.block_converter_class.import_config(config), + "num_blocks": config["num_hidden_layers"] - 1, + } + + @classmethod + def export_config(cls, config: FixedBlockSequenceConfig) -> dict: + # TODO: Support PatternBlockSequenceConfig with compatible configs. + Assert.custom(isinstance, config, FixedBlockSequenceConfig) + return safe_merge_dicts( + cls.block_converter_class.export_config(config.block), + {"num_hidden_layers": config.num_blocks + 1}, + ) + + class MTPLlamaBaseModelConverter(LlamaBaseModelConverter): + decoder_converter_class: typing.ClassVar[type[MTPLlamaDecoderConverter]] = MTPLlamaDecoderConverter head_converter_class: typing.ClassVar[type[MTPLlamaHeadConverter]] = MTPLlamaHeadConverter diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 680d8bfb2..9215e6dc7 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -79,7 +79,7 @@ def inner_forward( # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) - batch = self.fast_llm_base_model.preprocess( + batch = self.fast_llm_base_model.preprocess_batch( GPTBatch(input_ids, sequence_lengths=sequence_lenghts), phase=PhaseType.inference, iteration=iteration ) ((input_, kwargs),) = batch @@ -104,7 +104,7 @@ def inner_forward( self._inference_runner.forward(input_, kwargs, iteration=iteration) # TODO: Make a proper way of returning the model output. - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[AttentionKwargs.sequence_first]: logits = kwargs["logits"].transpose(0, 1) else: logits = kwargs["logits"] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b7d751a61..2c1fb0e4a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -4,8 +4,7 @@ import torch from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.base_model.base_model import BaseModel, Layer -from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner @@ -13,8 +12,7 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding -from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead +from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -23,7 +21,7 @@ logger = logging.getLogger(__name__) -class GPTBaseModel[ConfigType: GPTBaseModelConfig](BaseModel[ConfigType]): +class GPTBaseModel[ConfigType: GPTBaseModelConfig](LanguageModel[ConfigType], BaseModel[ConfigType]): """ A transformer-based language model generalizing the GPT model architecture. """ @@ -35,24 +33,18 @@ def __init__( config: GPTBaseModelConfig, distributed_config: DistributedConfig, ): - self._hidden_dim = TensorDim("hidden", config.embeddings_layer.hidden_size) super().__init__(config, distributed_config) if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron( - param, self._config.decoder.block, config.embeddings_layer.hidden_size + param, self._config.decoder.block, config.embeddings.hidden_size ) # Noqa - # `self._reference_models` is not populated at this point, so we pass a mutable dict. - self._preprocessors: list[Preprocessor] = self._config.get_preprocessors(distributed_config) - - def get_layers(self) -> list[Layer]: - return self._config.get_blocks(self._distributed_config) def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: - # TODO: How much of this is generalizable? + # TODO Remove (Move batch splitting elsewhere) # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence if isinstance(batch_meta, GPTBatchConfig): @@ -63,7 +55,7 @@ def preprocess_meta( else: micro_batch_size, sequence_length = batch_meta.shape if phase != PhaseType.inference: - sequence_length -= self._config.output_layer.prediction_heads + sequence_length -= self._config.head.prediction_heads micro_sequence_length = sequence_length truncate_documents = True @@ -142,8 +134,6 @@ def preprocess_meta( kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( hidden_dims[:2], tensor_name="labels", dtype=torch.int64 ) - for preprocessor in self._preprocessors: - preprocessor.preprocess_meta(kwargs) reference_kwargs = {} for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] @@ -161,7 +151,7 @@ def preprocess_meta( return preprocessed_meta - def preprocess( + def preprocess_batch( self, batch: GPTBatch, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, @@ -170,7 +160,7 @@ def preprocess( iteration: int, metrics: dict | None = None, ) -> list[tuple[torch.Tensor, dict]]: - # TODO: How much of this is generalizable? + # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup if preprocessed_meta is None: @@ -179,7 +169,7 @@ def preprocess( _, common_kwargs = preprocessed_meta[0] sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size sequence_first = common_kwargs[AttentionKwargs.sequence_first] - prediction_heads: int = self._config.output_layer.prediction_heads + max_prediction_distance = self._config.head.max_prediction_distance batch.token_ids = batch.token_ids.to( device=self._distributed.device, @@ -193,7 +183,7 @@ def preprocess( (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta ] - reference_batch = reference_model.fast_llm_model.base_model.preprocess( + reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch( batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration ) @@ -203,19 +193,20 @@ def preprocess( reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + token_ids = batch.token_ids if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. - batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() + token_ids = token_ids.transpose(0, 1).contiguous() preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size if sequence_first: - tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] + tokens = token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? - tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() + tokens = token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: @@ -235,10 +226,10 @@ def preprocess( if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if sequence_first: - labels = batch.token_ids[sequence_offset : sequence_k + prediction_heads] + labels = token_ids[sequence_offset : sequence_k + max_prediction_distance] else: # TODO: Avoid multiple contiguous calls? - labels = batch.token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() + labels = token_ids[:, sequence_offset : sequence_k + max_prediction_distance].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config if batch.loss_masking_spans is not None: @@ -248,12 +239,13 @@ def preprocess( if not spans.numel(): continue valid_spans = spans[ - (spans[:, 0] <= sequence_k + prediction_heads - 1) & (spans[:, 1] >= sequence_offset) + (spans[:, 0] <= sequence_k + max_prediction_distance - 1) + & (spans[:, 1] >= sequence_offset) ] if valid_spans.numel(): # if span is partially within the sequence, truncate parts of spans that are outside of the sequence valid_spans[:, 0].clamp_(min=sequence_offset) - valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1) + valid_spans[:, 1].clamp_(max=sequence_k + max_prediction_distance - 1) valid_spans -= sequence_offset loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: @@ -265,47 +257,55 @@ def preprocess( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels - kwargs.update(reference_logits[i]) + kwargs.update(reference_logits[i]) - for preprocessor in self._preprocessors: - preprocessor.preprocess(tokens, kwargs) - preprocessed.append((tokens, kwargs)) + if batch.chosen_spans is not None: + chosen_valid_spans = [] + for spans in batch.chosen_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) + valid_spans -= sequence_offset - return preprocessed + chosen_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans - @property - def embedding(self) -> LanguageModelEmbedding: - return self.layers[0] + rejected_valid_spans = [] + for spans in batch.rejected_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) + valid_spans -= sequence_offset - @property - def model_head(self) -> LanguageModelHead: - return self.layers[self.model_head_indices[0]] + rejected_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans - @property - def model_head_indices(self) -> list[int]: - return sorted([len(self) - 1 - 2 * i for i in range(self._config.output_layer.prediction_heads)]) + self.preprocess(tokens, kwargs) + preprocessed.append((tokens, kwargs)) - def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: - if self._config.output_layer.tied_weight: - return { - WORD_EMBEDDINGS_WEIGHT: ( - self.embedding.word_embeddings_weight, - (0, *self.model_head_indices), - ) - } - elif self._config.output_layer.prediction_heads > 1: - return { - OUTPUT_WEIGHTS: ( - self.model_head.output_weights, - tuple(self.model_head_indices), - ) - } - else: - return {} + return preprocessed + + def get_tied_parameters(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: + # TODO: Integrate to the `LayerBase` interface, move to `LanguageModel`, `MultiTokenPrediction`? + output_weights = self.head.get_output_weights() + if self._config.tied_embedding_weight: + output_weights.insert(0, self.embeddings.word_embeddings_weight) + return {output_weights[0].tensor_name: output_weights} if len(output_weights) > 1 else {} class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): - base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel + # TODO: Can we drop class? + pass class GPTInferenceRunner(InferenceRunner): diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 4dbbfbb1c..54ea13dc4 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -22,13 +22,14 @@ def _get_sampling_parameters( parameters = super()._get_sampling_parameters(parameters, _return_dict=True) parameters.update( { - "vocab_size": self._config.model.base_model.embeddings_layer.vocab_size, + "vocab_size": self._config.model.base_model.embeddings.vocab_size, "sequence_length": self._config.batch.sequence_length, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - "use_preference_loss_spans": self._config.model.base_model.output_layer.enable_dpo, + # OK since DPO is not supported for MTP. + "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), "cross_document_attention": self._config.batch.cross_document_attention, "truncate_documents": self._config.batch.truncate_documents, - "extra_tokens": self._config.model.base_model.output_layer.prediction_heads, + "extra_tokens": self._config.model.base_model.head.max_prediction_distance, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 1f9feceb4..bbd69ae8a 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -316,7 +316,9 @@ def new_decorator(*args, **kwargs): return new_decorator -def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple = ()): +def compare_nested( + config_a, config_b, errors: list | None = None, prefix: tuple = (), ignore_missing: tuple[str, ...] = () +): if errors is None: errors = [] # Check for equality of both values and types. diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index f14f028e1..0de823e2a 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -7,9 +7,8 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead +from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs +from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -21,9 +20,7 @@ def _reverse_kl_loss( loss_mask: torch.Tensor | None, teacher_softmax_temperature: float = 1.0, ): - scaled_target = target / teacher_softmax_temperature - - scaled_target = torch.clamp(target, min=-50, max=50) + scaled_target = torch.clamp(target / teacher_softmax_temperature, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) with torch.enable_grad(): @@ -101,56 +98,61 @@ def _lm_head( @pytest.mark.slow @pytest.mark.parametrize("cross_entropy_impl", tuple(CrossEntropyImpl)) @pytest.mark.parametrize( - ("config_dict", "distributed_config_dict", "loss_masking"), + ("config_dict", "distributed_config_dict", "loss_masking", "prediction_heads"), ( - ({}, {}, False), - ({}, {"compute_dtype": DataType.bfloat16}, False), - ({"embeddings_layer": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False), - ({"sequence_first": True}, {}, False), - ({"output_layer": {"logit_z_loss": 1e-3}}, {}, False), - ({"output_layer": {"logits_scale_factor": 5.0}}, {}, False), - ({"output_layer": {"tied_weight": False}}, {}, False), - ({"output_layer": {"prediction_heads": 2}}, {}, False), - ({}, {}, True), + ({}, {}, False, 1), + ({}, {"compute_dtype": DataType.bfloat16}, False, 1), + ({"embeddings": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False, 1), + ({"sequence_first": True}, {}, False, 1), + ({"head": {"logit_z_loss": 1e-3}}, {}, False, 1), + ({"head": {"logits_scale_factor": 5.0}}, {}, False, 1), + ({"tied_embedding_weight": True}, {}, False, 1), + ({}, {}, False, 2), + ({}, {}, True, 1), ( { - "output_layer": { + "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.cross_entropy, } }, {}, False, + 1, ), ( { - "output_layer": { + "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.reverse_kl, } }, {}, False, + 1, ), ( { - "output_layer": { + "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + "language_model_loss_factor": 1.0, } }, {}, True, + 1, ), ( { - "output_layer": { + "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.reverse_kl, } }, {}, True, + 1, ), ), ) @@ -159,24 +161,37 @@ def test_lm_head( config_dict: dict[str, typing.Any], distributed_config_dict: dict[str, typing.Any], loss_masking: bool, + prediction_heads: int, ): + torch.cuda.manual_seed(0) + torch.manual_seed(0) + head_config = { + "cross_entropy_implementation": cross_entropy_impl, + "normalization": {"type": "rms_norm"}, + } config = GPTBaseModelConfig.from_dict( { "decoder": { "num_blocks": 0, }, - "embeddings_layer": { + "embeddings": { "vocab_size": VOCAB_SIZE, "hidden_size": HIDDEN_SIZE, }, - "output_layer": { - "cross_entropy_implementation": cross_entropy_impl, - "normalization": {"type": "rms_norm"}, - }, + "head": ( + head_config + if prediction_heads == 1 + else { + "type": "multi_token_prediction", + "head": head_config, + "prediction_heads": prediction_heads, + } + ), }, config_dict, update_type=UpdateType.update, ) + head_config: LanguageModelHeadConfig = config.head if prediction_heads == 1 else config.head.head model, distributed = get_base_model( GPTModelConfig.from_dict( @@ -188,22 +203,22 @@ def test_lm_head( ) sequence_first = config.sequence_first or ( - config.output_layer.cross_entropy_splits is not None and config.output_layer.cross_entropy_splits > 1 + head_config.cross_entropy_splits is not None and head_config.cross_entropy_splits > 1 ) input_ = torch.randn( (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=( distributed.config.optimization_dtype.torch - if config.embeddings_layer.full_precision_residual + if config.embeddings.full_precision_residual else distributed.config.compute_dtype.torch ), device=distributed.device, requires_grad=True, ) label_shape = ( - (SEQUENCE_LENGTH + config.output_layer.prediction_heads - 1, BATCH_SIZE) + (SEQUENCE_LENGTH + config.head.max_prediction_distance - 1, BATCH_SIZE) if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.output_layer.prediction_heads - 1) + else (BATCH_SIZE, SEQUENCE_LENGTH + config.head.max_prediction_distance - 1) ) if loss_masking: loss_mask = torch.randint(0, 2, label_shape, dtype=torch.bool, device=distributed.device) @@ -213,7 +228,7 @@ def test_lm_head( AttentionKwargs.sequence_first: sequence_first, AttentionKwargs.grad_output: 1.0, } - if config.output_layer.distillation_model is None: + if head_config.distillation_model is None: target = torch.randint( 0, VOCAB_SIZE, @@ -226,41 +241,43 @@ def test_lm_head( kwargs[LanguageModelKwargs.labels] = target else: - assert config.output_layer.prediction_heads == 1 + assert config.head.max_prediction_distance == 1 target = torch.randn( input_.shape[:-1] + (VOCAB_SIZE,), dtype=input_.dtype, device=distributed.device, ) - kwargs[f"{config.output_layer.distillation_model}_logits"] = target + kwargs[f"{head_config.distillation_model}_logits"] = target if loss_mask is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask - if config.output_layer.tied_weight or config.output_layer.prediction_heads > 1: - logit_weight = ( + if config.tied_embedding_weight or config.head.max_prediction_distance > 1: + logit_weight = torch.nn.Parameter( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device - ) - .normal_(config.embeddings_layer.hidden_size**-0.5) - .requires_grad_(True) + ).normal_(config.embeddings.hidden_size**-0.5) ) - kwargs[WORD_EMBEDDINGS_WEIGHT if config.output_layer.tied_weight else OUTPUT_WEIGHTS] = logit_weight else: logit_weight = None - for prediction_distance, layer_index in enumerate(model.model_head_indices): + for prediction_distance, head in enumerate((model.head,) if prediction_heads == 1 else model.head.heads): # Prepare the LM head - head: LanguageModelHead = model[layer_index] Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) - stage = get_stage([head], distributed) + is_duplicate = config.tied_embedding_weight or prediction_distance > 0 + stage = get_stage( + [head], + distributed, + tied_parameter_duplicates=[head.output_weights.tensor_name] if is_duplicate else [], + tied_parameter_duplicate_buffers={head.output_weights.tensor_name: logit_weight} if is_duplicate else {}, + ) # Get reference outputs and grads - if logit_weight is None: - logit_weight = head.output_weights - else: + if is_duplicate: logit_weight.grad_buffer = torch.full_like(logit_weight, float("nan")) logit_weight.param_grad_is_zero = True + else: + logit_weight = head.output_weights ref_input = input_.detach().requires_grad_() ref_rms_weight = head.final_norm.weight.detach().requires_grad_() @@ -276,9 +293,9 @@ def test_lm_head( loss_mask, rms_weight=ref_rms_weight, logit_weight=ref_logit_weight, - logit_scale_factor=config.output_layer.logits_scale_factor, - logit_z_loss=config.output_layer.logit_z_loss, - distillation_loss_implementation=config.output_layer.distillation_loss_implementation, + logit_scale_factor=head_config.logits_scale_factor, + logit_z_loss=head_config.logit_z_loss, + distillation_loss_implementation=head_config.distillation_loss_implementation, ) # Prepare LM head inputs @@ -291,13 +308,18 @@ def test_lm_head( output_grad = torch.randn_like(shared_hidden) loss_name = f"language_model_loss_{prediction_distance}" if prediction_distance > 0 else "language_model_loss" - Assert.eq(head._loss_name, loss_name) loss_keys = {loss_name} if ref_z_loss is not None: - loss_keys.add("z_loss") - if config.output_layer.distillation_model is not None: + loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") + if head_config.distillation_model is not None: loss_keys.add("distillation_loss") - loss_keys.add("distil_lm_loss") + if head_config.language_model_loss_factor > 0: + loss_keys.add("distillation_language_model_loss") + + Assert.eq( + {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, + {loss_key: 1 for loss_key in loss_keys}, + ) losses = {key: [] for key in loss_keys} output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) @@ -305,7 +327,7 @@ def test_lm_head( threshold = 1e-5 if distributed.config.compute_dtype == DataType.float32 else 5e-3 min_threshold = ( 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 - ) * config.output_layer.logits_scale_factor + ) * head_config.logits_scale_factor Assert.eq(losses.keys(), loss_keys) Assert.eq(len(losses[loss_name]), 1) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 714abc130..3c3bfb833 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -328,7 +328,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): ) test_input = torch.randint( 0, - model_ref.config.fast_llm_config.base_model.embeddings_layer.vocab_size, + model_ref.config.fast_llm_config.base_model.embeddings.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda", diff --git a/tests/test_attention.py b/tests/test_attention.py index dceaa8282..a19cba8f0 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -2,13 +2,13 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.attention.attention import Attention from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs -from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert -def test_varlen_preprocessor(): +def test_varlen_preprocessing(): sequence_lengths = [torch.tensor([8, 13, 4, 11], dtype=torch.int32), torch.tensor([11, 16, 9], dtype=torch.int32)] # First micro-sequence: # [0...7,0...3] + [0...10,0] -> [0,8,12,23,24] @@ -28,8 +28,12 @@ def test_varlen_preprocessor(): ] micro_sequence_length = 12 sequence_length = 36 - varlen_preprocessor = FlashAttnVarlenPreprocessor( - AttentionConfig(head_size=64), DistributedConfig(compute_dtype="bfloat16") + attention = Attention( + AttentionConfig(head_size=64), + DistributedConfig(compute_dtype="bfloat16"), + hidden_dim=TensorDim("", 1), + lr_scale=None, + peft=None, ) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { @@ -40,6 +44,6 @@ def test_varlen_preprocessor(): AttentionKwargs.sequence_length: sequence_length, AttentionKwargs.sequence_lengths: sequence_lengths, } - varlen_preprocessor.preprocess(torch.empty(1, device="cpu"), kwargs) + attention.preprocess(torch.empty(1, device="cpu"), kwargs) Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_config.py b/tests/test_config.py index 6d2583ba3..326200537 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -74,7 +74,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config = GPTModelConfig.from_dict( { "base_model": { - "embeddings_layer": { + "embeddings": { "hidden_size": 1024, # Default }, "decoder": { @@ -92,7 +92,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, # Default }, - "output_layer": {"tied_weight": False}, + "tied_embedding_weight": False, }, "multi_stage": {"zero_stage": 3}, "distributed": {"compute_dtype": "bfloat16"}, @@ -105,7 +105,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config.save_metadata(save_config) base_model_update = { - "embeddings_layer": {"hidden_size": 512, "vocab_size": 1000}, + "embeddings": {"hidden_size": 512, "vocab_size": 1000}, "decoder": { "block": { "mixer": { @@ -127,51 +127,50 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): } ) serialized_config = pretrained_config.model.to_dict() - expected_config = {"type": "gpt", "distributed": DistributedConfig().to_dict()} + expected_config = {"distributed": DistributedConfig().to_dict()} if load_config == ModelConfigType.fast_llm: expected_config["multi_stage"] = {"zero_stage": 3} expected_config["distributed"].update({"seed": 1234, "compute_dtype": "float16"}) if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { - "embeddings_layer": { + "embeddings": { "hidden_size": 512, "vocab_size": 1000, }, "decoder": { - "type": "fixed", "block": { - "type": "decoder", "mixer": { - "type": "attention", - "rotary": {"type": "default"}, "window_size": 32, "head_groups": 1, }, "mlp": { - "type": "mlp", "intermediate_size": 4096, # Implicit default, default value "activation": "silu", # Implicit default, non-default value }, - "normalization": {"type": "rms_norm", "implementation": "triton"}, + "normalization": {"implementation": "triton"}, }, "num_blocks": 12, }, - "output_layer": {"tied_weight": False, "normalization": {"type": "layer_norm"}}, - "peft": {"type": "lora", "freeze_others": False}, + "tied_embedding_weight": False, + "peft": {"freeze_others": False}, } else: - base_model_update["decoder"]["type"] = "fixed" - base_model_update["decoder"]["block"]["type"] = "decoder" - base_model_update["decoder"]["block"]["normalization"]["type"] = "layer_norm" - base_model_update["decoder"]["block"]["mixer"]["type"] = "attention" - base_model_update["decoder"]["block"]["mixer"]["rotary"] = {"type": "none"} - base_model_update["decoder"]["block"]["mlp"] = {"type": "mlp"} - base_model_update["output_layer"] = {"normalization": {"type": "layer_norm"}} - base_model_update["peft"] = {"type": "lora", "freeze_others": False} expected_config["base_model"] = base_model_update - check_equal_nested(serialized_config, expected_config) + check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) + + +def _trim_type(config: dict): + # Serialization inserts dynamic types, we ignore them during the comparison. + if "type" in config: + del config["type"] + for key in list(config): + if isinstance(value := config[key], dict): + _trim_type(value) + if not value: + del config[key] + return config def _check_dim(dim: DistributedDim, name: str, rank: int, size: int, global_rank: int): diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index cc5a60a8a..407b47767 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -5,7 +5,6 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel -from fast_llm.layers.decoder.block import DecoderBlock from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup @@ -42,14 +41,14 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, DecoderBlock) else 0 - for layer in model_ref.base_model.layers + sum(p.numel() for p in layer.unwrap().mlp.parameters()) if layer.module_name.startswith("decoder") else 0 + for layer in model_ref.base_model.get_layers() ] # Make sure each layer has its own buffer so the check below works. Assert.eq( - num_stages := len(model_ref.base_model.layers), - len(model_frozen.base_model.layers), + num_stages := len(model_ref.base_model.get_layers()), + len(model_frozen.base_model.get_layers()), len(model_ref.stages), len(model_frozen.stages), ) diff --git a/tests/utils/compare_tensor_logs.py b/tests/utils/compare_tensor_logs.py index 51ee66d31..1c8ebd76a 100644 --- a/tests/utils/compare_tensor_logs.py +++ b/tests/utils/compare_tensor_logs.py @@ -79,7 +79,9 @@ def _compare_dict_keys(self, dict_ref, dict_test, errors, name): keys_test = set(dict_test) if keys_ref != keys_test: errors.append( - f">>>> {name} do not match. Missing = {keys_ref - keys_test}, extra = {keys_test - keys_ref}." + f">>>> {name} do not match." + f"\n Missing = \n{"\n * ".join(keys_ref - keys_test)}" + f"\n Extra = \n{"\n * ".join(keys_test - keys_ref)}" ) # Avoid set to preserve ordering. diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 863be2cae..fac595905 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -110,7 +110,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon DistributedTestingConfig( name="ce4", compare="simple", - config_args=["model.base_model.output_layer.cross_entropy_splits=4"], + config_args=["model.base_model.head.cross_entropy_splits=4"], num_gpus=1, compare_config=_compare_layer_mismatch, ), @@ -228,8 +228,8 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.embeddings_layer.vocab_parallel=False", - "model.base_model.output_layer.cross_entropy_splits=4", + "model.base_model.embeddings.vocab_parallel=False", + "model.base_model.head.cross_entropy_splits=4", ], num_gpus=2, compare_config=_compare_layer_match, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index aa8100126..6b313aa8a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -189,7 +189,7 @@ def _update_and_add_testing_config( }, "model": { "base_model": { - "embeddings_layer": { + "embeddings": { "word_embeddings": init_1, "position_embeddings": {"enabled": True, **init_1}, "hidden_size": 256, @@ -215,7 +215,8 @@ def _update_and_add_testing_config( }, "num_blocks": 2, }, - "output_layer": {"output_weight": init_1}, + "head": {"output_weight": init_1}, + "tied_embedding_weight": True, }, "multi_stage": { "debug_param_init": _LOG_LEVEL, @@ -324,7 +325,7 @@ def _update_and_add_testing_config( updates={ ("model", "base_model", "decoder", "block", "mixer", "head_groups"): 4, ("model", "base_model", "decoder", "block", "mixer", "rotary", "type"): "default", - ("model", "base_model", "embeddings_layer", "position_embeddings", "enabled"): False, + ("model", "base_model", "embeddings", "position_embeddings", "enabled"): False, }, megatron_args=[ "--group-query-attention", @@ -354,8 +355,8 @@ def _update_and_add_testing_config( ("model", "base_model", "decoder", "block", "mlp", "activation"): "silu", ("model", "base_model", "decoder", "block", "mlp", "add_linear_biases"): False, ("model", "base_model", "decoder", "block", "normalization", "type"): "rms_norm", - ("model", "base_model", "output_layer", "normalization", "type"): "rms_norm", - ("model", "base_model", "output_layer", "tied_weight"): False, + ("model", "base_model", "head", "normalization", "type"): "rms_norm", + ("model", "base_model", "tied_embedding_weight"): False, }, megatron_args=[ "--swiglu", @@ -436,12 +437,22 @@ def _update_and_add_testing_config( }, ) + +_llama_block = MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]["block"] + + _update_and_add_testing_config( # Tests multi-token prediction, custom HF model and converter. "llama", "mtp_llama", updates={ - ("model", "base_model", "output_layer", "prediction_heads"): 2, + ("model", "base_model", "head"): { + "type": "multi_token_prediction", + "block": _llama_block, + "head": MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["head"], + "prediction_heads": 2, + }, + ("model", "base_model", "decoder", "num_blocks"): 1, }, # Megatron doesn't support multi-token prediction. megatron_args=None, @@ -456,6 +467,8 @@ def _update_and_add_testing_config( ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, compare_factor=2.0, + # Arg update for cross-entropy splits doesn't work here. + skip_tests=("ce4", "ms"), ) _update_and_add_testing_config( @@ -549,8 +562,6 @@ def _update_and_add_testing_config( compare_factor=2.0, ) -_llama_block = MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]["block"] - _update_and_add_testing_config( # Tests hybrid Mamba, llamba converter. diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 0dc3462eb..098f0240e 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -11,7 +11,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier -from fast_llm.engine.base_model.base_model import BaseModel, Layer +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.logging import configure_logging from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig @@ -32,22 +32,26 @@ def result_path(): def get_base_model(config: FastLLMModelConfig): # Create a base model (and distributed). # Using a full model config so we have the model type and distributed config in the same argument. - base_model = config.get_model_class().base_model_class(config.base_model, config.distributed) + base_model = config.get_base_model_config_class().get_base_model(config.base_model, config.distributed) base_model.setup(distributed := Distributed(config.distributed)) return base_model, distributed -def get_stage(base_model: BaseModel | list[Layer], distributed: Distributed): +def get_stage( + layers: list[Layer], + distributed: Distributed, + tied_parameter_duplicates: typing.Iterable[str] = (), + tied_parameter_duplicate_buffers: dict[str, torch.nn.Parameter] | None = None, +): # Create a fast-llm stage which allocates and initializes meta tensors correctly. stage = Stage( config=StageConfig(), - base_model=base_model, + layers=layers, distributed_config=distributed.config, - begin=0, - end=1, index=0, + tied_parameter_duplicates=tied_parameter_duplicates, ) - stage.setup(distributed=distributed) + stage.setup(distributed=distributed, tied_parameter_duplicate_buffers=tied_parameter_duplicate_buffers) stage.initialize_weights() stage.restore_parameters() stage.reset_gradients() diff --git a/tools/generate_config_yaml_for_sharded_dst.py b/tools/generate_config_yaml_for_sharded_dst.py deleted file mode 100644 index c0b4fa24d..000000000 --- a/tools/generate_config_yaml_for_sharded_dst.py +++ /dev/null @@ -1,98 +0,0 @@ -import argparse -import pathlib - -import yaml - -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator, GPTMemmapDatasetPreparatorConfig - -""" -This script is intended to be used only for creation of fast_llm_config.yaml files for sharded datasets encoded with older version of the prepare command. -""" - - -def read_dataset_shard_config(shard_path): - """ - Read a dataset shard from the given path. - - Args: - shard_path: Path to the shard prefix (without .idx or .bin extension) - - Returns: - A GPTMemmapDataset instance - """ - # Convert to pathlib.Path if it's a string - path = pathlib.Path(shard_path) if isinstance(shard_path, str) else shard_path - - # Create a GPTMemmapDataset instance - # The name parameter is just for identification - dataset = GPTMemmapDataset(name=path.name, prefix=path) - - # Print basic information about the dataset - print(f"Dataset: {dataset.name}") - print(f"Number of documents: {dataset._num_documents}") - print(f"Number of tokens: {dataset.num_tokens}") - - return GPTMemmapDatasetConfig.from_dict( - { - "type": "memmap", - "path": path.name.replace(".bin", ""), - "num_documents": dataset._num_documents, - "num_tokens": dataset.num_tokens, - } - ) - - -def get_preparator(prepare_config: GPTMemmapDatasetPreparatorConfig) -> GPTMemmapDatasetPreparator: - config = GPTMemmapDatasetPreparatorConfig.from_dict( - { - "output_path": prepare_config.output_path, - "dataset": {"path": prepare_config.dataset.path}, - "tokenizer": {"path": prepare_config.tokenizer.path}, - }, - {}, - ) - return config.get_dataset_preparator_class()(config=config) - - -def main(config_dict): - prepare_config = GPTMemmapDatasetPreparatorConfig.from_dict(config_dict) - destination = pathlib.Path(prepare_config.output_path) - - shards = list(destination.glob("shard_*.bin")) - dataset_configs = [read_dataset_shard_config(shard) for shard in shards] - - preparator = get_preparator(prepare_config) - preparator.generate_config_yaml_for_sharded_dst(dataset_configs) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Generate config YAML for sharded datasets") - parser.add_argument( - "--prepare_config", - type=str, - required=False, - default=None, # "/home/toolkit/dev/Fast-LLM/.vscode/prepare_dst.yaml", - help="Path to the prepare config YAML file", - ) - parser.add_argument( - "--dataset_path", - type=str, - required=False, - default="/mnt/datasets/tokenized/Mistral-Nemo-Base-2407/FineWeb2/deu_Latn/", - help="Path to the dataset path", - ) - args = parser.parse_args() - - if args.prepare_config: - with open(args.prepare_config) as f: - config_dict = yaml.safe_load(f) - else: - assert args.dataset_path is not None, "Please provide a prepare config YAML file or dataset path" - config_dict = { - "output_path": args.dataset_path, - "dataset": {"path": "unknown"}, - "tokenizer": {"path": "no_tokenizer"}, - } - main(config_dict)