From a3dc89d979e41ede09bee5aad8d9d0f8fa8d8445 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 Sep 2025 16:18:33 -0400 Subject: [PATCH 01/16] stuff --- fast_llm/engine/config_utils/data_type.py | 2 +- fast_llm/engine/config_utils/run.py | 2 +- fast_llm/layers/attention/attention.py | 8 +++-- fast_llm/layers/block/block.py | 13 ------- fast_llm/layers/block/config.py | 2 -- fast_llm/layers/decoder/block.py | 36 ++++++++++++++++--- fast_llm/layers/decoder/config.py | 23 ++++++++++-- .../layers/decoder/mlp/mixture_of_experts.py | 6 ++-- fast_llm/layers/decoder/mlp/mlp.py | 4 ++- fast_llm/layers/language_model/embedding.py | 4 --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++- fast_llm/layers/ssm/mamba.py | 9 ++--- fast_llm/layers/ssm/mamba2.py | 9 ++--- 13 files changed, 76 insertions(+), 46 deletions(-) diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index 0929b7cb1..f4a2cfd6c 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -9,7 +9,7 @@ from triton import language as tl -class DataType(str, enum.Enum): +class DataType(enum.StrEnum): """ An enum to represent data types independently of third party libraries, so we can swap them more easily and allow for lazy imports. diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 1737f4308..1849a2316 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -136,7 +136,7 @@ def __init__( self._distributed.config.data_rank == 0 and self._distributed.config.tensor_rank == 0 ) config_dict = config.to_dict() - config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance) + config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.debug) if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 9a940f4cb..2d4f049f0 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -64,6 +64,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( config, @@ -71,6 +72,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) @@ -273,7 +275,7 @@ def _query_key_value_backward( input_grad.add_(self.key_value.backward(key_value_grad, context.pop("key_value"))) return input_grad - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], @@ -340,7 +342,7 @@ def forward( max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), dropout_p=self._config.dropout if self.training else 0.0, window_size=window_size, - causal=True, + causal=self._config.causal, softmax_scale=self._softmax_scale, ).view(*out_dims) else: @@ -350,7 +352,7 @@ def forward( value, window_size=window_size, dropout_p=self._config.dropout if self.training else 0.0, - causal=True, + causal=self._config.causal, softmax_scale=self._softmax_scale, ) input_ = input_.flatten(-2) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 773cce87e..0f975c9c5 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -123,16 +123,3 @@ class Block[ConfigType: Config](BaseBlock[ConfigType], Layer): """ Base class for actual blocks, i.e., base blocks that are also `Layers`. """ - - def __init__( - self, - config: ConfigType, - distributed_config: DistributedConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - return_input: bool = False, - ): - super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) - self._return_input = return_input diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index df5bd8181..47c1ab9b7 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -87,7 +87,6 @@ def get_block( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - return_input: bool = False, ) -> "Block": return self.layer_class( self, @@ -95,7 +94,6 @@ def get_block( hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, - return_input=return_input, ) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index ba4c370c2..ce9893ea4 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.block import BaseBlock, Block +from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig @@ -19,18 +19,44 @@ logger = logging.getLogger(__name__) -class BlockWithBias[ConfigType: Config](BaseBlock[ConfigType]): +class BlockWithBias[ConfigType: Config](Block[ConfigType]): """ Base class for mixer and MLP modules. """ - @abc.abstractmethod + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) + self._return_bias = return_bias + def forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor: + output, bias = self._forward(input_, kwargs, losses, metrics) + if self._return_bias: + return output, bias + else: + return output if bias is None else output + bias + + @abc.abstractmethod + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: pass @@ -58,7 +84,7 @@ def __init__( peft=peft, ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. - self._return_input: bool = return_input + self._return_input = return_input # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) @@ -70,6 +96,7 @@ def __init__( self._hidden_dim, self._lr_scale, peft=peft, + return_bias=True, ) self.mlp = self._config.mlp.get_layer( @@ -77,6 +104,7 @@ def __init__( self._hidden_dim, self._lr_scale, peft=peft, + return_bias=True, ) def setup(self, distributed: Distributed) -> None: diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 5f8131b5c..724b8d172 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.config import BaseBlockConfig, BlockConfig +from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -15,7 +15,7 @@ @config_class() -class BlockWithBiasConfig(BaseBlockConfig): +class BlockWithBiasConfig(BlockConfig): """ A common interface for various blocks and block layers. """ @@ -30,6 +30,7 @@ def get_layer( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = False, ) -> "BlockWithBias": return self.layer_class( self, @@ -37,6 +38,7 @@ def get_layer( hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, + return_bias=return_bias, ) @@ -94,6 +96,23 @@ def layer_class(self) -> "type[DecoderBlock]": return DecoderBlock + def get_block( + self, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_input: bool = False, + ) -> "DecoderBlock": + return self.layer_class( + self, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + return_input=return_input, + ) + def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: return self.mixer.get_preprocessors(distributed_config) + self.mlp.get_preprocessors(distributed_config) diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 089fa2dc7..d4cb46dbf 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -46,6 +46,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): Assert.gt(config.experts, 1) # TODO: Implement? @@ -56,6 +57,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) self.router = self._config.router.get_layer( self._hidden_dim, @@ -83,9 +85,9 @@ def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: CompositeTensorDim("moe_intermediate_2", (experts_dim, intermediate_2_dim)), ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, None]: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug.enabled: diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 9dd17d698..aaea94adb 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -28,6 +28,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( config, @@ -35,6 +36,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, self._intermediate_2_dim = self._get_intermediate_dims() @@ -102,7 +104,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): _config: MLPConfig - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d1e13a5b..362ffaa22 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -36,17 +36,13 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - return_input: bool = False, ): - if return_input: - raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, - return_input=return_input, ) self._residual_dtype = ( self._distributed_config.optimization_dtype diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index f014012b2..c9fc609b0 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -43,6 +43,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( config, @@ -50,6 +51,7 @@ def __init__( hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, + return_bias=return_bias, ) state_dim = TensorDim("state", self._config.state_size) v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) @@ -128,7 +130,7 @@ def __init__( peft=self._peft, ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index e77a4468b..081aabe65 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -43,13 +43,10 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( - config, - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for Mamba" @@ -120,7 +117,7 @@ def __init__( peft=self._peft, ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b0657313d..4b0bd4366 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -41,13 +41,10 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_bias: bool = True, ): super().__init__( - config, - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, + config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) num_heads = div(self._config.d_inner, self._config.state_size) @@ -153,7 +150,7 @@ def __init__( BlockDimNames.sequence_q, ) - def forward( + def _forward( self, input_: torch.Tensor, kwargs: dict[str, typing.Any], From 414f87edb881247b8351a1d09e41fe883de93247 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 Sep 2025 16:19:07 -0400 Subject: [PATCH 02/16] stuff --- fast_llm/layers/language_model/head.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index ade1144d2..4b080b360 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -52,17 +52,13 @@ def __init__( lr_scale: float | None, peft: PeftConfig | None, prediction_distance: int, - return_input: bool = False, ): - if return_input: - raise NotImplementedError() super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, - return_input=return_input, ) self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) From 4a21360ed07277ba90e5bdbefe6cf3a43a83a9fd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 26 Sep 2025 16:25:34 -0400 Subject: [PATCH 03/16] stuff --- fast_llm/layers/language_model/config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index f59b4cffd..1af6bdc38 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -395,7 +395,12 @@ def get_blocks(self, distributed_config: DistributedConfig): peft=self.peft, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self.output_layer.prediction_heads > 1 and i == len(self.decoder) - 1, + # TODO: Not all blocks support this argument. + **( + {"return_input": True} + if self.output_layer.prediction_heads > 1 and i == len(self.decoder) - 1 + else {} + ), ) for i in range(len(self.decoder)) ], From 1ba88ca48a266582d46b68dfe162ed33b1929ba3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 30 Sep 2025 22:54:17 -0400 Subject: [PATCH 04/16] Model interface --- fast_llm/core/ops.py | 8 +- fast_llm/engine/base_model/base_model.py | 108 ++++--- fast_llm/engine/base_model/config.py | 34 +- fast_llm/engine/evaluation/evaluator.py | 2 +- fast_llm/engine/multi_stage/config.py | 5 - fast_llm/engine/multi_stage/multi_stage.py | 15 +- fast_llm/engine/multi_stage/stage.py | 16 +- fast_llm/engine/multi_stage/stage_base.py | 13 +- fast_llm/engine/schedule/runner.py | 2 +- fast_llm/engine/training/trainer.py | 12 +- fast_llm/layers/attention/attention.py | 117 +++++++ fast_llm/layers/attention/config.py | 23 +- fast_llm/layers/attention/preprocessing.py | 153 --------- fast_llm/layers/attention/rotary/config.py | 4 +- fast_llm/layers/attention/rotary/rotary.py | 35 +- fast_llm/layers/block/block.py | 21 +- fast_llm/layers/block/config.py | 88 ++--- fast_llm/layers/block/sequence.py | 108 +++++++ .../layers/common/normalization/config.py | 4 +- fast_llm/layers/decoder/block.py | 13 +- fast_llm/layers/decoder/config.py | 9 +- fast_llm/layers/decoder/mlp/config.py | 21 -- .../layers/decoder/mlp/mixture_of_experts.py | 22 +- fast_llm/layers/language_model/config.py | 302 ++++++------------ fast_llm/layers/language_model/embedding.py | 27 ++ fast_llm/layers/language_model/head.py | 112 +++++-- .../language_model/multi_token_prediction.py | 77 +++++ .../layers/language_model/preprocessing.py | 107 ------- fast_llm/models/gpt/config.py | 13 +- fast_llm/models/gpt/model.py | 262 ++++++++++++--- tests/utils/utils.py | 2 +- tools/generate_config_yaml_for_sharded_dst.py | 98 ------ 32 files changed, 934 insertions(+), 899 deletions(-) delete mode 100644 fast_llm/layers/attention/preprocessing.py create mode 100644 fast_llm/layers/language_model/multi_token_prediction.py delete mode 100644 fast_llm/layers/language_model/preprocessing.py delete mode 100644 tools/generate_config_yaml_for_sharded_dst.py diff --git a/fast_llm/core/ops.py b/fast_llm/core/ops.py index a7492daa5..bb61aadd0 100644 --- a/fast_llm/core/ops.py +++ b/fast_llm/core/ops.py @@ -26,7 +26,7 @@ def reduce_op( return (input_, handle) if async_op else input_ -def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: +def split_op(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: """Split the tensor along its last dimension and keep the corresponding slice.""" if group: @@ -139,11 +139,11 @@ class _Split(torch.autograd.Function): """Split the input and keep only the corresponding chuck to the rank.""" @staticmethod - def symbolic(graph, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: # noqa + def symbolic(graph, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: # noqa return split_op(input_, group, dim) @staticmethod - def forward(ctx, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: # noqa + def forward(ctx, input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: # noqa ctx.group = group ctx.dim = dim return split_op(input_, group, dim) @@ -209,7 +209,7 @@ def reduce_backward(input_: torch.Tensor, group: ProcessGroup | None) -> torch.T @torch._dynamo.disable # noqa -def split(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> list[torch.Tensor]: +def split(input_: torch.Tensor, group: ProcessGroup | None, dim: int) -> torch.Tensor: return _Split.apply(input_, group, dim) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 0a3f8d1ce..ce7002c54 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -1,23 +1,19 @@ import abc import typing -import torch import torch.nn from fast_llm.config import Configurable -from fast_llm.engine.base_model.config import BaseModelConfig, ResourceUsageConfig +from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, ResourceUsageConfig from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import ParameterMeta, TensorMeta -from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.engine.inference.runner import InferenceRunner -class Module(torch.nn.Module, abc.ABC): - """ """ - +class LayerBase(torch.nn.Module, abc.ABC): _is_setup: bool = False _distributed: Distributed @@ -27,57 +23,87 @@ def __init__(self, distributed_config: DistributedConfig): def setup(self, distributed: Distributed) -> None: assert not self._is_setup + for layer in self.get_layers(): + if layer is not self: + layer.setup(distributed) distributed.check_config(self._distributed_config) self._distributed = distributed self._is_setup = True + @abc.abstractmethod + def get_layers(self) -> list["Layer"]: + """ + The list of layers as meant to be seen by the Fast-LLM engine. + May differ from the module configuration seen by pytorch. + """ -class Layer(Module): - # Weight used to determine the stage size + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + out = 0 + for layer in self.get_layers(): + if layer is self: + raise NotImplementedError() + out += layer.get_compute_usage(input_, kwargs, config) + return out + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + losses = [] + for layer in self.get_layers(): + if layer is not self: + losses += layer.get_loss_definitions(count) + return losses + + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + for layer in self.get_layers(): + if layer is not self: + layer.preprocess(batch, kwargs) + + +class Layer(LayerBase): + # Weight used to determine the stage size. layer_count: float = 1.0 + def get_layers(self) -> list["Layer"]: + # Return a breakdown of the layer into atomic ones, + # i.e. the list of layers from as seen from the Fast-LLM model. + return [self] + @abc.abstractmethod def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: pass - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - raise NotImplementedError() +class LayerWithNamespace(Layer): + """ + A layer with its own namespace for preprocessing (kwargs), + so that it doesn't inadvertently interact with other layers. + TODO: Consider namespace for losses and metrics? + """ -class Sequential(Layer): - def __init__(self, distributed_config: DistributedConfig): - super().__init__(distributed_config) - self.layers = torch.nn.ModuleList(self.get_layers()) - - def __getitem__(self, item): - return self.layers[item] + def __init__(self, layer: Layer, namespace: str): + super().__init__(layer._distributed_config) + self._layer = layer + self._namespace = namespace + self.layer_count = self._layer.layer_count + self.get_compute_usage = self._layer.get_compute_usage - def __iter__(self): - return iter(self.layers) - - def __len__(self): - return len(self.layers) + def setup(self, distributed: Distributed) -> None: + self._layer.setup(distributed) + super().setup(distributed) def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: - for layer in self.layers: - input_ = layer(input_, kwargs, losses, metrics) - return input_ + return self._layer.forward(input_, kwargs[self._namespace], losses, metrics) - @abc.abstractmethod - def get_layers(self) -> list[Layer]: - pass + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + assert self._namespace not in kwargs + kwargs[self._namespace] = kwargs.copy() + return self._layer.preprocess(batch, kwargs[self._namespace]) - def setup(self, distributed: Distributed) -> None: - super().setup(distributed) - for layer in self.layers: - layer.setup(distributed) - -class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential): +class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], LayerBase): def __init__( self, @@ -85,23 +111,14 @@ def __init__( distributed_config: DistributedConfig, ): super().__init__(config, distributed_config) - for key, value in self.named_modules(): - value.module_name = key - for key, value in self.named_parameters(): - Assert.custom(isinstance, value, ParameterMeta) - # Rename to the parameter full name - value.tensor_name = key # Reference models # TODO: Add basic handling (preprocessor) in this class. self._reference_models: dict[str, "InferenceRunner"] = {} - @abc.abstractmethod - def get_layers(self) -> list[Layer]: - pass - @abc.abstractmethod def preprocess_meta(self, batch_meta: typing.Any, phase: PhaseType) -> list[tuple[TensorMeta, dict]]: + # TODO ====== Remove (Move batch splitting elsewhere) ====== pass @abc.abstractmethod @@ -114,9 +131,12 @@ def preprocess( iteration: int, metrics: dict | None = None, ) -> list[tuple[torch.Tensor, dict]]: + # TODO ====== Move batch splitting elsewhere, align interface with LayerBase ====== pass def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: + # TODO ====== Tied weights ====== + # Return tuples of independently defined metas to tie together. # For each tied weight, return the weight and the tuple of layers sharing it. # The weight should be defined in the first layer in the set. # Warning: This may return buffers instead of metas after stage setup. diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 78fafea34..f1eef47b9 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -4,14 +4,15 @@ from fast_llm.config import MISSING, Config, Field, FieldHint, FieldVerboseLevel, config_class from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.utils import compare_nested, log +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.utils import Assert, compare_nested, log if typing.TYPE_CHECKING: - import torch + from fast_llm.engine.base_model.base_model import BaseModel @config_class() -class BaseModelConfig(Config): +class ModuleConfig(Config): """ Abstract config class for a base model. # TODO: Find better name? @@ -43,7 +44,7 @@ def _get_architecture(self) -> dict[str, typing.Any]: return architecture def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: - if isinstance(value, BaseModelConfig): + if isinstance(value, ModuleConfig): # TODO: Make sure all nested configs have an architecture type hint? return value._get_architecture() elif isinstance(value, Config): @@ -57,12 +58,29 @@ def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: return self._serialize_value(value) -class Preprocessor(abc.ABC): - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - pass +@config_class() +class BaseModelConfig(ModuleConfig): + """ + Abstract config class for a base model. + """ + + def get_base_model(self, distributed_config: DistributedConfig) -> "BaseModel": + from fast_llm.tensor import ParameterMeta + + model = self.base_model_class(self, distributed_config) + # Storing the global name of each module and tensor. + # Done here because it needs to run right after `model.__init__()` + for key, value in model.named_modules(): + value.module_name = key + for key, value in model.named_parameters(): + Assert.custom(isinstance, value, ParameterMeta) + # Rename to the parameter full name + value.tensor_name = key + return model + @property @abc.abstractmethod - def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + def base_model_class(self) -> type["BaseModel"]: pass diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index d5202a90f..e055595bd 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -116,7 +116,7 @@ def setup( phase=PhaseType.validation, ) - self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions() + self._loss_defs = self._multi_stage.base_model.get_loss_definitions() self._evaluation_iterator = None self._is_setup = True diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index aa18f5052..27c0e2b7b 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -137,11 +137,6 @@ class StageConfig(Config): desc="Check for tensor-parallel desyncs and log an error if a desync is found. High overhead", hint=FieldHint.logging, ) - compile_all: bool = Field( - default=False, - desc="Compile the whole model using torch.compile.", - hint=FieldHint.expert, - ) @config_class() diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index e48fdb88b..77dc4e7dd 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -25,7 +25,6 @@ class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): - base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel _is_setup: bool = False _flat_shard: torch.Tensor _shards: dict[str, torch.Tensor] @@ -46,7 +45,8 @@ def __init__( stage_filter: set | None = None, ): super().__init__(config) - self._base_model = self.base_model_class(self._config.base_model, self._config.distributed) + self._base_model = self._config.base_model.get_base_model(self._config.distributed) + self._layers = self._base_model.get_layers() self._training = None self._verbose = verbose self._stage_filter = stage_filter @@ -67,10 +67,8 @@ def __init__( self._stages = [ Stage( config=self._config.multi_stage, - base_model=self._base_model, + layers=self._layers[stage_splits[i] : stage_splits[i + 1]], distributed_config=self._config.distributed, - begin=stage_splits[i], - end=stage_splits[i + 1], index=i, ) for i in (range(self._num_stages)) @@ -510,12 +508,9 @@ def _split_into_stages(self) -> list[int]: # Create stages (greedy split, could do better). stage_splits = [0] layer_counter, last_counter = 0, 0 - for i, layer in enumerate(self._base_model): + for i, layer in enumerate(self._layers): layer_counter += layer.layer_count # noqa - if ( - layer_counter >= last_counter + self._config.multi_stage.layers_per_stage - or i == len(self._base_model) - 1 - ): + if layer_counter >= last_counter + self._config.multi_stage.layers_per_stage or i == len(self._layers) - 1: stage_splits.append(i + 1) last_counter = layer_counter return stage_splits diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 7829c243b..bb3133256 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -139,7 +139,9 @@ def forward( else: # TODO: Handle variable shape. output_global = output - kwargs["hidden_states"][self._layer_range[i]] = { + + # TODO ====== Use ====== + kwargs["hidden_states"][self._layers[i].module_name] = { "layer_type": type(layer).__name__, "tensor": output_global, } @@ -223,9 +225,9 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] and self._distributed.tensor_group is not None and not self._meta_outputs[i].is_tensor_parallel ): - check_parallel_match(output, self._distributed.tensor_group, f"layer {self._layer_range[i]} fw") + check_parallel_match(output, self._distributed.tensor_group, f"layer {self._layers[i].module_name} fw") if self._config.debug_layer_outputs: - name = f"layer {self._layer_range[i]} fw" + name = f"layer {self._layers[i].module_name} fw" if (nmb := kwargs.get("num_micro_batches", 1)) > 1: name = f"{name}, mb={kwargs.get('micro_batch',0)}/{nmb}" if (nms := kwargs.get("micro_batch_splits", 1)) > 1: @@ -242,7 +244,7 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] meta=self._meta_outputs[i], ) if self._config.debug_activation_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"layer {self._layer_range[i]} fw", str)) + log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"layer {self._layers[i].module_name} fw", str)) def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any], i: int) -> None: if not input_.requires_grad: @@ -254,11 +256,11 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any ): input_.register_hook( lambda grad: check_parallel_match( - grad, self._distributed.tensor_group, f"layer {self._layer_range[i]} bw" + grad, self._distributed.tensor_group, f"layer {self._layers[i].module_name} bw" ) ) if self._config.debug_layer_gradients: - name = f"layer {self._layer_range[i]} bw" + name = f"layer {self._layers[i].module_name} bw" if (nmb := kwargs.get("num_micro_batches", 1)) > 1: name = f"{name}, mb={kwargs.get('micro_batch',0)}/{nmb}" if (nms := kwargs.get("micro_batch_splits", 1)) > 1: @@ -276,6 +278,6 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any if self._config.debug_activation_memory: input_.register_hook( lambda grad: log_pipeline_parallel_main_rank( - lambda: log_memory_usage(f"layer {self._layer_range[i]} bw", str) + lambda: log_memory_usage(f"layer {self._layers[i].module_name} bw", str) ) ) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index ded24e538..4778780ee 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -6,7 +6,7 @@ from fast_llm.config import Configurable from fast_llm.core.distributed import check_parallel_match -from fast_llm.engine.base_model.base_model import BaseModel, Layer +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed @@ -28,24 +28,17 @@ def __init__( self, *, config: StageConfig, - base_model: BaseModel | list[Layer], + layers: list[Layer], distributed_config: DistributedConfig, - begin: int, - end: int, index: int, ): super().__init__(config) self._distributed_config = distributed_config.validate() - Assert.in_range(begin, 0, end) - Assert.leq(end, len(base_model)) - self._fsdp_rank = self._distributed_config.data_rank self._fsdp_size = self._distributed_config.data_parallel self._is_setup = False self._index = index - - self._layers = [torch.compile(layer) if self._config.compile_all else layer for layer in base_model[begin:end]] - self._layer_range = list(range(begin, end)) + self._layers = layers parameter_metas, frozen_metas = self._get_parameter_metas() self._parameter_metas = parameter_metas + frozen_metas diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index dbdd035a4..58449f207 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -95,7 +95,7 @@ def __init__( self._num_stages = len(self._stages) self._loss_definitions = { loss_definition.name: loss_definition - for loss_definition in self._multi_stage.base_model.config.get_loss_definitions() + for loss_definition in self._multi_stage.base_model.get_loss_definitions() } def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index a752bec28..aa4f2d570 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -149,7 +149,7 @@ def __init__(self, config: TrainerConfig): multi_stage=self._multi_stage, distributed_config=self._config.model.distributed, ) - self._loss_defs = self._multi_stage.base_model.config.get_loss_definitions() + self._loss_definitions = self._multi_stage.base_model.get_loss_definitions() if not self._is_evaluation_only: steps_per_split = { @@ -320,7 +320,7 @@ def _run_training(self) -> None: phase=PhaseType.test, num_iters=self._config.training.test_iters, ) - formatted_metrics = format_metrics(metrics[metrics_key], self._loss_defs, PhaseType.test) + formatted_metrics = format_metrics(metrics[metrics_key], self._loss_definitions, PhaseType.test) log_main_rank(formatted_metrics) self._wandb.alert("Testing results", formatted_metrics, "WARN") # TODO: This may erase some metrics. @@ -331,7 +331,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: advanced_iters = 0 skipped_iters = 0 nan_iters = 0 - total_losses = {loss_def.name: 0.0 for loss_def in self._loss_defs} + total_losses = {loss_def.name: 0.0 for loss_def in self._loss_definitions} # Profiling profiler = self._config.profiling.get_profiler( @@ -435,7 +435,9 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: **get_and_reset_memory_usage_mib(), } - formatted_metrics = format_metrics(metrics[metrics_key], self._loss_defs, PhaseType.training) + formatted_metrics = format_metrics( + metrics[metrics_key], self._loss_definitions, PhaseType.training + ) logger.info(formatted_metrics) if self._config.training.wandb.alert.enabled(self._completed_steps): self._wandb.alert("Training results", formatted_metrics, "INFO") @@ -443,7 +445,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: advanced_iters = 0 skipped_iters = 0 nan_iters = 0 - total_losses = {loss_def.name: 0.0 for loss_def in self._loss_defs} + total_losses = {loss_def.name: 0.0 for loss_def in self._loss_definitions} self._run.save_logged_tensors(f"train_{self._completed_steps}") diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 2d4f049f0..167184193 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -56,6 +56,11 @@ class Attention[ConfigType: AttentionConfig](BlockWithBias[ConfigType]): _config: ConfigType + # Preprocessing + _backup_attention_mask: torch.Tensor + _backup_attention_mask_value: torch.Tensor + _backup_attention_tensor_cache_max_sequence_length: int = -1 + def __init__( self, config: ConfigType, @@ -431,3 +436,115 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c self.dense.get_compute_usage(dense_input, config), ) ) + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._rotary.preprocess(batch, kwargs) + if not self._use_flash_attention: + self._preprocess_for_backup_attention(batch, kwargs) + elif AttentionKwargs.sequence_lengths in kwargs: + self._preprocess_for_varlen(batch, kwargs) + + def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + if ( + sequence_length := kwargs[AttentionKwargs.sequence_length] + ) > self._backup_attention_tensor_cache_max_sequence_length: + # Create tensor cache. + self._backup_attention_tensor_cache_max_sequence_length = sequence_length + + self._backup_attention_mask = torch.ones( + (sequence_length, sequence_length), + dtype=torch.bool, + device=batch.device, + ).tril_() + + if self._config.window_size is not None: + self._backup_attention_mask.triu_(-self._config.window_size + 1) + self._backup_attention_mask_value = torch.full( + [], + torch.finfo(self._distributed_config.compute_dtype.torch).min, + dtype=self._distributed_config.compute_dtype.torch, + device=batch.device, + ) + + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + kwargs[AttentionKwargs.attention_mask] = self._backup_attention_mask[ + None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k + ] + if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: + seq_ids = torch.stack( + [ + torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) + for sample_lens in sequence_lengths + ] + ) + document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) + kwargs[AttentionKwargs.attention_mask] = ( + kwargs[AttentionKwargs.attention_mask] + & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] + ) + kwargs[AttentionKwargs.attention_mask_value] = self._backup_attention_mask_value + + def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + """ + Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 + cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively. + Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k. + If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally + also contain previous tokens from the first document in micro-sequence. + We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. + """ + if AttentionKwargs.sequence_lengths not in kwargs: + return + sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + if sequence_q < kwargs[AttentionKwargs.sequence_length]: + cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] + # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents + # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets + # of the first documents so that we can index into their kv pairs + start_seq_idx = [ + torch.argmax((cu_seqlens >= sequence_k - sequence_q).to(torch.uint8), dim=0) for cu_seqlens in cumsums + ] + end_seq_idx = [torch.argmax((cu_seqlens >= sequence_k).to(torch.uint8), dim=0) for cu_seqlens in cumsums] + seqlens_q = [] + seqlens_k = [] + for idx, sample_seqlens in enumerate(sequence_lengths): + start_idx = start_seq_idx[idx] + end_idx = end_seq_idx[idx] + seqlens_q.extend([0] * start_idx) + n_attention_tokens = sample_seqlens[end_idx] - (cumsums[idx][end_idx] - sequence_k) + if start_idx == end_idx: + seqlens_q.append(sequence_q) + else: + start_q_tokens = cumsums[idx][start_idx] - (sequence_k - sequence_q) + seqlens_q.extend( + [ + start_q_tokens, + *(sample_seqlens[idx] for idx in range(start_idx + 1, end_idx)), + n_attention_tokens, + ] + ) + seqlens_k.extend(sample_seqlens[: end_idx + 1]) + seqlens_k[-1] = n_attention_tokens + seqlens_q = torch.tensor(seqlens_q, dtype=torch.int32) + seqlens_k = torch.tensor(seqlens_k, dtype=torch.int32) + else: + seqlens_q = torch.cat(sequence_lengths) + seqlens_k = torch.cat(sequence_lengths) + kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( + ( + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(batch.device), + ) + ) + kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( + ( + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(batch.device), + ) + ) + kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() + kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 2910c7c76..68b6dde91 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -3,9 +3,7 @@ import warnings from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig @@ -80,6 +78,11 @@ class AttentionConfig(MixerConfig): desc="Add biases to linear layers. May be overridden for individual layers.", hint=FieldHint.architecture, ) + causal: bool = Field( + default=True, + desc="Use causal attention. Turn this off only for bidirectional attention e.g., in Vision Transformer.", + hint=FieldHint.feature, + ) dropout: float = Field( default=0.0, desc="Dropout applied to the attention intermediate states.", @@ -121,19 +124,3 @@ def layer_class(self) -> "type[Attention]": def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) - - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. - # TODO: Find a better solution. - preprocessors: list[Preprocessor] = [ - self.rotary.get_layer(TensorDim("head_size", self.head_size)), - ] - if self.do_use_flash_attention(distributed_config): - from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor - - preprocessors.append(FlashAttnVarlenPreprocessor(self, distributed_config)) - else: - from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor - - preprocessors.append(BackupAttentionPreprocessor(self, distributed_config)) - return preprocessors diff --git a/fast_llm/layers/attention/preprocessing.py b/fast_llm/layers/attention/preprocessing.py deleted file mode 100644 index 204c08ad2..000000000 --- a/fast_llm/layers/attention/preprocessing.py +++ /dev/null @@ -1,153 +0,0 @@ -import logging -import typing - -import torch - -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs -from fast_llm.tensor import TensorMeta - -logger = logging.getLogger(__name__) - - -class BackupAttentionPreprocessor(Preprocessor): - _head_size_dim: TensorDim - _rotary_embedding_frequencies: torch.Tensor - _mask: torch.Tensor - _mask_value: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - - def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): - self._config = config - self._distributed_config = distributed_config - assert not self._config.do_use_flash_attention(self._distributed_config) - - def _create_tensors(self, sequence_length: int, device: torch.device) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - self._mask = torch.ones( - (sequence_length, sequence_length), - dtype=torch.bool, - device=device, - ).tril_() - - if self._config.window_size is not None: - self._mask.triu_(-self._config.window_size + 1) - self._mask_value = torch.full( - [], - torch.finfo(self._distributed_config.compute_dtype.torch).min, - dtype=self._distributed_config.compute_dtype.torch, - device=device, - ) - - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length], batch.device) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size - kwargs[AttentionKwargs.attention_mask] = self._mask[ - None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k - ] - if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: - seq_ids = torch.stack( - [ - torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) - for sample_lens in sequence_lengths - ] - ) - document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) - kwargs[AttentionKwargs.attention_mask] = ( - kwargs[AttentionKwargs.attention_mask] - & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] - ) - kwargs[AttentionKwargs.attention_mask_value] = self._mask_value - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( - ( - scalar_dim, - scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - scalar_dim, - kwargs[AttentionKwargs.sequence_k_dim], - ), - tensor_name=AttentionKwargs.attention_mask, - dtype=torch.bool, - ) - kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( - (scalar_dim,), - tensor_name=AttentionKwargs.attention_mask_value, - dtype=self._distributed_config.compute_dtype.torch, - ) - - -class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): - assert config.do_use_flash_attention(distributed_config) - - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - """ - Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 - cu_seqlens_q and cu_seqlens_k are cumulative sequence lengths for the query and key/value tensors, respectively. - Assumes a flattened batch of documents. In absence of sequence_data_parallelism, cu_seqlens_q = cu_seqlens_k. - If sequence_data_parallelism > 1, query tensors contain tokens only from current micro-sequence, whereas key/value tensors additionally - also contain previous tokens from the first document in micro-sequence. - We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. - """ - if AttentionKwargs.sequence_lengths not in kwargs: - return - sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size - if sequence_q < kwargs[AttentionKwargs.sequence_length]: - cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] - # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents - # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets - # of the first documents so that we can index into their kv pairs - start_seq_idx = [ - torch.argmax((cu_seqlens >= sequence_k - sequence_q).to(torch.uint8), dim=0) for cu_seqlens in cumsums - ] - end_seq_idx = [torch.argmax((cu_seqlens >= sequence_k).to(torch.uint8), dim=0) for cu_seqlens in cumsums] - seqlens_q = [] - seqlens_k = [] - for idx, sample_seqlens in enumerate(sequence_lengths): - start_idx = start_seq_idx[idx] - end_idx = end_seq_idx[idx] - seqlens_q.extend([0] * start_idx) - n_attention_tokens = sample_seqlens[end_idx] - (cumsums[idx][end_idx] - sequence_k) - if start_idx == end_idx: - seqlens_q.append(sequence_q) - else: - start_q_tokens = cumsums[idx][start_idx] - (sequence_k - sequence_q) - seqlens_q.extend( - [ - start_q_tokens, - *(sample_seqlens[idx] for idx in range(start_idx + 1, end_idx)), - n_attention_tokens, - ] - ) - seqlens_k.extend(sample_seqlens[: end_idx + 1]) - seqlens_k[-1] = n_attention_tokens - seqlens_q = torch.tensor(seqlens_q, dtype=torch.int32) - seqlens_k = torch.tensor(seqlens_k, dtype=torch.int32) - else: - seqlens_q = torch.cat(sequence_lengths) - seqlens_k = torch.cat(sequence_lengths) - kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( - ( - torch.zeros(1, dtype=torch.int32, device=batch.device), - torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(batch.device), - ) - ) - kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( - ( - torch.zeros(1, dtype=torch.int32, device=batch.device), - torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(batch.device), - ) - ) - kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 5bd7a9b87..26877ee0c 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -4,7 +4,7 @@ import warnings from fast_llm.config import Field, FieldHint, config_class -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.utils import Assert @@ -14,7 +14,7 @@ @config_class(registry=True) -class RotaryConfig(BaseModelConfig): +class RotaryConfig(ModuleConfig): # TODO: Move rotary to its own submodule. @classmethod diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 889711839..d57d72947 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -5,8 +5,7 @@ import torch from fast_llm.config import Configurable -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.attention.rotary.config import ( @@ -16,7 +15,6 @@ RotaryConfig, YarnRotaryConfig, ) -from fast_llm.tensor import TensorMeta from fast_llm.utils import div @@ -41,7 +39,7 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) -class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module, Preprocessor): +class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module): def __init__( self, config: ConfigType, @@ -56,6 +54,9 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor]: pass + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + pass + class NoRotary[ConfigType: NoRotaryConfig](Rotary[ConfigType]): def forward( @@ -63,12 +64,6 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor]: return query, key - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - pass - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - pass - class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[ConfigType]): _rotary_embedding_frequencies: torch.Tensor @@ -82,26 +77,6 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None ] kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( - ( - scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - scalar_dim, - self._head_size_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_q, - ) - kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( - ( - scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - scalar_dim, - self._head_size_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_k, - ) - def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 0f975c9c5..ab6cb22b0 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -4,13 +4,13 @@ import torch -from fast_llm.config import Config, Configurable -from fast_llm.engine.base_model.base_model import Layer, Module -from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.config import Configurable +from fast_llm.engine.base_model.base_model import Layer, LayerBase +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.logging import get_model_debug_level, log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -93,9 +93,9 @@ def __call__[ ) -class BaseBlock[ConfigType: Config](Configurable[ConfigType], Module): +class BlockBase[ConfigType: ModuleConfig](Configurable[ConfigType], LayerBase): """ - Base class for blocks and block-like layers (mlp, mixers, etc.). + Base class for blocks and block-like layers (mlp, mixers, block sequences, etc.). """ def __init__( @@ -115,11 +115,6 @@ def __init__( self._lr_scale = lr_scale self._peft = peft - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - raise NotImplementedError() - -class Block[ConfigType: Config](BaseBlock[ConfigType], Layer): - """ - Base class for actual blocks, i.e., base blocks that are also `Layers`. - """ +class Block[ConfigType: BlockConfig](BlockBase[ConfigType], Layer): + pass diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 47c1ab9b7..f3e93edeb 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,11 +1,9 @@ -import abc -import collections import functools import typing import warnings from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, Preprocessor +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -13,7 +11,8 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.block.block import Block + from fast_llm.layers.block.block import BlockBase + from fast_llm.layers.block.sequence import FixedBlockSequence, PatternBlockSequence class BlockDimNames: @@ -40,8 +39,8 @@ class BlockKwargs: grad_output = "grad_output" -@config_class() -class BaseBlockConfig(BaseModelConfig): +@config_class(registry=True) +class BlockConfig(ModuleConfig): """ Base configuration class for blocks and block-like layers (mlp, mixers, etc.). """ @@ -55,19 +54,6 @@ class BaseBlockConfig(BaseModelConfig): hint=FieldHint.feature, ) - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - return [] - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return [] - - -@config_class(registry=True) -class BlockConfig(BaseBlockConfig): - """ - Base configuration class for actual blocks, i.e., base blocks that are also `Layers`. - """ - @classmethod def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is BlockConfig and cls.get_subclass(default.get("type")) is None: @@ -78,16 +64,16 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi return super()._from_dict(default, strict=strict) @property - def layer_class(self) -> "type[Block]": + def layer_class(self) -> "type[BlockBase]": raise NotImplementedError() - def get_block( + def get_layer( self, distributed_config: DistributedConfig, hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - ) -> "Block": + ) -> "BlockBase": return self.layer_class( self, distributed_config, @@ -98,7 +84,7 @@ def get_block( @config_class(registry=True) -class BlockSequenceConfig(BaseModelConfig): +class BlockSequenceConfig(BlockConfig): @classmethod def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: if cls is BlockSequenceConfig and cls.get_subclass(default.get("type")) is None: @@ -106,21 +92,6 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi return FixedBlockSequenceConfig._from_dict(default, strict) return super()._from_dict(default, strict=strict) - @abc.abstractmethod - def __len__(self) -> int: - pass - - @abc.abstractmethod - def __getitem__(self, index: int) -> BlockConfig: - pass - - @abc.abstractmethod - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - pass - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return [] - @config_class(dynamic_type={BlockSequenceConfig: "fixed"}) class FixedBlockSequenceConfig(BlockSequenceConfig): @@ -136,18 +107,11 @@ class FixedBlockSequenceConfig(BlockSequenceConfig): valid=check_field(Assert.geq, 0), ) - def __len__(self) -> int: - return self.num_blocks - - def __getitem__(self, index: int) -> BlockConfig: - return self.block - - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # TODO: Prevent name conflicts in preprocessed kwargs. - return self.block.get_preprocessors(distributed_config) + @property + def layer_class(self) -> "type[FixedBlockSequence]": + from fast_llm.layers.block.sequence import FixedBlockSequence - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.block.get_loss_definitions(count=count * self.num_blocks) + return FixedBlockSequence @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) @@ -180,26 +144,18 @@ def _validate(self): super()._validate() - def __len__(self) -> int: - return self.num_blocks + @property + def layer_class(self) -> "type[PatternBlockSequence]": + from fast_llm.layers.block.sequence import PatternBlockSequence - def __getitem__(self, index: int) -> BlockConfig: - return self.blocks[self.expanded_pattern[index]] + return PatternBlockSequence @functools.cached_property def expanded_pattern(self) -> list[str]: + # The complete list of block names, expanded to `num_blocks` return (self.pattern * (self.num_blocks // len(self.pattern) + 1))[: self.num_blocks] - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - # TODO: Prevent name conflicts in preprocessed kwargs. - return sum((block.get_preprocessors(distributed_config) for block in self.blocks.values()), []) - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - # TODO: Prevent name conflicts. - return sum( - ( - self.blocks[name].get_loss_definitions(count=count * count_) - for name, count_ in collections.Counter(self.expanded_pattern).items() - ), - [], - ) + @functools.cached_property + def preprocessing_layers(self) -> dict[str, int]: + # The index at which each block first appears. These blocks are used for preprocessing. + return {name: self.expanded_pattern.index(name) for name in set(self.expanded_pattern)} diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index e69de29bb..57621a848 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -0,0 +1,108 @@ +import collections + +import torch.nn + +from fast_llm.engine.base_model.base_model import Layer, LayerWithNamespace +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig +from fast_llm.layers.common.peft.config import PeftConfig + + +class FixedBlockSequence[ConfigType: FixedBlockSequenceConfig](BlockBase[ConfigType], torch.nn.ModuleList): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + self.extend( + layers := [ + self._config.block.get_layer( + distributed_config, + hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + for _ in range(self._config.num_blocks) + ] + ) + # Wrap all blocks in a namespace using the unique module name of the first one. + namespace = layers[0].module_name if self._config.num_blocks > 0 else "" + # Note: Pytorch won't redundantly register modules because it doesn't look into lists. + self._layers_with_namespace = [ + LayerWithNamespace(sublayer, namespace) for layer in layers for sublayer in layer.get_layers() + ] + + def get_layers(self) -> list["Layer"]: + return self._layers_with_namespace + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return self[0].get_loss_definitions(count=count * self.num_blocks) if self._config.num_blocks > 0 else [] + + +class PatternBlockSequence[ConfigType: PatternBlockSequenceConfig](BlockBase[ConfigType], torch.nn.ModuleList): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + self.extend( + layers := [ + self._config.blocks[name].get_layer( + distributed_config, + hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + for name in self._config.expanded_pattern + ] + ) + # Wrap each set of blocks with identical config in a namespace + # using the unique module name of the first such block. + # Note: Pytorch won't redundantly register modules because it doesn't look into lists. + self._layers_with_namespace = [ + LayerWithNamespace(sublayer, layers[self._config.preprocessing_layers[name]].module_name) + for name, layer in zip(self._config.expanded_pattern, layers) + for sublayer in layer.get_layers() + ] + + def get_layers(self) -> list["Layer"]: + return self._layers_with_namespace + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + # TODO: Prevent name conflicts. + return sum( + ( + self[self._config.preprocessing_layers[name]].get_loss_definitions(count=count * count_) + for name, count_ in collections.Counter(self.expanded_pattern).items() + ), + [], + ) diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index c1ced10df..a80a19280 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -3,7 +3,7 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import ParameterConfig, combine_lr_scales from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert @@ -26,7 +26,7 @@ class NormalizationImplementation(str, enum.Enum): @config_class(registry=True) -class NormalizationConfig(BaseModelConfig): +class NormalizationConfig(ModuleConfig): lr_scale: float | None = Field( default=None, desc="Scaling factor for the layer learning rate." diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index ce9893ea4..08dd5a815 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -4,22 +4,21 @@ import torch -from fast_llm.config import Config from fast_llm.core.distributed import set_generator -from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) -class BlockWithBias[ConfigType: Config](Block[ConfigType]): +class BlockWithBias[ConfigType: BlockWithBiasConfig](Block[ConfigType]): """ Base class for mixer and MLP modules. """ @@ -85,12 +84,9 @@ def __init__( ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input = return_input - # Note, layer_lr_scale does not impact the norms - # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) self.norm_2 = self._config.normalization.get_layer(self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft) - # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. self.mixer = self._config.mixer.get_layer( self._distributed_config, self._hidden_dim, @@ -178,3 +174,6 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c self.mlp.get_compute_usage(input_, kwargs, config), ) ) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 724b8d172..403b204c8 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -1,7 +1,6 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import LossDef, Preprocessor from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -96,7 +95,7 @@ def layer_class(self) -> "type[DecoderBlock]": return DecoderBlock - def get_block( + def get_layer( self, distributed_config: DistributedConfig, hidden_dim: TensorDim, @@ -112,9 +111,3 @@ def get_block( peft=peft, return_input=return_input, ) - - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - return self.mixer.get_preprocessors(distributed_config) + self.mlp.get_preprocessors(distributed_config) - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) diff --git a/fast_llm/layers/decoder/mlp/config.py b/fast_llm/layers/decoder/mlp/config.py index 100f53740..36841b45b 100644 --- a/fast_llm/layers/decoder/mlp/config.py +++ b/fast_llm/layers/decoder/mlp/config.py @@ -3,7 +3,6 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import LossDef from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.layers.common.linear.config import AffineLinearConfig, LinearConfig from fast_llm.layers.decoder.config import MLPBaseConfig @@ -152,23 +151,3 @@ def _validate(self) -> None: super()._validate() Assert.leq(self.shared_experts, self.experts) Assert.leq(self.shared_experts + self.experts_per_token, self.experts) - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_definitions = [] - if self.routing == RoutingType.topk: - loss_definitions.append( - LossDef( - name=MLPLossNames.load_balancing_loss, - formatted_name="load balancing loss", - count=1, - ) - ) - if self.z_loss_coefficient: - loss_definitions.append( - LossDef( - name=MLPLossNames.router_z_loss, - formatted_name="router z loss", - count=1, - ) - ) - return loss_definitions diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index d4cb46dbf..ffc9eadba 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -5,7 +5,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, set_generator -from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -263,6 +263,26 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c return super().get_compute_usage(moe_input, kwargs, config) + self.router.get_compute_usage(input_, config) + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + loss_definitions = [] + if self._config.routing == RoutingType.topk: + loss_definitions.append( + LossDef( + name=MLPLossNames.load_balancing_loss, + formatted_name="load balancing loss", + count=1, + ) + ) + if self._config.z_loss_coefficient: + loss_definitions.append( + LossDef( + name=MLPLossNames.router_z_loss, + formatted_name="router z loss", + count=1, + ) + ) + return loss_definitions + def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 1af6bdc38..85a84f508 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,7 +1,7 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import BaseModelConfig, LossDef, Preprocessor +from fast_llm.engine.base_model.config import LossDef, ModuleConfig from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -9,25 +9,13 @@ from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead - - -class LanguageModelLossNames: - language_model_loss = "language_model_loss" - z_loss = "z_loss" - dpo_loss = "dpo_loss" - distil_lm_loss = "distillation_language_model_loss" # the next token perdiciton of combined distillation loss - distillation_loss = "distillation_loss" - - @staticmethod - def multi_token_prediction_loss(index: int) -> str: - if index == 0: - return LanguageModelLossNames.language_model_loss - return f"language_model_loss_{index}" + from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction class LanguageModelKwargs(BlockKwargs): @@ -100,17 +88,37 @@ def layer_class(self) -> "type[LanguageModelEmbedding]": return LanguageModelEmbedding - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - preprocessors = [] - if self.position_embeddings.enabled: - from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor - preprocessors.append(PositionEmbeddingPreprocessor(self, distributed_config)) - return preprocessors +@config_class(registry=True) +class LanguageModelHeadBaseConfig(BlockConfig): + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is LanguageModelHeadBaseConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return LanguageModelHeadConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + def get_layer( + self, + distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + return self.layer_class( + self, + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + ) -@config_class() -class LanguageModelHeadConfig(BlockConfig): + +@config_class(dynamic_type={LanguageModelHeadBaseConfig: "default"}) +class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): _abstract = False normalization: NormalizationConfig = Field( desc="Configuration for the final normalization layer.", @@ -121,17 +129,6 @@ class LanguageModelHeadConfig(BlockConfig): desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - tied_weight: bool = Field( - default=True, - desc="Tie the output weights (logits) with the vocabulary embedding.", - hint=FieldHint.architecture, - ) - prediction_heads: int = Field( - default=1, - desc="Number of multi-token prediction heads.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) cross_entropy_implementation: CrossEntropyImpl = Field( default=CrossEntropyImpl.auto, desc="Implementation for the cross-entropy computation.", @@ -173,12 +170,6 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - prediction_loss_coefficient: list[float] | None = Field( - default=None, - desc="Loss coefficient for each prediction head.", - doc="If not provided, all heads are equally weighted.", - hint=FieldHint.feature, - ) teacher_softmax_temperature: float = Field( default=1.0, desc="Divides distillation target logits by this factor.", @@ -208,6 +199,30 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.feature, ) + def get_layer( + self, + distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + prediction_distance: int = 0, + prediction_heads: int = 1, + loss_coefficient: float = 1.0, + ): + return self.layer_class( + self, + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + prediction_distance=prediction_distance, + prediction_heads=prediction_heads, + loss_coefficient=loss_coefficient, + ) + @property def layer_class(self) -> "type[LanguageModelHead]": from fast_llm.layers.language_model.head import LanguageModelHead @@ -222,125 +237,76 @@ def _validate(self) -> None: else: self.language_model_loss_factor = 0.0 super()._validate() - if self.distillation_model is not None: - if self.prediction_heads > 1: - raise NotImplementedError("Multi-token prediction not supported with distillation.") + + +@config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) +class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): + _abstract = False + # Needs to be `DecoderBlockConfig` for the `return_input` interface. + # TODO: Make a generic wrapper for returning input instead? + # TODO ====== Tied weight ====== + block: DecoderBlockConfig = Field( + desc="Configuration for the decoder block before each head.", + hint=FieldHint.architecture, + ) + # TODO: Generalize? (needs the extra initialization arguments) + head: LanguageModelHeadConfig = Field( + desc="Configuration for the multi-token-prediction heads.", + hint=FieldHint.architecture, + ) + prediction_heads: int = Field( + default=1, + desc="Prediction heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + # TODO ====== Adjust ====== + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) + + def _validate(self) -> None: + super()._validate() if isinstance(self.prediction_loss_coefficient, list): Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) for coeff in self.prediction_loss_coefficient: Assert.geq(coeff, 0) - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - preprocessors: list[Preprocessor] = [] - - if self.enable_dpo: # TODO better way to pass in? - from fast_llm.layers.language_model.preprocessing import PreferenceSpanPreprocessor - - preprocessors.append(PreferenceSpanPreprocessor()) + @property + def layer_class(self) -> "type[MultiTokenPrediction]": + from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction - return preprocessors + return MultiTokenPrediction def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_defs = [] - if self.logit_z_loss: - LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=count) - - if self.enable_dpo: - loss_defs.append(LossDef(name=LanguageModelLossNames.dpo_loss, formatted_name="dpo loss", count=count)) - - if self.distillation_model is not None: - loss_defs.append( - LossDef(name=LanguageModelLossNames.distillation_loss, formatted_name="distillation loss", count=count) - ) - if self.language_model_loss_factor > 0.0: - loss_defs.append( - LossDef( - name=LanguageModelLossNames.distil_lm_loss, formatted_name="distillation lm loss", count=count - ) - ) - - for i in range(self.prediction_heads): - loss_defs.append( - LossDef( - name=LanguageModelLossNames.multi_token_prediction_loss(i), - formatted_name=f"language model loss {i}", - count=count, - ) - ) - return loss_defs - - def get_block( - self, - distributed_config: DistributedConfig, - embeddings_config: LanguageModelEmbeddingsConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - prediction_distance: int = 0, - ): - return self.layer_class( - self, - distributed_config, - embeddings_config, - hidden_dim=hidden_dim, - lr_scale=combine_lr_scales(lr_scale, self.lr_scale), - peft=peft, - prediction_distance=prediction_distance, + # TODO ====== Wrong ====== + return self.block.get_loss_definitions(count=count * self.prediction_heads) + self.head.get_loss_definitions( + count=count * self.prediction_heads ) - def get_blocks( - self, - distributed_config: DistributedConfig, - embeddings_config: LanguageModelEmbeddingsConfig, - mtp_block_config: BlockConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - ): - blocks = [] - for i in range(self.prediction_heads): - if i > 0: - blocks.append( - mtp_block_config.get_block( - distributed_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, - # The last block only returns the model output. - # The previous blocks return a stack of shared_hidden and transformer_output. - return_input=i < self.prediction_heads - 1, - ) - ) - blocks.append( - self.get_block( - distributed_config, - embeddings_config, - hidden_dim=hidden_dim, - lr_scale=lr_scale, - peft=peft, - prediction_distance=i, - ) - ) - return blocks - -# TODO: `BlockSequenceConfig`? (interface not fully compatible) @config_class() -class LanguageModelBaseConfig(BaseModelConfig): +class LanguageModelConfig(ModuleConfig): # TODO: block decoder: BlockSequenceConfig = Field( desc="Configuration for the language model decoder.", hint=FieldHint.architecture, ) embeddings_layer: LanguageModelEmbeddingsConfig = Field() - output_layer: LanguageModelHeadConfig = Field() + output_layer: LanguageModelHeadBaseConfig = Field() # TODO: Allow overriding in sub-models? peft: PeftConfig = Field( desc="Configuration for parameter-efficient fine tuning.", hint=FieldHint.architecture, ) + tied_embedding_weight: bool = Field( + default=False, + desc="Tie the output weights (logits) with the vocabulary embedding.", + hint=FieldHint.architecture, + ) sequence_first: bool | None = Field( default=None, desc="Override the default dimension ordering", @@ -349,67 +315,3 @@ class LanguageModelBaseConfig(BaseModelConfig): " Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error.", hint=FieldHint.testing, ) - - def __len__(self) -> int: - return len(self.decoder) + 2 * self.output_layer.prediction_heads - - def __getitem__(self, index: int) -> BlockConfig: - if index <= 0: - Assert.eq(index, 0) - return self.embeddings_layer - elif index <= len(self.decoder): - return self.decoder[index - 1] - else: - # Start at the last decoder layer so all MTP heads are treated similarly. - index - len(self.decoder) - return self.embeddings_layer - - def get_preprocessors(self, distributed_config: DistributedConfig) -> list[Preprocessor]: - return ( - self.embeddings_layer.get_preprocessors(distributed_config) - + self.decoder.get_preprocessors(distributed_config) - + self.output_layer.get_preprocessors(distributed_config) - ) - - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return ( - self.embeddings_layer.get_loss_definitions(count) - + self.decoder.get_loss_definitions(count) - + self.output_layer.get_loss_definitions(count) - ) - - def get_blocks(self, distributed_config: DistributedConfig): - hidden_dim = TensorDim("hidden", self.embeddings_layer.hidden_size) - return [ - self.embeddings_layer.get_block( - distributed_config, - hidden_dim=hidden_dim, - lr_scale=None, - peft=self.peft, - ), - *[ - self.decoder[i].get_block( - distributed_config, - hidden_dim, - lr_scale=None, - peft=self.peft, - # The last layer only returns the transformer output. - # The previous layers return a stack of shared_hidden and transformer_output. - # TODO: Not all blocks support this argument. - **( - {"return_input": True} - if self.output_layer.prediction_heads > 1 and i == len(self.decoder) - 1 - else {} - ), - ) - for i in range(len(self.decoder)) - ], - *self.output_layer.get_blocks( - distributed_config, - self.embeddings_layer, - self.decoder[len(self.decoder) - 1], - hidden_dim=hidden_dim, - lr_scale=None, - peft=self.peft, - ), - ] diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 362ffaa22..6e3bbc901 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -127,3 +127,30 @@ def forward( def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (embeddings) return 0 + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + if not self._config.position_embeddings.enabled: + return + self._create_position_embeddings(kwargs[LanguageModelKwargs.sequence_length], batch.device) + sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: + position_ids = torch.stack( + [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] + ).to(batch.device, dtype=torch.int64) + position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] + if kwargs[LanguageModelKwargs.sequence_first]: + position_ids = position_ids.transpose(0, 1) + kwargs[LanguageModelKwargs.position_ids] = position_ids + else: + kwargs[LanguageModelKwargs.position_ids] = self._position_ids[ + sequence_k - sequence_q : sequence_k + ].unsqueeze(int(kwargs[LanguageModelKwargs.sequence_first])) + + def _create_position_embeddings(self, sequence_length: int, device: torch.device) -> None: + if sequence_length <= self._tensor_cache_max_sequence_length: + return + self._tensor_cache_max_sequence_length = sequence_length + + Assert.leq(sequence_length, self._config.num_position_embeddings) + self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 4b080b360..42b7e3d6c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,3 +1,4 @@ +import functools import logging import typing @@ -6,7 +7,7 @@ from torch.distributed import all_reduce from fast_llm.core.ops import split_op -from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames @@ -23,7 +24,6 @@ LanguageModelEmbeddingsConfig, LanguageModelHeadConfig, LanguageModelKwargs, - LanguageModelLossNames, ) from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.tensor import TensorMeta @@ -51,7 +51,9 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - prediction_distance: int, + prediction_distance: int = 0, + prediction_heads: int = 1, + loss_coefficient: float = 1.0, ): super().__init__( config, @@ -60,6 +62,17 @@ def __init__( lr_scale=lr_scale, peft=peft, ) + if prediction_distance > 0 and ( + self._config.distillation_model is not None or self._config.dpo_reference_model is not None + ): + raise NotImplementedError("Multi-token prediction not supported with distillation or dpo.") + + Assert.in_range(prediction_distance, 0, prediction_heads) + self._prediction_distance = prediction_distance + self._prediction_heads = prediction_heads + self._loss_coefficient = loss_coefficient + self._is_last_head = self._prediction_distance == self._prediction_heads - 1 + self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -67,20 +80,6 @@ def __init__( if self._config.cross_entropy_splits is not None and self._sequence_parallel: assert not self._vocab_parallel - self._loss_coefficient = ( - self._config.prediction_loss_coefficient[prediction_distance] - if self._config.prediction_loss_coefficient - else 1.0 - ) - self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - - # Distance of the target token prediction - # 0: next-token prediction - # >0: multi-token prediction (MTP) - Assert.geq(prediction_distance, 0) - self._prediction_distance = prediction_distance - self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 - if not self._config.enable_dpo: self._cross_entropy_impl = self._config.cross_entropy_implementation if self._cross_entropy_impl == CrossEntropyImpl.auto: @@ -222,9 +221,7 @@ def _get_targets( if lm_target is not None: # MTP: Shift the labels lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) - + 1 - - self._config.prediction_heads + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads ) if LanguageModelKwargs.sequence_q_dim in kwargs: Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) @@ -336,7 +333,7 @@ def _logits_cross_entropy_forward_backward( self.training, grad_output, losses, - LanguageModelLossNames.z_loss, + self._z_loss_name, logits_scale_factor=self._config.logits_scale_factor, ) if self._debug.enabled and self._config.cross_entropy_splits is None: @@ -424,14 +421,81 @@ def _logits_cross_entropy_forward_backward( loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) if self.training and losses is not None: if dpo_loss is not None: - losses[LanguageModelLossNames.dpo_loss].append(dpo_loss.detach()) + losses[self._dpo_loss_name].append(dpo_loss.detach()) if self._config.distillation_model is not None and distillation_loss is not None: - losses[LanguageModelLossNames.distillation_loss].append(distillation_loss.detach()) + losses[self._distillation_language_model_loss_name].append(distillation_loss.detach()) if self._config.distillation_model is not None and lm_loss is not None: - losses[LanguageModelLossNames.distil_lm_loss].append(lm_loss.detach()) + losses[self._distillation_loss_name].append(lm_loss.detach()) return loss, output_parallel_linear_backward(grad, context) if self.training else None + @functools.cached_property + def _loss_name(self) -> str: + name = "language_model_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _z_loss_name(self) -> str: + name = "z_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _dpo_loss_name(self) -> str: + name = "dpo_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _distillation_language_model_loss_name(self) -> str: + name = "distillation_language_model_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _distillation_loss_name(self) -> str: + name = "distillation_loss" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] + if self._config.logit_z_loss: + LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) + if self._config.enable_dpo: + loss_defs.append( + LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) + ) + + if self._config.distillation_model is not None: + loss_defs.append( + LossDef( + name=self._distillation_loss_name, + formatted_name=_format_name(self._distillation_loss_name), + count=count, + ) + ) + if self._config.language_model_loss_factor > 0.0: + loss_defs.append( + LossDef( + name=self._distillation_language_model_loss_name, + formatted_name=_format_name(self._distillation_language_model_loss_name), + count=count, + ) + ) + + return loss_defs + + +def _format_name(name: str) -> str: + return name.replace("_", " ") + def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: tensors = [tensor for tensor in tensors if tensor is not None] diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py new file mode 100644 index 000000000..79555d866 --- /dev/null +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -0,0 +1,77 @@ +import torch + +from fast_llm.engine.base_model.base_model import Layer, LayerWithNamespace +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, MultiTokenPredictionConfig + + +class MultiTokenPrediction[ConfigType: MultiTokenPredictionConfig](BlockBase[ConfigType]): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + embeddings_config: LanguageModelEmbeddingsConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + self.blocks = torch.nn.ModuleList( + [ + self._config.block.get_layer( + self._distributed_config, + self._hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + # The last block only returns the model output. + # The previous blocks return a stack of shared_hidden and transformer_output. + return_input=index < self._config.prediction_heads - 1, + ) + for index in range(self._config.prediction_heads) + ] + ) + self.heads = torch.nn.ModuleList( + [ + self._config.head.get_layer( + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + prediction_distance=index, + prediction_heads=self._config.prediction_heads, + loss_coefficient=( + 1.0 + if self._config.prediction_loss_coefficient is None + else self._config.prediction_loss_coefficient[index] + ), + ) + for index in range(self._config.prediction_heads) + ] + ) + + # Wrap all blocks in a namespace using the unique module name of the first one. + namespace = self.blocks[0].module_name + # Note: Pytorch won't redundantly register modules because it doesn't look into lists. + self._blocks_with_namespace = [ + LayerWithNamespace(sublayer, namespace) for layer in self.blocks for sublayer in layer.get_layers() + ] + + def get_layers(self) -> list[Layer]: + return [ + module + for block, head in zip(self._blocks_with_namespace, self.heads, strict=True) + for module in (block, head) + ] diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py deleted file mode 100644 index fc1dac299..000000000 --- a/fast_llm/layers/language_model/preprocessing.py +++ /dev/null @@ -1,107 +0,0 @@ -import logging -import typing - -import torch - -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_dim import scalar_dim -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs -from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert - -logger = logging.getLogger(__name__) - - -class PositionEmbeddingPreprocessor(Preprocessor): - _rotary_embedding_frequencies: torch.Tensor - _position_ids: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - - def __init__(self, config: LanguageModelEmbeddingsConfig, distributed_config: DistributedConfig): - self._config = config - assert config.position_embeddings.enabled - self._distributed_config = distributed_config - - def _create_tensors(self, sequence_length: int, device: torch.device) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - Assert.leq(sequence_length, self._config.num_position_embeddings) - self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) - - def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[LanguageModelKwargs.sequence_length], batch.device) - sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size - sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: - position_ids = torch.stack( - [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] - ).to(batch.device, dtype=torch.int64) - position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] - if kwargs[LanguageModelKwargs.sequence_first]: - position_ids = position_ids.transpose(0, 1) - kwargs[LanguageModelKwargs.position_ids] = position_ids - else: - kwargs[LanguageModelKwargs.position_ids] = self._position_ids[ - sequence_k - sequence_q : sequence_k - ].unsqueeze(int(kwargs[LanguageModelKwargs.sequence_first])) - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - # Position embeddings will be broadcast. - sequence_q_dim = kwargs[LanguageModelKwargs.sequence_q_dim] - kwargs[LanguageModelKwargs.position_ids] = TensorMeta.from_dims( - ( - (sequence_q_dim, scalar_dim) - if kwargs[LanguageModelKwargs.sequence_first] - else (scalar_dim, sequence_q_dim) - ), - tensor_name=LanguageModelKwargs.position_ids, - dtype=torch.int64, - ) - - -class PreferenceSpanPreprocessor(Preprocessor): - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - return - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size - sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels - - if LanguageModelKwargs.chosen_spans not in kwargs or LanguageModelKwargs.rejected_spans not in kwargs: - raise ValueError("Expected chosen spans or rejected spans to be found within the batch.") - - chosen_spans = kwargs[LanguageModelKwargs.chosen_spans] - chosen_valid_spans = [] - for spans in chosen_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - chosen_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans - - rejected_spans = kwargs[LanguageModelKwargs.rejected_spans] - rejected_valid_spans = [] - for spans in rejected_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - rejected_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 8fbb99cad..6721daea2 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,12 +4,13 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.models.gpt.conversion.config import ( AprielHybridSSMCheckpointFormat, AutoGPTHuggingfaceCheckpointFormat, @@ -26,7 +27,7 @@ if typing.TYPE_CHECKING: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM - from fast_llm.models.gpt.model import GPTInferenceRunner, GPTModel + from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel from fast_llm.models.gpt.trainer import GPTTrainer logger = logging.getLogger(__name__) @@ -80,7 +81,7 @@ def micro_batch_splits(self) -> int: @config_class() -class GPTBaseModelConfig(LanguageModelBaseConfig): +class GPTBaseModelConfig(LanguageModelConfig, BaseModelConfig): _abstract = False # Debug, to get an exact match with megatron init. @@ -88,6 +89,12 @@ class GPTBaseModelConfig(LanguageModelBaseConfig): default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing ) + @property + def base_model_class(self) -> type["GPTBaseModel"]: + from fast_llm.models.gpt.model import GPTBaseModel + + return GPTBaseModel + @config_class(dynamic_type={FastLLMModelConfig: "gpt"}) class GPTModelConfig(FastLLMModelConfig): diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b7d751a61..9ddcf6300 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -5,7 +5,7 @@ from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel, Layer -from fast_llm.engine.base_model.config import Preprocessor +from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner @@ -13,8 +13,6 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding -from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -37,22 +35,60 @@ def __init__( ): self._hidden_dim = TensorDim("hidden", config.embeddings_layer.hidden_size) super().__init__(config, distributed_config) + + hidden_dim = TensorDim("hidden", self.embeddings_layer.hidden_size) + self.embedding = self._config.embeddings_layer.get_layer( + distributed_config, + hidden_dim=hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + self.decoder = self._config.decoder.get_layer( + distributed_config, + hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + self.head = self._config.output_layer.get_layer( + distributed_config, + self._config.embeddings_layer, + hidden_dim=hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron( param, self._config.decoder.block, config.embeddings_layer.hidden_size ) # Noqa - # `self._reference_models` is not populated at this point, so we pass a mutable dict. - self._preprocessors: list[Preprocessor] = self._config.get_preprocessors(distributed_config) - def get_layers(self) -> list[Layer]: - return self._config.get_blocks(self._distributed_config) + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) + # self._preprocessors.append(self._config.vision_encoder.transformer.rotary.build(self._tensor_space)) + + def get_layers(self) -> list["Layer"]: + return self.embedding.get_layers() + self.decoder.get_layers() + self.head.get_layers() + + # TODO ====== Vision ====== + # def get_vision_layers(self) -> list[Layer]: + # vit_layers = [ + # VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, block_index=idx + 1) + # for idx in range(self._config.vision_encoder.transformer.num_layers) + # ] + # return [ + # PatchConv(self._config.vision_encoder, self._tensor_space), + # *vit_layers, + # VisionAdapter(self._config.vision_encoder, self._tensor_space), + # MultiModalEmbedding(self._config, self._tensor_space), + # ] def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: - # TODO: How much of this is generalizable? + # TODO ====== Remove (Move batch splitting elsewhere) ====== # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence if isinstance(batch_meta, GPTBatchConfig): @@ -113,6 +149,33 @@ def preprocess_meta( LanguageModelKwargs.mask_inputs: not truncate_documents, } + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # try: + # max_image_size = batch_meta.max_image_size + # except AttributeError: + # max_image_size = 256 + # logger.warning("Inference mode: max_image_size not provided, defaulting to 256") + # vision_kwargs = { + # VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, + # VisionEncoderKwargs.max_image_size: max_image_size, + # VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, + # VisionEncoderKwargs.kv_channels: self._tensor_space[VisionTransformerDimNames.kv_channels].size, + # VisionEncoderKwargs.out_channels: self._tensor_space[VisionEncoderDimNames.out_channels].size, + # } + # vision_hidden_dim = self._tensor_space[VisionTransformerDimNames.hidden] + # vision_hidden_dims = ( + # (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) + # if sequence_first + # else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) + # ) + # vision_kwargs.update( + # { + # VisionTransformerKwargs.hidden_dims: vision_hidden_dims, + # } + # ) + # common_kwargs.update(vision_kwargs) + sequence_k_pasts = range( sequence_q_dim.size * self._distributed_config.sequence_data_rank, sequence_length, @@ -142,8 +205,6 @@ def preprocess_meta( kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( hidden_dims[:2], tensor_name="labels", dtype=torch.int64 ) - for preprocessor in self._preprocessors: - preprocessor.preprocess_meta(kwargs) reference_kwargs = {} for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] @@ -157,6 +218,13 @@ def preprocess_meta( reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size + # preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) + # else: + # preprocessed_meta.append((tokens, kwargs)) + preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -170,7 +238,7 @@ def preprocess( iteration: int, metrics: dict | None = None, ) -> list[tuple[torch.Tensor, dict]]: - # TODO: How much of this is generalizable? + # TODO ====== Move batch splitting elsewhere, align interface with LayerBase ====== assert self._is_setup if preprocessed_meta is None: @@ -203,19 +271,20 @@ def preprocess( reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + token_ids = batch.token_ids if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. - batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() + token_ids = token_ids.transpose(0, 1).contiguous() preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size if sequence_first: - tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] + tokens = token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? - tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() + tokens = token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: @@ -235,10 +304,10 @@ def preprocess( if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if sequence_first: - labels = batch.token_ids[sequence_offset : sequence_k + prediction_heads] + labels = token_ids[sequence_offset : sequence_k + prediction_heads] else: # TODO: Avoid multiple contiguous calls? - labels = batch.token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() + labels = token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config if batch.loss_masking_spans is not None: @@ -255,57 +324,150 @@ def preprocess( valid_spans[:, 0].clamp_(min=sequence_offset) valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1) valid_spans -= sequence_offset - loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - loss_mask[start : end + 1, idx] = False + labels[start : end + 1, idx] = -100 else: - loss_mask[idx, start : end + 1] = False - if self._config.output_layer.distillation_model is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask - labels = torch.where(loss_mask, labels, -100) + labels[idx, start : end + 1] = -100 + + # TODO ====== Preference spans ====== + if batch.chosen_spans is not None: + chosen_valid_spans = [] + for spans in batch.chosen_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + + chosen_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans + + rejected_valid_spans = [] + for spans in batch.rejected_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + + rejected_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans + + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # if self._config.vision_encoder.image_break_token is not None: + # if not labels_cloned: + # labels = labels.clone() + # labels_cloned = True + # labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + # if self._config.vision_encoder.image_end_token is not None: + # if not labels_cloned: + # labels = labels.clone() + # labels_cloned = True + # labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) + # Loss-masking for distillation losses + if self._config.distillation_model is not None: + loss_mask = torch.ones_like(labels, dtype=torch.bool) + loss_mask = torch.where(labels == -100, False, loss_mask) + kwargs[LanguageModelKwargs.loss_mask] = loss_mask kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) - for preprocessor in self._preprocessors: - preprocessor.preprocess(tokens, kwargs) + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # batch_images = ( + # batch.images if batch.images is not None else [[]] * kwargs[AttentionKwargs.micro_batch_size] + # ) + # kwargs[VisionEncoderKwargs.images] = [ + # [ + # img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + # for img in images + # ] + # for images in batch_images + # ] + # kwargs[VisionEncoderKwargs.image_positions] = ( + # batch.image_positions + # if batch.image_positions is not None + # else [[]] * kwargs[AttentionKwargs.micro_batch_size] + # ) + # kwargs[LanguageModelKwargs.tokens] = tokens + + # TODO ====== Turn into super() call ====== + self.embedding.preprocess(tokens, kwargs) + self.decoder.preprocess(tokens, kwargs) + self.head.preprocess(tokens, kwargs) + + # TODO ====== Vision ====== + # image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + # if image_patches is not None: + # preprocessed.append((image_patches, kwargs)) + # else: + # preprocessed.append((tokens, kwargs)) + preprocessed.append((tokens, kwargs)) return preprocessed - @property - def embedding(self) -> LanguageModelEmbedding: - return self.layers[0] + # TODO ====== Vision ====== + # @property + # def embedding(self) -> LanguageModelEmbedding: + # return self.layers[self.embedding_layer_index] - @property - def model_head(self) -> LanguageModelHead: - return self.layers[self.model_head_indices[0]] + # @property + # def transformer_layers(self) -> list[TransformerBlock]: + # return self.layers[self.embedding_layer_index + 1 : -1] - @property - def model_head_indices(self) -> list[int]: - return sorted([len(self) - 1 - 2 * i for i in range(self._config.output_layer.prediction_heads)]) + # @property + # def embedding_layer_index(self) -> int: + # if self._config.vision_encoder.enabled: + # return self._config.vision_encoder.transformer.num_layers + 2 + # else: + # return 0 def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: - if self._config.output_layer.tied_weight: - return { - WORD_EMBEDDINGS_WEIGHT: ( - self.embedding.word_embeddings_weight, - (0, *self.model_head_indices), - ) - } - elif self._config.output_layer.prediction_heads > 1: - return { - OUTPUT_WEIGHTS: ( - self.model_head.output_weights, - tuple(self.model_head_indices), - ) - } - else: - return {} + # TODO ====== Tied weights ====== + if self._config.tied_embedding_weight: + raise NotImplementedError() + return {} + # if self._config.output_layer.tied_weight: + # return { + # WORD_EMBEDDINGS_WEIGHT: ( + # self.embedding.word_embeddings_weight, + # # TODO ====== Vision ====== + # # (self.embedding_layer_index, *self.model_head_indices), + # (0, *self.model_head_indices), + # ) + # } + # elif self._config.output_layer.prediction_heads > 1: + # return { + # OUTPUT_WEIGHTS: ( + # self.model_head.output_weights, + # tuple(self.model_head_indices), + # ) + # } + # else: + # return {} + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return ( + self.embeddings_layer.get_loss_definitions(count) + + self.decoder.get_loss_definitions(count) + + self.output_layer.get_loss_definitions(count) + ) class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): - base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel + # TODO: Can we drop class? + pass class GPTInferenceRunner(InferenceRunner): diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 0dc3462eb..b086c291f 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -41,7 +41,7 @@ def get_stage(base_model: BaseModel | list[Layer], distributed: Distributed): # Create a fast-llm stage which allocates and initializes meta tensors correctly. stage = Stage( config=StageConfig(), - base_model=base_model, + layers=base_model, distributed_config=distributed.config, begin=0, end=1, diff --git a/tools/generate_config_yaml_for_sharded_dst.py b/tools/generate_config_yaml_for_sharded_dst.py deleted file mode 100644 index c0b4fa24d..000000000 --- a/tools/generate_config_yaml_for_sharded_dst.py +++ /dev/null @@ -1,98 +0,0 @@ -import argparse -import pathlib - -import yaml - -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator, GPTMemmapDatasetPreparatorConfig - -""" -This script is intended to be used only for creation of fast_llm_config.yaml files for sharded datasets encoded with older version of the prepare command. -""" - - -def read_dataset_shard_config(shard_path): - """ - Read a dataset shard from the given path. - - Args: - shard_path: Path to the shard prefix (without .idx or .bin extension) - - Returns: - A GPTMemmapDataset instance - """ - # Convert to pathlib.Path if it's a string - path = pathlib.Path(shard_path) if isinstance(shard_path, str) else shard_path - - # Create a GPTMemmapDataset instance - # The name parameter is just for identification - dataset = GPTMemmapDataset(name=path.name, prefix=path) - - # Print basic information about the dataset - print(f"Dataset: {dataset.name}") - print(f"Number of documents: {dataset._num_documents}") - print(f"Number of tokens: {dataset.num_tokens}") - - return GPTMemmapDatasetConfig.from_dict( - { - "type": "memmap", - "path": path.name.replace(".bin", ""), - "num_documents": dataset._num_documents, - "num_tokens": dataset.num_tokens, - } - ) - - -def get_preparator(prepare_config: GPTMemmapDatasetPreparatorConfig) -> GPTMemmapDatasetPreparator: - config = GPTMemmapDatasetPreparatorConfig.from_dict( - { - "output_path": prepare_config.output_path, - "dataset": {"path": prepare_config.dataset.path}, - "tokenizer": {"path": prepare_config.tokenizer.path}, - }, - {}, - ) - return config.get_dataset_preparator_class()(config=config) - - -def main(config_dict): - prepare_config = GPTMemmapDatasetPreparatorConfig.from_dict(config_dict) - destination = pathlib.Path(prepare_config.output_path) - - shards = list(destination.glob("shard_*.bin")) - dataset_configs = [read_dataset_shard_config(shard) for shard in shards] - - preparator = get_preparator(prepare_config) - preparator.generate_config_yaml_for_sharded_dst(dataset_configs) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Generate config YAML for sharded datasets") - parser.add_argument( - "--prepare_config", - type=str, - required=False, - default=None, # "/home/toolkit/dev/Fast-LLM/.vscode/prepare_dst.yaml", - help="Path to the prepare config YAML file", - ) - parser.add_argument( - "--dataset_path", - type=str, - required=False, - default="/mnt/datasets/tokenized/Mistral-Nemo-Base-2407/FineWeb2/deu_Latn/", - help="Path to the dataset path", - ) - args = parser.parse_args() - - if args.prepare_config: - with open(args.prepare_config) as f: - config_dict = yaml.safe_load(f) - else: - assert args.dataset_path is not None, "Please provide a prepare config YAML file or dataset path" - config_dict = { - "output_path": args.dataset_path, - "dataset": {"path": "unknown"}, - "tokenizer": {"path": "no_tokenizer"}, - } - main(config_dict) From e23c68da2ee9e6cea249d62f6b265764cad71d0d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 1 Oct 2025 16:51:18 -0400 Subject: [PATCH 05/16] stuff --- examples/mistral.yaml | 2 +- fast_llm/engine/base_model/base_model.py | 8 +- fast_llm/layers/block/sequence.py | 34 ++++--- fast_llm/layers/language_model/config.py | 20 +++- fast_llm/layers/language_model/head.py | 9 +- fast_llm/models/gpt/config.py | 18 ++-- fast_llm/models/gpt/conversion/llama.py | 10 +- fast_llm/models/gpt/model.py | 114 ++++------------------- fast_llm/models/gpt/trainer.py | 6 +- tests/layers/test_lm_head.py | 44 ++++----- tests/models/test_checkpoint.py | 2 +- tests/test_config.py | 13 +-- tests/utils/model_configs.py | 13 +-- 13 files changed, 124 insertions(+), 169 deletions(-) diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 88655954f..987801892 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -55,10 +55,10 @@ model: dropout: 0.0 num_blocks: 32 output_layer: - tied_weight: false normalization: type: rms_norm epsilon: 1.0e-05 + tied_embedding_weight: false multi_stage: zero_stage: 2 distributed: diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index ce7002c54..5b1180a13 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -95,7 +95,13 @@ def setup(self, distributed: Distributed) -> None: def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: - return self._layer.forward(input_, kwargs[self._namespace], losses, metrics) + if self._namespace in kwargs: + kwargs = kwargs[self._namespace] + else: + # TODO: Forward meta doesn't go through preprocessing so doesn't have a namespace. + # Using kwargs as-is since it's generally unused. + assert isinstance(input_, TensorMeta) + return self._layer.forward(input_, kwargs.get(self._namespace, kwargs), losses, metrics) def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: assert self._namespace not in kwargs diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index 57621a848..33b884fdf 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -1,4 +1,5 @@ import collections +import functools import torch.nn @@ -30,8 +31,9 @@ def __init__( lr_scale=lr_scale, peft=peft, ) + self.extend( - layers := [ + [ self._config.block.get_layer( distributed_config, hidden_dim, @@ -41,18 +43,21 @@ def __init__( for _ in range(self._config.num_blocks) ] ) + + @functools.cached_property + def _layers_with_namespace(self) -> list[Layer]: + # This needs to be in a property because `module_name` is set after `__init__`. # Wrap all blocks in a namespace using the unique module name of the first one. - namespace = layers[0].module_name if self._config.num_blocks > 0 else "" - # Note: Pytorch won't redundantly register modules because it doesn't look into lists. - self._layers_with_namespace = [ - LayerWithNamespace(sublayer, namespace) for layer in layers for sublayer in layer.get_layers() - ] + namespace = self[0].module_name if self._config.num_blocks > 0 else "" + return [LayerWithNamespace(sublayer, namespace) for layer in self for sublayer in layer.get_layers()] def get_layers(self) -> list["Layer"]: return self._layers_with_namespace def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self[0].get_loss_definitions(count=count * self.num_blocks) if self._config.num_blocks > 0 else [] + return ( + self[0].get_loss_definitions(count=count * self._config.num_blocks) if self._config.num_blocks > 0 else [] + ) class PatternBlockSequence[ConfigType: PatternBlockSequenceConfig](BlockBase[ConfigType], torch.nn.ModuleList): @@ -75,7 +80,7 @@ def __init__( peft=peft, ) self.extend( - layers := [ + [ self._config.blocks[name].get_layer( distributed_config, hidden_dim, @@ -85,16 +90,19 @@ def __init__( for name in self._config.expanded_pattern ] ) + + @functools.cached_property + def _layers_with_namespace(self) -> list[Layer]: + # This needs to be in a property because `module_name` is set after `__init__`. # Wrap each set of blocks with identical config in a namespace # using the unique module name of the first such block. - # Note: Pytorch won't redundantly register modules because it doesn't look into lists. - self._layers_with_namespace = [ - LayerWithNamespace(sublayer, layers[self._config.preprocessing_layers[name]].module_name) - for name, layer in zip(self._config.expanded_pattern, layers) + return [ + LayerWithNamespace(sublayer, self[self._config.preprocessing_layers[name]].module_name) + for name, layer in zip(self._config.expanded_pattern, self) for sublayer in layer.get_layers() ] - def get_layers(self) -> list["Layer"]: + def get_layers(self) -> list[Layer]: return self._layers_with_namespace def get_loss_definitions(self, count: int = 1) -> list[LossDef]: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 85a84f508..9bdc12c16 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,3 +1,4 @@ +import abc import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none @@ -116,8 +117,13 @@ def get_layer( peft=peft, ) + @property + @abc.abstractmethod + def max_prediction_distance(self) -> int: + pass + -@config_class(dynamic_type={LanguageModelHeadBaseConfig: "default"}) +@config_class(dynamic_type={LanguageModelHeadBaseConfig: "language_model_head"}) class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): _abstract = False normalization: NormalizationConfig = Field( @@ -238,6 +244,10 @@ def _validate(self) -> None: self.language_model_loss_factor = 0.0 super()._validate() + @property + def max_prediction_distance(self) -> int: + return 1 + @config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): @@ -287,6 +297,10 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count * self.prediction_heads ) + @property + def max_prediction_distance(self) -> int: + return self.prediction_heads + @config_class() class LanguageModelConfig(ModuleConfig): @@ -295,8 +309,8 @@ class LanguageModelConfig(ModuleConfig): desc="Configuration for the language model decoder.", hint=FieldHint.architecture, ) - embeddings_layer: LanguageModelEmbeddingsConfig = Field() - output_layer: LanguageModelHeadBaseConfig = Field() + embeddings: LanguageModelEmbeddingsConfig = Field() + head: LanguageModelHeadBaseConfig = Field() # TODO: Allow overriding in sub-models? peft: PeftConfig = Field( desc="Configuration for parameter-efficient fine tuning.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 5e770bf23..b4d502b10 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -25,7 +25,6 @@ LanguageModelHeadConfig, LanguageModelKwargs, ) -from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -100,7 +99,8 @@ def __init__( "vocab", embeddings_config.vocab_size, self._parallel_dim if self._vocab_parallel else None ) # Only the first head defines the output weights - if self._prediction_distance == 0 and not self._config.tied_weight: + # TODO ====== tied_weight ====== + if self._prediction_distance == 0: # and not self._config.tied_weight: # untie embedding weights self.output_weights = self._config.output_weight.get_parameter( (self._vocab_dim, self._hidden_dim), @@ -245,8 +245,9 @@ def _get_targets( return targets def _get_output_weights(self, kwargs: dict) -> torch.Tensor: - if self._config.tied_weight: - return kwargs[WORD_EMBEDDINGS_WEIGHT] + # TODO ====== tied_weight ====== + # if self._config.tied_weight: + # return kwargs[WORD_EMBEDDINGS_WEIGHT] if self._prediction_distance > 0: return kwargs[OUTPUT_WEIGHTS] return self.output_weights diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 6721daea2..9c8e90d44 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -148,18 +148,18 @@ class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): def _validate(self) -> None: if self.batch.sequence_length is None: # TODO: Drop this. - self.batch.sequence_length = self.model.base_model.embeddings_layer.num_position_embeddings + self.batch.sequence_length = self.model.base_model.embeddings.num_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() - if self.model.base_model.embeddings_layer.position_embeddings.enabled: - Assert.geq(self.model.base_model.embeddings_layer.num_position_embeddings, self.batch.sequence_length) + if self.model.base_model.embeddings.position_embeddings.enabled: + Assert.geq(self.model.base_model.embeddings.num_position_embeddings, self.batch.sequence_length) - distillation_model = self.model.base_model.output_layer.distillation_model - dpo_reference_model = self.model.base_model.output_layer.dpo_reference_model + distillation_model = self.model.base_model.head.distillation_model + dpo_reference_model = self.model.base_model.head.dpo_reference_model - if self.model.base_model.output_layer.enable_dpo: + if self.model.base_model.head.enable_dpo: assert dpo_reference_model is not None Assert.none(distillation_model) else: @@ -173,14 +173,14 @@ def _validate(self) -> None: Assert.eq(self.reference_models.keys(), expected_names) for reference_model in self.reference_models.values(): - output_layer = reference_model.model.base_model.output_layer + output_layer = reference_model.model.base_model.head Assert.none(output_layer.distillation_model) Assert.none(output_layer.dpo_reference_model) # TODO: Support more LM head features. Assert.none(output_layer.cross_entropy_splits) Assert.eq( - reference_model.model.base_model.embeddings_layer.vocab_parallel, - self.model.base_model.embeddings_layer.vocab_parallel, + reference_model.model.base_model.embeddings.vocab_parallel, + self.model.base_model.embeddings.vocab_parallel, ) Assert.geq(output_layer.prediction_heads, output_layer.prediction_heads) diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 629a3ceed..c22d1ebcb 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -538,18 +538,18 @@ def import_config(cls, config: dict) -> dict: def export_config(cls, config: GPTBaseModelConfig) -> dict: Assert.custom(isinstance, config, GPTBaseModelConfig) return safe_merge_dicts( - cls.embeddings_converter_class.export_config(config.embeddings_layer), + cls.embeddings_converter_class.export_config(config.embeddings), cls.decoder_converter_class.export_config(config.decoder), - cls.head_converter_class.export_config(config.output_layer), + cls.head_converter_class.export_config(config.head), ) @classmethod def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: return [ - *cls.embeddings_converter_class.get_converters(config.embeddings_layer, "layers.0", "model"), + *cls.embeddings_converter_class.get_converters(config.embeddings, "layers.0", "model"), *cls.decoder_converter_class.get_converters(config.decoder, "layers", "model.layers"), *cls.head_converter_class.get_converters( - config.output_layer, config.decoder[len(config.decoder) - 1], "layers", len(config.decoder) + 1 + config.head, config.decoder[len(config.decoder) - 1], "layers", len(config.decoder) + 1 ), ] @@ -557,7 +557,7 @@ def _create_weight_converters( self, ) -> list[WeightConverter]: base_model_config = self._model.config.base_model - self.embeddings_converter_class.get_converters(base_model_config.embeddings_layer, "layers.0", "model") + self.embeddings_converter_class.get_converters(base_model_config.embeddings, "layers.0", "model") converters = self.decoder_converter_class.get_converters(base_model_config.decoder, "layers", "model.layers") self.head_converter_class.get_converters( base_model_config.decoder, base_model_config.decoder.block, "layers", len(base_model_config.decoder) + 1 diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9ddcf6300..3d481a04a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -33,26 +33,25 @@ def __init__( config: GPTBaseModelConfig, distributed_config: DistributedConfig, ): - self._hidden_dim = TensorDim("hidden", config.embeddings_layer.hidden_size) super().__init__(config, distributed_config) - hidden_dim = TensorDim("hidden", self.embeddings_layer.hidden_size) - self.embedding = self._config.embeddings_layer.get_layer( + self._hidden_dim = TensorDim("hidden", config.embeddings.hidden_size) + self.embeddings = self._config.embeddings.get_layer( distributed_config, - hidden_dim=hidden_dim, + hidden_dim=self._hidden_dim, lr_scale=None, peft=self._config.peft, ) self.decoder = self._config.decoder.get_layer( distributed_config, - hidden_dim, + self._hidden_dim, lr_scale=None, peft=self._config.peft, ) - self.head = self._config.output_layer.get_layer( + self.head = self._config.head.get_layer( distributed_config, - self._config.embeddings_layer, - hidden_dim=hidden_dim, + self._config.embeddings, + hidden_dim=self._hidden_dim, lr_scale=None, peft=self._config.peft, ) @@ -61,29 +60,11 @@ def __init__( for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron( - param, self._config.decoder.block, config.embeddings_layer.hidden_size + param, self._config.decoder.block, config.embeddings.hidden_size ) # Noqa - # TODO ====== Vision ====== - # if self._config.vision_encoder.enabled: - # self._preprocessors.append(VisionPreprocessor(self._config.vision_encoder, self._tensor_space)) - # self._preprocessors.append(self._config.vision_encoder.transformer.rotary.build(self._tensor_space)) - def get_layers(self) -> list["Layer"]: - return self.embedding.get_layers() + self.decoder.get_layers() + self.head.get_layers() - - # TODO ====== Vision ====== - # def get_vision_layers(self) -> list[Layer]: - # vit_layers = [ - # VisionTransformerBlock(self._config.vision_encoder.transformer, self._tensor_space, block_index=idx + 1) - # for idx in range(self._config.vision_encoder.transformer.num_layers) - # ] - # return [ - # PatchConv(self._config.vision_encoder, self._tensor_space), - # *vit_layers, - # VisionAdapter(self._config.vision_encoder, self._tensor_space), - # MultiModalEmbedding(self._config, self._tensor_space), - # ] + return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType @@ -99,7 +80,7 @@ def preprocess_meta( else: micro_batch_size, sequence_length = batch_meta.shape if phase != PhaseType.inference: - sequence_length -= self._config.output_layer.prediction_heads + sequence_length -= self._config.head.prediction_heads micro_sequence_length = sequence_length truncate_documents = True @@ -247,7 +228,7 @@ def preprocess( _, common_kwargs = preprocessed_meta[0] sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size sequence_first = common_kwargs[AttentionKwargs.sequence_first] - prediction_heads: int = self._config.output_layer.prediction_heads + max_prediction_distance = self._config.head.max_prediction_distance batch.token_ids = batch.token_ids.to( device=self._distributed.device, @@ -304,10 +285,10 @@ def preprocess( if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if sequence_first: - labels = token_ids[sequence_offset : sequence_k + prediction_heads] + labels = token_ids[sequence_offset : sequence_k + max_prediction_distance] else: # TODO: Avoid multiple contiguous calls? - labels = token_ids[:, sequence_offset : sequence_k + prediction_heads].contiguous() + labels = token_ids[:, sequence_offset : sequence_k + max_prediction_distance].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config if batch.loss_masking_spans is not None: @@ -317,12 +298,13 @@ def preprocess( if not spans.numel(): continue valid_spans = spans[ - (spans[:, 0] <= sequence_k + prediction_heads - 1) & (spans[:, 1] >= sequence_offset) + (spans[:, 0] <= sequence_k + max_prediction_distance - 1) + & (spans[:, 1] >= sequence_offset) ] if valid_spans.numel(): # if span is partially within the sequence, truncate parts of spans that are outside of the sequence valid_spans[:, 0].clamp_(min=sequence_offset) - valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1) + valid_spans[:, 1].clamp_(max=sequence_k + max_prediction_distance - 1) valid_spans -= sequence_offset for start, end in valid_spans: if sequence_first: @@ -362,19 +344,6 @@ def preprocess( rejected_valid_spans.append(valid_spans) kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans - # TODO ====== Vision ====== - # if self._config.vision_encoder.enabled: - # if self._config.vision_encoder.image_break_token is not None: - # if not labels_cloned: - # labels = labels.clone() - # labels_cloned = True - # labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) - # if self._config.vision_encoder.image_end_token is not None: - # if not labels_cloned: - # labels = labels.clone() - # labels_cloned = True - # labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) - # Loss-masking for distillation losses if self._config.distillation_model is not None: loss_mask = torch.ones_like(labels, dtype=torch.bool) loss_mask = torch.where(labels == -100, False, loss_mask) @@ -382,57 +351,15 @@ def preprocess( kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) - # TODO ====== Vision ====== - # if self._config.vision_encoder.enabled: - # batch_images = ( - # batch.images if batch.images is not None else [[]] * kwargs[AttentionKwargs.micro_batch_size] - # ) - # kwargs[VisionEncoderKwargs.images] = [ - # [ - # img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) - # for img in images - # ] - # for images in batch_images - # ] - # kwargs[VisionEncoderKwargs.image_positions] = ( - # batch.image_positions - # if batch.image_positions is not None - # else [[]] * kwargs[AttentionKwargs.micro_batch_size] - # ) - # kwargs[LanguageModelKwargs.tokens] = tokens - # TODO ====== Turn into super() call ====== - self.embedding.preprocess(tokens, kwargs) + self.embeddings.preprocess(tokens, kwargs) self.decoder.preprocess(tokens, kwargs) self.head.preprocess(tokens, kwargs) - # TODO ====== Vision ====== - # image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) - # if image_patches is not None: - # preprocessed.append((image_patches, kwargs)) - # else: - # preprocessed.append((tokens, kwargs)) - preprocessed.append((tokens, kwargs)) return preprocessed - # TODO ====== Vision ====== - # @property - # def embedding(self) -> LanguageModelEmbedding: - # return self.layers[self.embedding_layer_index] - - # @property - # def transformer_layers(self) -> list[TransformerBlock]: - # return self.layers[self.embedding_layer_index + 1 : -1] - - # @property - # def embedding_layer_index(self) -> int: - # if self._config.vision_encoder.enabled: - # return self._config.vision_encoder.transformer.num_layers + 2 - # else: - # return 0 - def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: # TODO ====== Tied weights ====== if self._config.tied_embedding_weight: @@ -442,9 +369,6 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: # return { # WORD_EMBEDDINGS_WEIGHT: ( # self.embedding.word_embeddings_weight, - # # TODO ====== Vision ====== - # # (self.embedding_layer_index, *self.model_head_indices), - # (0, *self.model_head_indices), # ) # } # elif self._config.output_layer.prediction_heads > 1: @@ -459,9 +383,9 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: def get_loss_definitions(self, count: int = 1) -> list[LossDef]: return ( - self.embeddings_layer.get_loss_definitions(count) + self.embeddings.get_loss_definitions(count) + self.decoder.get_loss_definitions(count) - + self.output_layer.get_loss_definitions(count) + + self.head.get_loss_definitions(count) ) diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 4dbbfbb1c..26500212d 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -22,13 +22,13 @@ def _get_sampling_parameters( parameters = super()._get_sampling_parameters(parameters, _return_dict=True) parameters.update( { - "vocab_size": self._config.model.base_model.embeddings_layer.vocab_size, + "vocab_size": self._config.model.base_model.embeddings.vocab_size, "sequence_length": self._config.batch.sequence_length, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - "use_preference_loss_spans": self._config.model.base_model.output_layer.enable_dpo, + "use_preference_loss_spans": self._config.model.base_model.head.enable_dpo, "cross_document_attention": self._config.batch.cross_document_attention, "truncate_documents": self._config.batch.truncate_documents, - "extra_tokens": self._config.model.base_model.output_layer.prediction_heads, + "extra_tokens": self._config.model.base_model.head.max_prediction_distance, } ) return parameters if _return_dict else GPTSamplingParameters(**parameters) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index f14f028e1..cf01cd707 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -107,14 +107,14 @@ def _lm_head( ({}, {"compute_dtype": DataType.bfloat16}, False), ({"embeddings_layer": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False), ({"sequence_first": True}, {}, False), - ({"output_layer": {"logit_z_loss": 1e-3}}, {}, False), - ({"output_layer": {"logits_scale_factor": 5.0}}, {}, False), - ({"output_layer": {"tied_weight": False}}, {}, False), - ({"output_layer": {"prediction_heads": 2}}, {}, False), + ({"head": {"logit_z_loss": 1e-3}}, {}, False), + ({"head": {"logits_scale_factor": 5.0}}, {}, False), + ({"head": {"tied_weight": False}}, {}, False), + ({"head": {"prediction_heads": 2}}, {}, False), ({}, {}, True), ( { - "output_layer": { + "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.cross_entropy, } @@ -124,7 +124,7 @@ def _lm_head( ), ( { - "output_layer": { + "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.reverse_kl, } @@ -134,7 +134,7 @@ def _lm_head( ), ( { - "output_layer": { + "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.cross_entropy, } @@ -144,7 +144,7 @@ def _lm_head( ), ( { - "output_layer": { + "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.reverse_kl, } @@ -169,7 +169,7 @@ def test_lm_head( "vocab_size": VOCAB_SIZE, "hidden_size": HIDDEN_SIZE, }, - "output_layer": { + "head": { "cross_entropy_implementation": cross_entropy_impl, "normalization": {"type": "rms_norm"}, }, @@ -188,7 +188,7 @@ def test_lm_head( ) sequence_first = config.sequence_first or ( - config.output_layer.cross_entropy_splits is not None and config.output_layer.cross_entropy_splits > 1 + config.head.cross_entropy_splits is not None and config.head.cross_entropy_splits > 1 ) input_ = torch.randn( (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), @@ -201,9 +201,9 @@ def test_lm_head( requires_grad=True, ) label_shape = ( - (SEQUENCE_LENGTH + config.output_layer.prediction_heads - 1, BATCH_SIZE) + (SEQUENCE_LENGTH + config.head.prediction_heads - 1, BATCH_SIZE) if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.output_layer.prediction_heads - 1) + else (BATCH_SIZE, SEQUENCE_LENGTH + config.head.prediction_heads - 1) ) if loss_masking: loss_mask = torch.randint(0, 2, label_shape, dtype=torch.bool, device=distributed.device) @@ -213,7 +213,7 @@ def test_lm_head( AttentionKwargs.sequence_first: sequence_first, AttentionKwargs.grad_output: 1.0, } - if config.output_layer.distillation_model is None: + if config.head.distillation_model is None: target = torch.randint( 0, VOCAB_SIZE, @@ -226,17 +226,17 @@ def test_lm_head( kwargs[LanguageModelKwargs.labels] = target else: - assert config.output_layer.prediction_heads == 1 + assert config.head.prediction_heads == 1 target = torch.randn( input_.shape[:-1] + (VOCAB_SIZE,), dtype=input_.dtype, device=distributed.device, ) - kwargs[f"{config.output_layer.distillation_model}_logits"] = target + kwargs[f"{config.head.distillation_model}_logits"] = target if loss_mask is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask - if config.output_layer.tied_weight or config.output_layer.prediction_heads > 1: + if config.head.tied_weight or config.head.prediction_heads > 1: logit_weight = ( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device @@ -244,7 +244,7 @@ def test_lm_head( .normal_(config.embeddings_layer.hidden_size**-0.5) .requires_grad_(True) ) - kwargs[WORD_EMBEDDINGS_WEIGHT if config.output_layer.tied_weight else OUTPUT_WEIGHTS] = logit_weight + kwargs[WORD_EMBEDDINGS_WEIGHT if config.head.tied_weight else OUTPUT_WEIGHTS] = logit_weight else: logit_weight = None @@ -276,9 +276,9 @@ def test_lm_head( loss_mask, rms_weight=ref_rms_weight, logit_weight=ref_logit_weight, - logit_scale_factor=config.output_layer.logits_scale_factor, - logit_z_loss=config.output_layer.logit_z_loss, - distillation_loss_implementation=config.output_layer.distillation_loss_implementation, + logit_scale_factor=config.head.logits_scale_factor, + logit_z_loss=config.head.logit_z_loss, + distillation_loss_implementation=config.head.distillation_loss_implementation, ) # Prepare LM head inputs @@ -295,7 +295,7 @@ def test_lm_head( loss_keys = {loss_name} if ref_z_loss is not None: loss_keys.add("z_loss") - if config.output_layer.distillation_model is not None: + if config.head.distillation_model is not None: loss_keys.add("distillation_loss") loss_keys.add("distil_lm_loss") losses = {key: [] for key in loss_keys} @@ -305,7 +305,7 @@ def test_lm_head( threshold = 1e-5 if distributed.config.compute_dtype == DataType.float32 else 5e-3 min_threshold = ( 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 - ) * config.output_layer.logits_scale_factor + ) * config.head.logits_scale_factor Assert.eq(losses.keys(), loss_keys) Assert.eq(len(losses[loss_name]), 1) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 714abc130..3c3bfb833 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -328,7 +328,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): ) test_input = torch.randint( 0, - model_ref.config.fast_llm_config.base_model.embeddings_layer.vocab_size, + model_ref.config.fast_llm_config.base_model.embeddings.vocab_size, size=(4, 100), dtype=torch.int64, device="cuda", diff --git a/tests/test_config.py b/tests/test_config.py index 6d2583ba3..d74cbcc80 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -74,7 +74,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config = GPTModelConfig.from_dict( { "base_model": { - "embeddings_layer": { + "embeddings": { "hidden_size": 1024, # Default }, "decoder": { @@ -92,7 +92,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, # Default }, - "output_layer": {"tied_weight": False}, + "tied_embedding_weight": False, }, "multi_stage": {"zero_stage": 3}, "distributed": {"compute_dtype": "bfloat16"}, @@ -105,7 +105,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config.save_metadata(save_config) base_model_update = { - "embeddings_layer": {"hidden_size": 512, "vocab_size": 1000}, + "embeddings": {"hidden_size": 512, "vocab_size": 1000}, "decoder": { "block": { "mixer": { @@ -134,7 +134,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): expected_config["distributed"].update({"seed": 1234, "compute_dtype": "float16"}) if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { - "embeddings_layer": { + "embeddings": { "hidden_size": 512, "vocab_size": 1000, }, @@ -157,7 +157,8 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, - "output_layer": {"tied_weight": False, "normalization": {"type": "layer_norm"}}, + "head": {"normalization": {"type": "layer_norm"}}, + "tied_embedding_weight": False, "peft": {"type": "lora", "freeze_others": False}, } else: @@ -167,7 +168,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): base_model_update["decoder"]["block"]["mixer"]["type"] = "attention" base_model_update["decoder"]["block"]["mixer"]["rotary"] = {"type": "none"} base_model_update["decoder"]["block"]["mlp"] = {"type": "mlp"} - base_model_update["output_layer"] = {"normalization": {"type": "layer_norm"}} + base_model_update["head"] = {"type": "language_model_head", "normalization": {"type": "layer_norm"}} base_model_update["peft"] = {"type": "lora", "freeze_others": False} expected_config["base_model"] = base_model_update diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index aa8100126..1e303b9f1 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -189,7 +189,7 @@ def _update_and_add_testing_config( }, "model": { "base_model": { - "embeddings_layer": { + "embeddings": { "word_embeddings": init_1, "position_embeddings": {"enabled": True, **init_1}, "hidden_size": 256, @@ -215,7 +215,8 @@ def _update_and_add_testing_config( }, "num_blocks": 2, }, - "output_layer": {"output_weight": init_1}, + "head": {"output_weight": init_1}, + "tied_embedding_weight": True, }, "multi_stage": { "debug_param_init": _LOG_LEVEL, @@ -324,7 +325,7 @@ def _update_and_add_testing_config( updates={ ("model", "base_model", "decoder", "block", "mixer", "head_groups"): 4, ("model", "base_model", "decoder", "block", "mixer", "rotary", "type"): "default", - ("model", "base_model", "embeddings_layer", "position_embeddings", "enabled"): False, + ("model", "base_model", "embeddings", "position_embeddings", "enabled"): False, }, megatron_args=[ "--group-query-attention", @@ -354,8 +355,8 @@ def _update_and_add_testing_config( ("model", "base_model", "decoder", "block", "mlp", "activation"): "silu", ("model", "base_model", "decoder", "block", "mlp", "add_linear_biases"): False, ("model", "base_model", "decoder", "block", "normalization", "type"): "rms_norm", - ("model", "base_model", "output_layer", "normalization", "type"): "rms_norm", - ("model", "base_model", "output_layer", "tied_weight"): False, + ("model", "base_model", "head", "normalization", "type"): "rms_norm", + ("model", "base_model", "tied_embedding_weight"): False, }, megatron_args=[ "--swiglu", @@ -441,7 +442,7 @@ def _update_and_add_testing_config( "llama", "mtp_llama", updates={ - ("model", "base_model", "output_layer", "prediction_heads"): 2, + ("model", "base_model", "head", "prediction_heads"): 2, }, # Megatron doesn't support multi-token prediction. megatron_args=None, From a9fc8726c7a32e911df999e19586ada718409290 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 1 Oct 2025 16:56:43 -0400 Subject: [PATCH 06/16] stuff --- fast_llm/models/gpt/model.py | 47 +++++------------------------------- 1 file changed, 6 insertions(+), 41 deletions(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 3d481a04a..99a89a345 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -130,33 +130,6 @@ def preprocess_meta( LanguageModelKwargs.mask_inputs: not truncate_documents, } - # TODO ====== Vision ====== - # if self._config.vision_encoder.enabled: - # try: - # max_image_size = batch_meta.max_image_size - # except AttributeError: - # max_image_size = 256 - # logger.warning("Inference mode: max_image_size not provided, defaulting to 256") - # vision_kwargs = { - # VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, - # VisionEncoderKwargs.max_image_size: max_image_size, - # VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, - # VisionEncoderKwargs.kv_channels: self._tensor_space[VisionTransformerDimNames.kv_channels].size, - # VisionEncoderKwargs.out_channels: self._tensor_space[VisionEncoderDimNames.out_channels].size, - # } - # vision_hidden_dim = self._tensor_space[VisionTransformerDimNames.hidden] - # vision_hidden_dims = ( - # (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) - # if sequence_first - # else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) - # ) - # vision_kwargs.update( - # { - # VisionTransformerKwargs.hidden_dims: vision_hidden_dims, - # } - # ) - # common_kwargs.update(vision_kwargs) - sequence_k_pasts = range( sequence_q_dim.size * self._distributed_config.sequence_data_rank, sequence_length, @@ -199,13 +172,6 @@ def preprocess_meta( reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs - # TODO ====== Vision ====== - # if self._config.vision_encoder.enabled: - # # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size - # preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) - # else: - # preprocessed_meta.append((tokens, kwargs)) - preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -306,11 +272,15 @@ def preprocess( valid_spans[:, 0].clamp_(min=sequence_offset) valid_spans[:, 1].clamp_(max=sequence_k + max_prediction_distance - 1) valid_spans -= sequence_offset + loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - labels[start : end + 1, idx] = -100 + loss_mask[start : end + 1, idx] = False else: - labels[idx, start : end + 1] = -100 + loss_mask[idx, start : end + 1] = False + if self._config.output_layer.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask + labels = torch.where(loss_mask, labels, -100) # TODO ====== Preference spans ====== if batch.chosen_spans is not None: @@ -344,11 +314,6 @@ def preprocess( rejected_valid_spans.append(valid_spans) kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans - if self._config.distillation_model is not None: - loss_mask = torch.ones_like(labels, dtype=torch.bool) - loss_mask = torch.where(labels == -100, False, loss_mask) - kwargs[LanguageModelKwargs.loss_mask] = loss_mask - kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) # TODO ====== Turn into super() call ====== From d266c8773f19023cfd3fb4c814a3d9fb035dbeeb Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 1 Oct 2025 16:57:52 -0400 Subject: [PATCH 07/16] stuff --- fast_llm/models/gpt/model.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 99a89a345..c7742e371 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -282,40 +282,6 @@ def preprocess( kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) - # TODO ====== Preference spans ====== - if batch.chosen_spans is not None: - chosen_valid_spans = [] - for spans in batch.chosen_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - chosen_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans - - rejected_valid_spans = [] - for spans in batch.rejected_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - rejected_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans - - kwargs.update(reference_logits[i]) - # TODO ====== Turn into super() call ====== self.embeddings.preprocess(tokens, kwargs) self.decoder.preprocess(tokens, kwargs) From 35cf10c0e788a1785d76187f1246ba6ab743fb07 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 1 Oct 2025 18:33:10 -0400 Subject: [PATCH 08/16] stuff --- fast_llm/engine/base_model/base_model.py | 7 ++--- fast_llm/layers/block/sequence.py | 10 ++++++- fast_llm/layers/decoder/block.py | 4 +++ fast_llm/models/gpt/model.py | 34 ++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 4 deletions(-) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 5b1180a13..4de115bd5 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -81,12 +81,13 @@ class LayerWithNamespace(Layer): TODO: Consider namespace for losses and metrics? """ - def __init__(self, layer: Layer, namespace: str): + def __init__(self, layer: Layer, namespace: str = None): super().__init__(layer._distributed_config) self._layer = layer self._namespace = namespace self.layer_count = self._layer.layer_count self.get_compute_usage = self._layer.get_compute_usage + self.module_name = self._layer.module_name def setup(self, distributed: Distributed) -> None: self._layer.setup(distributed) @@ -101,12 +102,12 @@ def forward( # TODO: Forward meta doesn't go through preprocessing so doesn't have a namespace. # Using kwargs as-is since it's generally unused. assert isinstance(input_, TensorMeta) - return self._layer.forward(input_, kwargs.get(self._namespace, kwargs), losses, metrics) + return self._layer.forward(input_, kwargs, losses, metrics) def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: assert self._namespace not in kwargs kwargs[self._namespace] = kwargs.copy() - return self._layer.preprocess(batch, kwargs[self._namespace]) + self._layer.preprocess(batch, kwargs[self._namespace]) class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], LayerBase): diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index 33b884fdf..530df950e 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -1,5 +1,6 @@ import collections import functools +import typing import torch.nn @@ -54,6 +55,9 @@ def _layers_with_namespace(self) -> list[Layer]: def get_layers(self) -> list["Layer"]: return self._layers_with_namespace + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + self._layers_with_namespace[0].preprocess(batch, kwargs) + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: return ( self[0].get_loss_definitions(count=count * self._config.num_blocks) if self._config.num_blocks > 0 else [] @@ -105,12 +109,16 @@ def _layers_with_namespace(self) -> list[Layer]: def get_layers(self) -> list[Layer]: return self._layers_with_namespace + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + for _, index in self._config.preprocessing_layers.items(): + self._layers_with_namespace[index].preprocess(batch, kwargs) + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: # TODO: Prevent name conflicts. return sum( ( self[self._config.preprocessing_layers[name]].get_loss_definitions(count=count * count_) - for name, count_ in collections.Counter(self.expanded_pattern).items() + for name, count_ in collections.Counter(self._config.expanded_pattern).items() ), [], ) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 08dd5a815..8b19db66a 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -175,5 +175,9 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c ) ) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self.mixer.preprocess(batch, kwargs) + self.mlp.preprocess(batch, kwargs) + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: return self.mixer.get_loss_definitions(count=count) + self.mlp.get_loss_definitions(count=count) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c7742e371..aa3a3d24b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -281,6 +281,40 @@ def preprocess( if self._config.output_layer.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) + kwargs[LanguageModelKwargs.labels] = labels + kwargs.update(reference_logits[i]) + + # TODO ====== Preference spans ====== + if batch.chosen_spans is not None: + chosen_valid_spans = [] + for spans in batch.chosen_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + + chosen_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans + + rejected_valid_spans = [] + for spans in batch.rejected_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + + rejected_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans # TODO ====== Turn into super() call ====== self.embeddings.preprocess(tokens, kwargs) From 78817e2e0c7d8725ab6f0a0a4fb01a38989e688c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 2 Oct 2025 00:08:21 -0400 Subject: [PATCH 09/16] fixes --- examples/mistral.yaml | 4 +- fast_llm/engine/base_model/base_model.py | 8 ++ fast_llm/layers/language_model/head.py | 8 +- .../language_model/multi_token_prediction.py | 20 +++- fast_llm/models/gpt/conversion/apriel.py | 3 +- fast_llm/models/gpt/conversion/llama.py | 82 +++++++---------- fast_llm/models/gpt/huggingface.py | 2 +- fast_llm/models/gpt/model.py | 2 +- fast_llm/utils.py | 4 +- tests/layers/test_lm_head.py | 92 +++++++++++-------- tests/test_attention.py | 14 ++- tests/test_config.py | 34 ++++--- tests/test_multi_stage.py | 8 +- tests/utils/distributed_configs.py | 2 +- tests/utils/utils.py | 10 +- 15 files changed, 159 insertions(+), 134 deletions(-) diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 987801892..2e4a57de7 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -27,7 +27,7 @@ optimizer: beta_2: 0.95 model: base_model: - embeddings_layer: + embeddings: hidden_size: 4096 vocab_size: 32000 dropout: 0.0 @@ -54,7 +54,7 @@ model: epsilon: 1.0e-05 dropout: 0.0 num_blocks: 32 - output_layer: + head: normalization: type: rms_norm epsilon: 1.0e-05 diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 4de115bd5..0a2b83425 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -73,6 +73,11 @@ def forward( ) -> torch.Tensor: pass + def unwrap(self) -> "Layer": + # Get the actual module contained in this layer, + # undoing any wrapping for the Fast-LLM engine (ex. `LayerWithNamespace`) + return self + class LayerWithNamespace(Layer): """ @@ -109,6 +114,9 @@ def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> No kwargs[self._namespace] = kwargs.copy() self._layer.preprocess(batch, kwargs[self._namespace]) + def unwrap(self) -> "Layer": + return self._layer.unwrap() + class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], LayerBase): diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b4d502b10..c284d269a 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -432,9 +432,9 @@ def _logits_cross_entropy_forward_backward( if dpo_loss is not None: losses[self._dpo_loss_name].append(dpo_loss.detach()) if self._config.distillation_model is not None and distillation_loss is not None: - losses[self._distillation_language_model_loss_name].append(distillation_loss.detach()) + losses[self._distillation_loss_name].append(distillation_loss.detach()) if self._config.distillation_model is not None and lm_loss is not None: - losses[self._distillation_loss_name].append(lm_loss.detach()) + losses[self._distillation_language_model_loss_name].append(lm_loss.detach()) return loss, output_parallel_linear_backward(grad, context) if self.training else None @@ -476,7 +476,9 @@ def _distillation_loss_name(self) -> str: def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] if self._config.logit_z_loss: - LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) + loss_defs.append( + LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) + ) if self._config.enable_dpo: loss_defs.append( LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py index 79555d866..8a27cd3d6 100644 --- a/fast_llm/layers/language_model/multi_token_prediction.py +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -1,6 +1,10 @@ +import functools +import typing + import torch from fast_llm.engine.base_model.base_model import Layer, LayerWithNamespace +from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import BlockBase @@ -62,16 +66,22 @@ def __init__( ] ) + @functools.cached_property + def _layers_with_namespace(self) -> list[Layer]: # Wrap all blocks in a namespace using the unique module name of the first one. + # This needs to be in a property because `module_name` is set after `__init__`. namespace = self.blocks[0].module_name - # Note: Pytorch won't redundantly register modules because it doesn't look into lists. - self._blocks_with_namespace = [ - LayerWithNamespace(sublayer, namespace) for layer in self.blocks for sublayer in layer.get_layers() - ] + return [LayerWithNamespace(sublayer, namespace) for layer in self.blocks for sublayer in layer.get_layers()] def get_layers(self) -> list[Layer]: return [ module - for block, head in zip(self._blocks_with_namespace, self.heads, strict=True) + for block, head in zip(self._layers_with_namespace, self.heads, strict=True) for module in (block, head) ] + + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + self._layers_with_namespace[0].preprocess(batch, kwargs) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + return self[0].get_loss_definitions(count=count * self._config.prediction_heads) diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 5b32c481d..9434b9116 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -317,14 +317,13 @@ def get_converters( fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False, - fast_llm_layer_start: int = 1, ) -> list[WeightConverter]: converters = [] for block_index in range(config.num_blocks): block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] converters += cls.block_converter_class.get_converters( block_config, - f"{fast_llm_prefix}.{block_index+fast_llm_layer_start}", + f"{fast_llm_prefix}.{block_index}", f"{hf_prefix}.{block_index}", drop_on_export, ) diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index c22d1ebcb..8a82235b4 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -434,13 +434,12 @@ def get_converters( fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False, - fast_llm_layer_start: int = 1, ) -> list[WeightConverter]: converters = [] for block_index in range(config.num_blocks): converters += cls.block_converter_class.get_converters( config.block, - f"{fast_llm_prefix}.{block_index+fast_llm_layer_start}", + f"{fast_llm_prefix}.{block_index}", f"{hf_prefix}.{block_index}", drop_on_export, ) @@ -477,47 +476,43 @@ class LlamaHeadConverter: @classmethod def import_config(cls, config: dict) -> dict: - return { - "tied_weight": config["tie_word_embeddings"], - "normalization": cls.normalization_converter_class.import_config(config), - } + return {"normalization": cls.normalization_converter_class.import_config(config)} @classmethod def export_config(cls, config: LanguageModelHeadConfig) -> dict: Assert.custom(isinstance, config, LanguageModelHeadConfig) - return safe_merge_dicts( - cls.normalization_converter_class.export_config(config.normalization), - {"tie_word_embeddings": config.tied_weight}, - ) + return cls.normalization_converter_class.export_config(config.normalization) @classmethod def get_converters( - cls, config: LanguageModelHeadConfig, block_config: DecoderBlockConfig, fast_llm_prefix: str, start_index: int + cls, config: LanguageModelHeadConfig, fast_llm_prefix: str, tied_embedding_weight: bool ) -> list[WeightConverter]: - converters = [] - for prediction_distance in range(config.prediction_heads): - if prediction_distance > 0: - converters += cls.block_converter_class.get_converters( - block_config, - f"{fast_llm_prefix}.{start_index+2*prediction_distance-1}", - "", - drop_on_export=True, - ) - converters += cls.normalization_converter_class.get_converters( + # for prediction_distance in range(config.prediction_heads): + # if prediction_distance > 0: + # converters += cls.block_converter_class.get_converters( + # block_config, + # f"{fast_llm_prefix}.{start_index+2*prediction_distance-1}", + # "", + # drop_on_export=True, + # ) + # converters += cls.normalization_converter_class.get_converters( + # config.normalization, + # f"{fast_llm_prefix}.{start_index+2*prediction_distance}.final_norm", + # f"model.norm", + # drop_on_export=prediction_distance > 0, + # ) + return [ + *cls.normalization_converter_class.get_converters( config.normalization, - f"{fast_llm_prefix}.{start_index+2*prediction_distance}.final_norm", + f"{fast_llm_prefix}.final_norm", f"model.norm", - drop_on_export=prediction_distance > 0, - ) - converters.append( + ), get_parameter_converter( - f"{fast_llm_prefix}.{start_index}.output_weights", + f"{fast_llm_prefix}.output_weights", "lm_head.weight", - drop_on_import=config.tied_weight, - ) - ) - - return converters + drop_on_import=tied_embedding_weight, + ), + ] class LlamaBaseModelConverter: @@ -529,9 +524,10 @@ class LlamaBaseModelConverter: @classmethod def import_config(cls, config: dict) -> dict: return { - "embeddings_layer": cls.embeddings_converter_class.import_config(config), + "embeddings": cls.embeddings_converter_class.import_config(config), "decoder": cls.decoder_converter_class.import_config(config, config["hidden_size"]), - "output_layer": cls.head_converter_class.import_config(config), + "head": cls.head_converter_class.import_config(config), + "tied_embedding_weight": config["tie_word_embeddings"], } @classmethod @@ -541,29 +537,17 @@ def export_config(cls, config: GPTBaseModelConfig) -> dict: cls.embeddings_converter_class.export_config(config.embeddings), cls.decoder_converter_class.export_config(config.decoder), cls.head_converter_class.export_config(config.head), + {"tie_word_embeddings": config.tied_embedding_weight}, ) @classmethod def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: return [ - *cls.embeddings_converter_class.get_converters(config.embeddings, "layers.0", "model"), - *cls.decoder_converter_class.get_converters(config.decoder, "layers", "model.layers"), - *cls.head_converter_class.get_converters( - config.head, config.decoder[len(config.decoder) - 1], "layers", len(config.decoder) + 1 - ), + *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), + *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), + *cls.head_converter_class.get_converters(config.head, "head", config.tied_embedding_weight), ] - def _create_weight_converters( - self, - ) -> list[WeightConverter]: - base_model_config = self._model.config.base_model - self.embeddings_converter_class.get_converters(base_model_config.embeddings, "layers.0", "model") - converters = self.decoder_converter_class.get_converters(base_model_config.decoder, "layers", "model.layers") - self.head_converter_class.get_converters( - base_model_config.decoder, base_model_config.decoder.block, "layers", len(base_model_config.decoder) + 1 - ) - return converters - class LlamaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: GPTModel diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 680d8bfb2..7f0fefc18 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -104,7 +104,7 @@ def inner_forward( self._inference_runner.forward(input_, kwargs, iteration=iteration) # TODO: Make a proper way of returning the model output. - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[AttentionKwargs.sequence_first]: logits = kwargs["logits"].transpose(0, 1) else: logits = kwargs["logits"] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index aa3a3d24b..08196aa55 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -281,7 +281,7 @@ def preprocess( if self._config.output_layer.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) - kwargs[LanguageModelKwargs.labels] = labels + kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) # TODO ====== Preference spans ====== diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 1f9feceb4..bbd69ae8a 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -316,7 +316,9 @@ def new_decorator(*args, **kwargs): return new_decorator -def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple = ()): +def compare_nested( + config_a, config_b, errors: list | None = None, prefix: tuple = (), ignore_missing: tuple[str, ...] = () +): if errors is None: errors = [] # Check for equality of both values and types. diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index cf01cd707..6b7aa993b 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -7,7 +7,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig @@ -101,17 +101,17 @@ def _lm_head( @pytest.mark.slow @pytest.mark.parametrize("cross_entropy_impl", tuple(CrossEntropyImpl)) @pytest.mark.parametrize( - ("config_dict", "distributed_config_dict", "loss_masking"), + ("config_dict", "distributed_config_dict", "loss_masking", "prediction_heads"), ( - ({}, {}, False), - ({}, {"compute_dtype": DataType.bfloat16}, False), - ({"embeddings_layer": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False), - ({"sequence_first": True}, {}, False), - ({"head": {"logit_z_loss": 1e-3}}, {}, False), - ({"head": {"logits_scale_factor": 5.0}}, {}, False), - ({"head": {"tied_weight": False}}, {}, False), - ({"head": {"prediction_heads": 2}}, {}, False), - ({}, {}, True), + ({}, {}, False, 1), + ({}, {"compute_dtype": DataType.bfloat16}, False, 1), + ({"embeddings": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False, 1), + ({"sequence_first": True}, {}, False, 1), + ({"head": {"logit_z_loss": 1e-3}}, {}, False, 1), + ({"head": {"logits_scale_factor": 5.0}}, {}, False, 1), + ({"tied_embedding_weight": False}, {}, False, 1), + ({}, {}, False, 2), + ({}, {}, True, 1), ( { "head": { @@ -121,6 +121,7 @@ def _lm_head( }, {}, False, + 1, ), ( { @@ -131,16 +132,19 @@ def _lm_head( }, {}, False, + 1, ), ( { "head": { "distillation_model": "distillation", "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + "language_model_loss_factor": 1.0, } }, {}, True, + 1, ), ( { @@ -151,6 +155,7 @@ def _lm_head( }, {}, True, + 1, ), ), ) @@ -159,24 +164,35 @@ def test_lm_head( config_dict: dict[str, typing.Any], distributed_config_dict: dict[str, typing.Any], loss_masking: bool, + prediction_heads: int, ): + head_config = { + "cross_entropy_implementation": cross_entropy_impl, + "normalization": {"type": "rms_norm"}, + } config = GPTBaseModelConfig.from_dict( { "decoder": { "num_blocks": 0, }, - "embeddings_layer": { + "embeddings": { "vocab_size": VOCAB_SIZE, "hidden_size": HIDDEN_SIZE, }, - "head": { - "cross_entropy_implementation": cross_entropy_impl, - "normalization": {"type": "rms_norm"}, - }, + "head": ( + head_config + if prediction_heads == 1 + else { + "type": "multi_token_prediction", + "head": head_config, + "prediction_heads": prediction_heads, + } + ), }, config_dict, update_type=UpdateType.update, ) + head_config: LanguageModelHeadConfig = config.head if prediction_heads == 1 else config.head.head model, distributed = get_base_model( GPTModelConfig.from_dict( @@ -188,22 +204,22 @@ def test_lm_head( ) sequence_first = config.sequence_first or ( - config.head.cross_entropy_splits is not None and config.head.cross_entropy_splits > 1 + head_config.cross_entropy_splits is not None and head_config.cross_entropy_splits > 1 ) input_ = torch.randn( (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=( distributed.config.optimization_dtype.torch - if config.embeddings_layer.full_precision_residual + if config.embeddings.full_precision_residual else distributed.config.compute_dtype.torch ), device=distributed.device, requires_grad=True, ) label_shape = ( - (SEQUENCE_LENGTH + config.head.prediction_heads - 1, BATCH_SIZE) + (SEQUENCE_LENGTH + config.head.max_prediction_distance - 1, BATCH_SIZE) if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.head.prediction_heads - 1) + else (BATCH_SIZE, SEQUENCE_LENGTH + config.head.max_prediction_distance - 1) ) if loss_masking: loss_mask = torch.randint(0, 2, label_shape, dtype=torch.bool, device=distributed.device) @@ -213,7 +229,7 @@ def test_lm_head( AttentionKwargs.sequence_first: sequence_first, AttentionKwargs.grad_output: 1.0, } - if config.head.distillation_model is None: + if head_config.distillation_model is None: target = torch.randint( 0, VOCAB_SIZE, @@ -226,31 +242,30 @@ def test_lm_head( kwargs[LanguageModelKwargs.labels] = target else: - assert config.head.prediction_heads == 1 + assert config.head.max_prediction_distance == 1 target = torch.randn( input_.shape[:-1] + (VOCAB_SIZE,), dtype=input_.dtype, device=distributed.device, ) - kwargs[f"{config.head.distillation_model}_logits"] = target + kwargs[f"{head_config.distillation_model}_logits"] = target if loss_mask is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask - if config.head.tied_weight or config.head.prediction_heads > 1: + if config.tied_embedding_weight or config.head.max_prediction_distance > 1: logit_weight = ( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device ) - .normal_(config.embeddings_layer.hidden_size**-0.5) + .normal_(config.embeddings.hidden_size**-0.5) .requires_grad_(True) ) - kwargs[WORD_EMBEDDINGS_WEIGHT if config.head.tied_weight else OUTPUT_WEIGHTS] = logit_weight + kwargs[WORD_EMBEDDINGS_WEIGHT if config.tied_embedding_weight else OUTPUT_WEIGHTS] = logit_weight else: logit_weight = None - for prediction_distance, layer_index in enumerate(model.model_head_indices): + for prediction_distance, head in enumerate((model.head,) if prediction_heads == 1 else model.head.heads): # Prepare the LM head - head: LanguageModelHead = model[layer_index] Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) stage = get_stage([head], distributed) @@ -276,9 +291,9 @@ def test_lm_head( loss_mask, rms_weight=ref_rms_weight, logit_weight=ref_logit_weight, - logit_scale_factor=config.head.logits_scale_factor, - logit_z_loss=config.head.logit_z_loss, - distillation_loss_implementation=config.head.distillation_loss_implementation, + logit_scale_factor=head_config.logits_scale_factor, + logit_z_loss=head_config.logit_z_loss, + distillation_loss_implementation=head_config.distillation_loss_implementation, ) # Prepare LM head inputs @@ -291,13 +306,18 @@ def test_lm_head( output_grad = torch.randn_like(shared_hidden) loss_name = f"language_model_loss_{prediction_distance}" if prediction_distance > 0 else "language_model_loss" - Assert.eq(head._loss_name, loss_name) loss_keys = {loss_name} if ref_z_loss is not None: - loss_keys.add("z_loss") - if config.head.distillation_model is not None: + loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") + if head_config.distillation_model is not None: loss_keys.add("distillation_loss") - loss_keys.add("distil_lm_loss") + if head_config.language_model_loss_factor > 0: + loss_keys.add("distillation_language_model_loss") + + Assert.eq( + {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, + {loss_key: 1 for loss_key in loss_keys}, + ) losses = {key: [] for key in loss_keys} output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) @@ -305,7 +325,7 @@ def test_lm_head( threshold = 1e-5 if distributed.config.compute_dtype == DataType.float32 else 5e-3 min_threshold = ( 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 - ) * config.head.logits_scale_factor + ) * head_config.logits_scale_factor Assert.eq(losses.keys(), loss_keys) Assert.eq(len(losses[loss_name]), 1) diff --git a/tests/test_attention.py b/tests/test_attention.py index dceaa8282..a19cba8f0 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -2,13 +2,13 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.attention.attention import Attention from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs -from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert -def test_varlen_preprocessor(): +def test_varlen_preprocessing(): sequence_lengths = [torch.tensor([8, 13, 4, 11], dtype=torch.int32), torch.tensor([11, 16, 9], dtype=torch.int32)] # First micro-sequence: # [0...7,0...3] + [0...10,0] -> [0,8,12,23,24] @@ -28,8 +28,12 @@ def test_varlen_preprocessor(): ] micro_sequence_length = 12 sequence_length = 36 - varlen_preprocessor = FlashAttnVarlenPreprocessor( - AttentionConfig(head_size=64), DistributedConfig(compute_dtype="bfloat16") + attention = Attention( + AttentionConfig(head_size=64), + DistributedConfig(compute_dtype="bfloat16"), + hidden_dim=TensorDim("", 1), + lr_scale=None, + peft=None, ) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { @@ -40,6 +44,6 @@ def test_varlen_preprocessor(): AttentionKwargs.sequence_length: sequence_length, AttentionKwargs.sequence_lengths: sequence_lengths, } - varlen_preprocessor.preprocess(torch.empty(1, device="cpu"), kwargs) + attention.preprocess(torch.empty(1, device="cpu"), kwargs) Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_config.py b/tests/test_config.py index d74cbcc80..326200537 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -127,7 +127,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): } ) serialized_config = pretrained_config.model.to_dict() - expected_config = {"type": "gpt", "distributed": DistributedConfig().to_dict()} + expected_config = {"distributed": DistributedConfig().to_dict()} if load_config == ModelConfigType.fast_llm: expected_config["multi_stage"] = {"zero_stage": 3} @@ -139,40 +139,38 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "vocab_size": 1000, }, "decoder": { - "type": "fixed", "block": { - "type": "decoder", "mixer": { - "type": "attention", - "rotary": {"type": "default"}, "window_size": 32, "head_groups": 1, }, "mlp": { - "type": "mlp", "intermediate_size": 4096, # Implicit default, default value "activation": "silu", # Implicit default, non-default value }, - "normalization": {"type": "rms_norm", "implementation": "triton"}, + "normalization": {"implementation": "triton"}, }, "num_blocks": 12, }, - "head": {"normalization": {"type": "layer_norm"}}, "tied_embedding_weight": False, - "peft": {"type": "lora", "freeze_others": False}, + "peft": {"freeze_others": False}, } else: - base_model_update["decoder"]["type"] = "fixed" - base_model_update["decoder"]["block"]["type"] = "decoder" - base_model_update["decoder"]["block"]["normalization"]["type"] = "layer_norm" - base_model_update["decoder"]["block"]["mixer"]["type"] = "attention" - base_model_update["decoder"]["block"]["mixer"]["rotary"] = {"type": "none"} - base_model_update["decoder"]["block"]["mlp"] = {"type": "mlp"} - base_model_update["head"] = {"type": "language_model_head", "normalization": {"type": "layer_norm"}} - base_model_update["peft"] = {"type": "lora", "freeze_others": False} expected_config["base_model"] = base_model_update - check_equal_nested(serialized_config, expected_config) + check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) + + +def _trim_type(config: dict): + # Serialization inserts dynamic types, we ignore them during the comparison. + if "type" in config: + del config["type"] + for key in list(config): + if isinstance(value := config[key], dict): + _trim_type(value) + if not value: + del config[key] + return config def _check_dim(dim: DistributedDim, name: str, rank: int, size: int, global_rank: int): diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index cc5a60a8a..a4f1e19c8 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -42,14 +42,14 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, DecoderBlock) else 0 - for layer in model_ref.base_model.layers + sum(p.numel() for p in layer.unwrap().mlp.parameters()) if isinstance(layer.unwrap(), DecoderBlock) else 0 + for layer in model_ref.base_model.get_layers() ] # Make sure each layer has its own buffer so the check below works. Assert.eq( - num_stages := len(model_ref.base_model.layers), - len(model_frozen.base_model.layers), + num_stages := len(model_ref.base_model.get_layers()), + len(model_frozen.base_model.get_layers()), len(model_ref.stages), len(model_frozen.stages), ) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 863be2cae..0ef18279d 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -110,7 +110,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon DistributedTestingConfig( name="ce4", compare="simple", - config_args=["model.base_model.output_layer.cross_entropy_splits=4"], + config_args=["model.base_model.head.cross_entropy_splits=4"], num_gpus=1, compare_config=_compare_layer_mismatch, ), diff --git a/tests/utils/utils.py b/tests/utils/utils.py index b086c291f..b6cad34d6 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -11,7 +11,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier -from fast_llm.engine.base_model.base_model import BaseModel, Layer +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.logging import configure_logging from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig @@ -32,19 +32,17 @@ def result_path(): def get_base_model(config: FastLLMModelConfig): # Create a base model (and distributed). # Using a full model config so we have the model type and distributed config in the same argument. - base_model = config.get_model_class().base_model_class(config.base_model, config.distributed) + base_model = config.get_base_model_config_class().get_base_model(config.base_model, config.distributed) base_model.setup(distributed := Distributed(config.distributed)) return base_model, distributed -def get_stage(base_model: BaseModel | list[Layer], distributed: Distributed): +def get_stage(layers: list[Layer], distributed: Distributed): # Create a fast-llm stage which allocates and initializes meta tensors correctly. stage = Stage( config=StageConfig(), - layers=base_model, + layers=layers, distributed_config=distributed.config, - begin=0, - end=1, index=0, ) stage.setup(distributed=distributed) From 678306a056446fc80ae0559a0d3fb56209dd98c0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 2 Oct 2025 16:06:08 -0400 Subject: [PATCH 10/16] tied weights --- fast_llm/engine/base_model/base_model.py | 17 ++-- fast_llm/engine/multi_stage/multi_stage.py | 92 +++++++++++++------ fast_llm/engine/multi_stage/stage.py | 2 + fast_llm/engine/multi_stage/stage_base.py | 30 ++++-- fast_llm/engine/schedule/runner.py | 5 - fast_llm/layers/language_model/config.py | 4 +- fast_llm/layers/language_model/embedding.py | 2 - fast_llm/layers/language_model/head.py | 37 ++++---- .../language_model/multi_token_prediction.py | 3 + fast_llm/models/gpt/model.py | 26 ++---- 10 files changed, 131 insertions(+), 87 deletions(-) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 0a2b83425..d61630e07 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -149,13 +149,16 @@ def preprocess( # TODO ====== Move batch splitting elsewhere, align interface with LayerBase ====== pass - def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: - # TODO ====== Tied weights ====== - # Return tuples of independently defined metas to tie together. - # For each tied weight, return the weight and the tuple of layers sharing it. - # The weight should be defined in the first layer in the set. - # Warning: This may return buffers instead of metas after stage setup. - # The name (dict key) is used to insert the weight in the kwargs of the forward pass. + def get_tied_parameters(self) -> dict[str, list[ParameterMeta]]: + """ + Return tuples of independently defined metas to tie together. + Metas should be compatible, i.e. have the same tensor dimensions. + Tied weights are named (dict keys) for convenience only. + Warning: Initialization and optimization properties are defined on the first appearance of the tied weight. + To prevent any confusion, the metas should be provided in the same order they appear in the model. + TODO: Improve? + Note: This may return buffers instead of metas after stage setup. + """ return {} def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None: diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 77dc4e7dd..ffbfe1338 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -3,7 +3,6 @@ import typing import warnings -import numpy as np import torch from torch._C._distributed_c10d import ProcessGroup @@ -63,41 +62,49 @@ def __init__( self._config.distributed.pipeline_parallel * self._config.multi_stage.stages_per_pipeline_stage, ) + # Keep track of which stage each parameter belongs to. + self._parameter_stages: dict[str, int] = {} + for stage_index in range(self._num_stages): + for layer in self._layers[stage_splits[stage_index] : stage_splits[stage_index + 1]]: + for meta in layer.parameters(): + assert meta.tensor_name not in self._parameter_stages + self._parameter_stages[meta.tensor_name] = stage_index + + # Determine which stages belong to this pipeline rank. + self._stage_pipeline_ranks = { + stage_index: (stage_index // self._config.multi_stage.stages_per_pipeline_stage) + % self._config.distributed.pipeline_parallel + for stage_index in (range(self._num_stages)) + } + + # Set up tied weights. + self._tied_parameters = self._get_tied_parameters() + self._tied_parameter_duplicates = [[] for _ in range(self._num_stages)] + for tied_parameter in self._tied_parameters.values(): + for meta in tied_parameter.metas[1:]: + self._tied_parameter_duplicates[self._parameter_stages[meta.tensor_name]].append(meta.tensor_name) + # Create the stages. self._stages = [ Stage( config=self._config.multi_stage, - layers=self._layers[stage_splits[i] : stage_splits[i + 1]], + layers=self._layers[stage_splits[stage_index] : stage_splits[stage_index + 1]], distributed_config=self._config.distributed, - index=i, + index=stage_index, + tied_parameter_duplicates=tied_parameter_duplicates_, ) - for i in (range(self._num_stages)) + for stage_index, tied_parameter_duplicates_ in enumerate(self._tied_parameter_duplicates) ] if self._verbose: log_main_rank(lambda: f" Total parameters: {sum(stage_.parameter_count for stage_ in self._stages):,} ") - # Keep track of which stage each parameter belongs to. - self._parameter_stages: dict[str, int] = {} - for stage_index, stage in enumerate(self._stages): - for parameter_name in stage.parameter_names: - assert parameter_name not in self._parameter_stages - self._parameter_stages[parameter_name] = stage_index - - # Determine which stages belong to this pipeline rank. - self._stage_pipeline_ranks = { - stage_index: (stage_index // self._config.multi_stage.stages_per_pipeline_stage) - % self._config.distributed.pipeline_parallel - for stage_index in (range(self._num_stages)) - } self._stages_owned = { stage_index: self._stages[stage_index] for stage_index, stage_rank in self._stage_pipeline_ranks.items() if stage_rank == self._config.distributed.pipeline_rank } - # Set up tied weights. - self._tied_parameters = self._get_tied_parameters(stage_splits[1:]) self._tied_weight_main_stages_on_device = { stage_index: self._stages[stage_index] for stage_index in sorted( @@ -318,6 +325,12 @@ def _setup_stages(self) -> None: if self._mode.support_forward and weight_buffer_index is not None else [] ) + tied_weight_duplicate_buffers = { + parameter_name: self._stages[self._parameter_stages[parameter_name]].get_parameter_buffer( + parameter_name + ) + for parameter_name in self._tied_parameter_duplicates[stage_index] + } stage.setup( distributed=self._distributed, weight_shards=stage_weight_shards, @@ -326,6 +339,7 @@ def _setup_stages(self) -> None: grad_buffers=stage_grad_buffers, mode=self._mode if stage_index in self._stages_on_device else StageMode.off_device, is_tied_weight_copy=stage_index in self._stages_on_device and stage_index not in self._stages_owned, + tied_weight_duplicate_buffers=tied_weight_duplicate_buffers, weight_buffer_shared_with=weight_buffer_shared_with, ) @@ -533,17 +547,43 @@ def _get_buffer_placement(self, num_shared_buffers: int | None) -> tuple[list[se } return buffer_contents, buffer_indices - def _get_tied_parameters(self, stage_ends) -> dict[str, "TiedParameter"]: + def _get_tied_parameters(self) -> dict[str, "TiedParameter"]: tied_parameters = {} - for name, (meta, layer_indexes) in self._base_model.get_tied_weights().items(): - Assert.eq(list(layer_indexes), sorted(layer_indexes)) - Assert.incl(meta, list(self._base_model[layer_indexes[0]].parameters())) - stage_indexes = sorted({np.searchsorted(stage_ends, i, side="right").item() for i in layer_indexes}) + for name, metas in self._base_model.get_tied_parameters().items(): + if len(metas) <= 1: + continue + stage_indexes = [self._parameter_stages[meta.tensor_name] for meta in metas] + # TODO: Ambiguous if multiple instances are on the same stage? + Assert.eq( + sorted(stage_indexes), + stage_indexes, + msg="Tied parameters should be provided in the order they appear in the model.", + ) + for meta in metas[1:]: + # TODO: Improve. Compare initializations? (Not currently possible) + if ( + len(meta.dims) != len(metas[0].dims) + or any(dim is not dim_ for dim, dim_ in zip(meta.dims, metas[0].dims, strict=True)) + or meta.sequence_tensor_parallel != metas[0].sequence_tensor_parallel + ): + raise ValueError( + f"Tied parameter group `{name}` has incompatible tied parameters {metas[0]} and {meta}." + ) + if ( + meta.requires_grad != metas[0].requires_grad + or meta.lr_scale != metas[0].lr_scale + or meta.param_weight_decay != metas[0].param_weight_decay + ): + logger.warning( + f"Tied parameters `{metas[0]}` and `{meta}` in tied parameter group `{name}` have different optimization parameters." + f" Only those of `{metas[0].tensor_name}` will be used." + ) + all_ranks = {self._stage_pipeline_ranks[stage_index] for stage_index in stage_indexes} tied_parameters[name] = TiedParameter( name=name, - meta=meta, + metas=tuple(metas), all_ranks=all_ranks, on_device=self._config.distributed.pipeline_rank in all_ranks, main_stage=stage_indexes[0], @@ -555,7 +595,7 @@ def _get_tied_parameters(self, stage_ends) -> dict[str, "TiedParameter"]: class TiedParameter: name: str # Parameter definition. - meta: ParameterMeta + metas: tuple[ParameterMeta, ...] # Whether the local rank is involved at all. on_device: bool # Process group for reduction. diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index bb3133256..dfe324669 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -47,6 +47,7 @@ def setup( # noqa grad_buffers: list[torch.Tensor | None] | None = None, mode: StageMode = StageMode.training, is_tied_weight_copy: bool = False, + tied_weight_duplicate_buffers: dict[str, torch.nn.Parameter] | None = None, weight_buffer_shared_with: collections.abc.Sequence["Stage"] = (), ) -> None: super().setup( @@ -56,6 +57,7 @@ def setup( # noqa weight_buffers=weight_buffers, grad_buffers=grad_buffers, mode=mode, + tied_weight_duplicate_buffers=tied_weight_duplicate_buffers, ) self._is_tied_weight_copy = is_tied_weight_copy if self._mode.support_forward: diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 4778780ee..7a80193fe 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -31,6 +31,7 @@ def __init__( layers: list[Layer], distributed_config: DistributedConfig, index: int, + tied_parameter_duplicates: typing.Iterable[str] = (), ): super().__init__(config) self._distributed_config = distributed_config.validate() @@ -39,9 +40,10 @@ def __init__( self._is_setup = False self._index = index self._layers = layers + self._tied_parameter_duplicates = set(tied_parameter_duplicates) - parameter_metas, frozen_metas = self._get_parameter_metas() - self._parameter_metas = parameter_metas + frozen_metas + parameter_metas, frozen_metas, duplicate_metas = self._get_parameter_metas() + self._parameter_metas = parameter_metas + frozen_metas + duplicate_metas self._fsdps = [] if parameter_metas: self._fsdps.append( @@ -106,6 +108,7 @@ def setup( weight_buffers: list[torch.Tensor | None] | None, grad_buffers: list[torch.Tensor | None] | None, mode: StageMode = StageMode.training, + tied_weight_duplicate_buffers: dict[str, torch.nn.Parameter] | None, ) -> None: assert not self._is_setup distributed.check_config(self._distributed_config) @@ -121,6 +124,8 @@ def setup( weight_buffers = [None for _ in self._fsdps] if grad_buffers is None: grad_buffers = [None for _ in self._fsdps] + if tied_weight_duplicate_buffers is None: + assert not self._tied_parameter_duplicates for fsdp, weight_shard, grad_shard, weight_buffer, grad_buffer in zip( self._fsdps, weight_shards, grad_shards, weight_buffers, grad_buffers, strict=True @@ -142,7 +147,10 @@ def _replace(module: torch.nn.Module): nonlocal i for key in module._parameters: meta = typing.cast(ParameterMeta, module._parameters[key]) - module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) + if meta.tensor_name in self._tied_parameter_duplicates: + module._parameters[key] = tied_weight_duplicate_buffers[meta.tensor_name] + else: + module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) i += 1 i = 0 @@ -172,6 +180,9 @@ def initialize_weights(self) -> None: ] for meta in metas: + if meta.tensor_name in self._tied_parameter_duplicates: + # Initialization is not managed by this stage. + continue fsdp = self._fsdps[fsdp_index := self._fsdp_index[meta.tensor_name]] parameter = weight_shards_split[fsdp_index][meta.tensor_name] # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) @@ -309,24 +320,31 @@ def _export_shard( for fsdp, shard in zip(self._fsdps, shards, strict=True): yield from fsdp.export_shard(shard, data_type) - def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta]]: + def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta], list[ParameterMeta]]: # Get all the stage parameters, # then separate the parameters with and without weight decay, # and squeeze the non-tensor parallel and sequence parallel ones in the middle. # This allows running the optimizer, grad norm and sequence_parallel reduction on contiguous buffers. parameter_metas: list[ParameterMeta] = [] frozen_metas: list[ParameterMeta] = [] + duplicate_metas: list[ParameterMeta] = [] meta: ParameterMeta for layer in self._layers: for name, meta in layer.named_parameters(): Assert.custom(isinstance, meta, ParameterMeta) Assert.eq(meta.dtype, self._distributed_config.optimization_dtype.torch) - if meta.requires_grad: + if name in self._tied_parameter_duplicates: + duplicate_metas.append(meta) + elif meta.requires_grad: parameter_metas.append(meta) else: frozen_metas.append(meta) - return self._reorder_parameter_metas(parameter_metas), self._reorder_parameter_metas(frozen_metas) + return ( + self._reorder_parameter_metas(parameter_metas), + self._reorder_parameter_metas(frozen_metas), + self._reorder_parameter_metas(duplicate_metas), + ) @classmethod def _reorder_parameter_metas(cls, parameter_metas): diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 58449f207..d08932c49 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -339,11 +339,6 @@ def _preprocess_data( num_micro_batches=batch_config.sequential_micro_batches, micro_batch_splits=batch_config.micro_batch_splits, ) - for name, tied_parameter in self._tied_parameters.items(): - if tied_parameter.on_device: - kwargs[name] = self._stages[tied_parameter.main_stage].get_parameter_buffer( - tied_parameter.meta.tensor_name - ) data_index = context.schedule.get_data_index(micro_batch, micro_batch_split) if self._stages_owned[0]: context.inputs[context.schedule.get_step(StepType.forward, 0, data_index).global_index] = input_ diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 9bdc12c16..3f2e4da28 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -15,7 +15,7 @@ if typing.TYPE_CHECKING: from fast_llm.layers.language_model.embedding import LanguageModelEmbedding - from fast_llm.layers.language_model.head import LanguageModelHead + from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction @@ -107,7 +107,7 @@ def get_layer( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - ): + ) -> LanguageModelHeadBase: return self.layer_class( self, distributed_config, diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 6e3bbc901..4dddb0d98 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -14,8 +14,6 @@ from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert -WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" - class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[ConfigType]): """ diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c284d269a..0cbfafa87 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,3 +1,4 @@ +import abc import functools import logging import typing @@ -22,6 +23,7 @@ from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( LanguageModelEmbeddingsConfig, + LanguageModelHeadBaseConfig, LanguageModelHeadConfig, LanguageModelKwargs, ) @@ -33,7 +35,13 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelHeadConfig](Block[ConfigType]): +class LanguageModelHeadBase[ConfigType: LanguageModelHeadBaseConfig](Block[ConfigType]): + @abc.abstractmethod + def get_output_weights(self) -> list[torch.Tensor]: + pass + + +class LanguageModelHead[ConfigType: LanguageModelHeadConfig](LanguageModelHeadBase[ConfigType]): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). TODO: Cleanup (dynamic type? composition?) @@ -98,16 +106,12 @@ def __init__( self._vocab_dim = TensorDim( "vocab", embeddings_config.vocab_size, self._parallel_dim if self._vocab_parallel else None ) - # Only the first head defines the output weights - # TODO ====== tied_weight ====== - if self._prediction_distance == 0: # and not self._config.tied_weight: - # untie embedding weights - self.output_weights = self._config.output_weight.get_parameter( - (self._vocab_dim, self._hidden_dim), - default_initialization=init_normal_(std=self._hidden_size**-0.5), - lr_scale=self._lr_scale, - peft=self._peft, - ) + self.output_weights = self._config.output_weight.get_parameter( + (self._vocab_dim, self._hidden_dim), + default_initialization=init_normal_(std=self._hidden_size**-0.5), + lr_scale=self._lr_scale, + peft=self._peft, + ) def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -185,7 +189,7 @@ def _forward_backward( self._parallel_dim.size if self._sequence_parallel_logits else 1 ) - output_weights = self._get_output_weights(kwargs) + output_weights = self.output_weights loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( ln_output.detach(), targets, output_weights, grad_output, kwargs, losses ) @@ -244,13 +248,8 @@ def _get_targets( targets = None return targets - def _get_output_weights(self, kwargs: dict) -> torch.Tensor: - # TODO ====== tied_weight ====== - # if self._config.tied_weight: - # return kwargs[WORD_EMBEDDINGS_WEIGHT] - if self._prediction_distance > 0: - return kwargs[OUTPUT_WEIGHTS] - return self.output_weights + def get_output_weights(self) -> list[torch.Tensor]: + return [self.output_weights] def _logits_cross_entropy_forward_backward_split( self, diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py index 8a27cd3d6..10fb4fe97 100644 --- a/fast_llm/layers/language_model/multi_token_prediction.py +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -80,6 +80,9 @@ def get_layers(self) -> list[Layer]: for module in (block, head) ] + def get_output_weights(self) -> list[torch.Tensor]: + return sum((head.output_weights for head in self.heads), []) + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: self._layers_with_namespace[0].preprocess(batch, kwargs) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 08196aa55..0b4a5d381 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -13,6 +13,7 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -36,7 +37,7 @@ def __init__( super().__init__(config, distributed_config) self._hidden_dim = TensorDim("hidden", config.embeddings.hidden_size) - self.embeddings = self._config.embeddings.get_layer( + self.embeddings: LanguageModelEmbedding = self._config.embeddings.get_layer( distributed_config, hidden_dim=self._hidden_dim, lr_scale=None, @@ -325,26 +326,11 @@ def preprocess( return preprocessed - def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: - # TODO ====== Tied weights ====== + def get_tied_parameters(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: + output_weights = self.head.get_output_weights() if self._config.tied_embedding_weight: - raise NotImplementedError() - return {} - # if self._config.output_layer.tied_weight: - # return { - # WORD_EMBEDDINGS_WEIGHT: ( - # self.embedding.word_embeddings_weight, - # ) - # } - # elif self._config.output_layer.prediction_heads > 1: - # return { - # OUTPUT_WEIGHTS: ( - # self.model_head.output_weights, - # tuple(self.model_head_indices), - # ) - # } - # else: - # return {} + output_weights.insert(0, self.embeddings.word_embeddings_weight) + return {output_weights[0].tensor_name: output_weights} if len(output_weights) > 1 else {} def get_loss_definitions(self, count: int = 1) -> list[LossDef]: return ( From dc3c58b81903fce6c4328f9789c9ec9c50b0214c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 2 Oct 2025 17:12:27 -0400 Subject: [PATCH 11/16] fixes --- fast_llm/engine/multi_stage/fsdp.py | 6 ++-- fast_llm/engine/multi_stage/multi_stage.py | 19 +++++++----- fast_llm/engine/multi_stage/stage.py | 3 ++ fast_llm/engine/multi_stage/stage_base.py | 4 +-- fast_llm/layers/language_model/config.py | 2 +- fast_llm/layers/language_model/embedding.py | 4 +++ tests/layers/test_lm_head.py | 33 +++++++++++++++++---- 7 files changed, 52 insertions(+), 19 deletions(-) diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 868cc2db4..827079f6e 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -56,11 +56,11 @@ def __init__( # The index range of the parameters in the buffer. self._parameter_begins_in_buffer = { parameter_meta.tensor_name: offset - for parameter_meta, offset in zip(parameter_metas, parameter_offsets[:-1]) + for parameter_meta, offset in zip(parameter_metas, parameter_offsets[:-1], strict=True) } self._parameter_ends_in_buffer = { parameter_meta.tensor_name: offset - for parameter_meta, offset in zip(parameter_metas, parameter_offsets[1:]) + for parameter_meta, offset in zip(parameter_metas, parameter_offsets[1:], strict=True) } # Shard properties @@ -377,7 +377,7 @@ def reduce_gradients( assert self._mode.support_backward if not self._requires_grad: return - for buffer, meta in zip(self._parameter_buffers.values(), self._parameter_metas.values()): + for buffer, meta in zip(self._parameter_buffers.values(), self._parameter_metas.values(), strict=True): if buffer.param_grad_is_zero: # noqa assert allow_no_grad or meta.allow_no_grad, meta triton_fill(buffer.grad_buffer, 0) # noqa diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index ffbfe1338..13f67773f 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -79,10 +79,13 @@ def __init__( # Set up tied weights. self._tied_parameters = self._get_tied_parameters() - self._tied_parameter_duplicates = [[] for _ in range(self._num_stages)] + self._tied_parameter_duplicates = [{} for _ in range(self._num_stages)] for tied_parameter in self._tied_parameters.values(): for meta in tied_parameter.metas[1:]: - self._tied_parameter_duplicates[self._parameter_stages[meta.tensor_name]].append(meta.tensor_name) + self._tied_parameter_duplicates[self._parameter_stages[meta.tensor_name]][ + meta.tensor_name + ] = tied_parameter + print("IUHWO", self._base_model.get_tied_parameters(), self._tied_parameters, self._tied_parameter_duplicates) # Create the stages. self._stages = [ @@ -91,7 +94,7 @@ def __init__( layers=self._layers[stage_splits[stage_index] : stage_splits[stage_index + 1]], distributed_config=self._config.distributed, index=stage_index, - tied_parameter_duplicates=tied_parameter_duplicates_, + tied_parameter_duplicates=tied_parameter_duplicates_.keys(), ) for stage_index, tied_parameter_duplicates_ in enumerate(self._tied_parameter_duplicates) ] @@ -326,10 +329,10 @@ def _setup_stages(self) -> None: else [] ) tied_weight_duplicate_buffers = { - parameter_name: self._stages[self._parameter_stages[parameter_name]].get_parameter_buffer( - parameter_name + parameter_name: self._stages[tied_parameter.main_stage].get_parameter_buffer( + tied_parameter.metas[0].tensor_name ) - for parameter_name in self._tied_parameter_duplicates[stage_index] + for parameter_name, tied_parameter in self._tied_parameter_duplicates[stage_index].items() } stage.setup( distributed=self._distributed, @@ -563,7 +566,7 @@ def _get_tied_parameters(self) -> dict[str, "TiedParameter"]: # TODO: Improve. Compare initializations? (Not currently possible) if ( len(meta.dims) != len(metas[0].dims) - or any(dim is not dim_ for dim, dim_ in zip(meta.dims, metas[0].dims, strict=True)) + or any(dim != dim_ for dim, dim_ in zip(meta.dims, metas[0].dims, strict=True)) or meta.sequence_tensor_parallel != metas[0].sequence_tensor_parallel ): raise ValueError( @@ -599,7 +602,7 @@ class TiedParameter: # Whether the local rank is involved at all. on_device: bool # Process group for reduction. - group: ProcessGroup | None = dataclasses.field(init=False) + group: ProcessGroup | None = dataclasses.field(repr=False, init=False) all_ranks: set[int] # The index of the main stage. main_stage: int diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index dfe324669..93fe1c692 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -70,6 +70,9 @@ def setup( # noqa self._accumulators = [] with torch.enable_grad(): for meta in self._parameter_metas: + if meta.tensor_name in self._tied_parameter_duplicates: + # Already handled in the main stage. + continue buffer = self.get_parameter_buffer(meta.tensor_name) if not buffer.requires_grad: continue diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 7a80193fe..450474ac0 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -330,10 +330,10 @@ def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta] duplicate_metas: list[ParameterMeta] = [] meta: ParameterMeta for layer in self._layers: - for name, meta in layer.named_parameters(): + for meta in layer.parameters(): Assert.custom(isinstance, meta, ParameterMeta) Assert.eq(meta.dtype, self._distributed_config.optimization_dtype.torch) - if name in self._tied_parameter_duplicates: + if meta.tensor_name in self._tied_parameter_duplicates: duplicate_metas.append(meta) elif meta.requires_grad: parameter_metas.append(meta) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 3f2e4da28..93eb54815 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -107,7 +107,7 @@ def get_layer( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - ) -> LanguageModelHeadBase: + ) -> "LanguageModelHeadBase": return self.layer_class( self, distributed_config, diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 4dddb0d98..ae65f5ac6 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -26,6 +26,10 @@ class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[Co layer_count: float = 1000.0 _config: ConfigType + # Position embedding preprocessing + _position_ids: torch.Tensor + _tensor_cache_max_sequence_length: int = -1 + def __init__( self, config: ConfigType, diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 6b7aa993b..dd311d520 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -8,7 +8,6 @@ from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs -from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert @@ -21,9 +20,7 @@ def _reverse_kl_loss( loss_mask: torch.Tensor | None, teacher_softmax_temperature: float = 1.0, ): - scaled_target = target / teacher_softmax_temperature - - scaled_target = torch.clamp(target, min=-50, max=50) + scaled_target = torch.clamp(target / teacher_softmax_temperature, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) with torch.enable_grad(): @@ -109,7 +106,7 @@ def _lm_head( ({"sequence_first": True}, {}, False, 1), ({"head": {"logit_z_loss": 1e-3}}, {}, False, 1), ({"head": {"logits_scale_factor": 5.0}}, {}, False, 1), - ({"tied_embedding_weight": False}, {}, False, 1), + ({"tied_embedding_weight": True}, {}, False, 1), ({}, {}, False, 2), ({}, {}, True, 1), ( @@ -166,6 +163,8 @@ def test_lm_head( loss_masking: bool, prediction_heads: int, ): + torch.cuda.manual_seed(0) + torch.manual_seed(0) head_config = { "cross_entropy_implementation": cross_entropy_impl, "normalization": {"type": "rms_norm"}, @@ -265,6 +264,14 @@ def test_lm_head( logit_weight = None for prediction_distance, head in enumerate((model.head,) if prediction_heads == 1 else model.head.heads): + print( + "AIUFHGUKI", + prediction_distance, + head.config, + head._prediction_distance, + head._prediction_heads, + head._is_last_head, + ) # Prepare the LM head Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) @@ -319,6 +326,10 @@ def test_lm_head( {loss_key: 1 for loss_key in loss_keys}, ) losses = {key: [] for key in loss_keys} + print("head_input", head_input) + for kew, value in kwargs.items(): + print("kwargs", kew, value) + output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) @@ -333,6 +344,18 @@ def test_lm_head( Assert.eq(len(losses["z_loss"]), 1) Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) + print("losses", losses) + print("ref_loss", ref_loss) + + print("input_grad", head_input.grad if head._is_last_head else head_input.grad.unbind()[1]) + print("ref_input.grad", ref_input.grad) + + print("head.final_norm.weight.grad_buffer", head.final_norm.weight.grad_buffer) + print("ref_rms_weight.grad", ref_rms_weight.grad) + + print("logit_weight.grad_buffer", logit_weight.grad_buffer) + print("ref_logit_weight.grad", ref_logit_weight.grad) + Assert.rms_close_relative(losses[loss_name][0], ref_loss, threshold, min_threshold) if head._is_last_head: From 9eaebf82b1495be1f32e7467647e0cbb70e31f3c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 2 Oct 2025 18:15:49 -0400 Subject: [PATCH 12/16] fixes --- fast_llm/engine/multi_stage/multi_stage.py | 2 +- fast_llm/engine/multi_stage/stage.py | 4 +- fast_llm/engine/multi_stage/stage_base.py | 14 +++++-- tests/layers/test_lm_head.py | 47 ++++++---------------- tests/utils/utils.py | 10 ++++- 5 files changed, 35 insertions(+), 42 deletions(-) diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 13f67773f..c0078c15c 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -342,7 +342,7 @@ def _setup_stages(self) -> None: grad_buffers=stage_grad_buffers, mode=self._mode if stage_index in self._stages_on_device else StageMode.off_device, is_tied_weight_copy=stage_index in self._stages_on_device and stage_index not in self._stages_owned, - tied_weight_duplicate_buffers=tied_weight_duplicate_buffers, + tied_parameter_duplicate_buffers=tied_weight_duplicate_buffers, weight_buffer_shared_with=weight_buffer_shared_with, ) diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 93fe1c692..56644e0ff 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -47,7 +47,7 @@ def setup( # noqa grad_buffers: list[torch.Tensor | None] | None = None, mode: StageMode = StageMode.training, is_tied_weight_copy: bool = False, - tied_weight_duplicate_buffers: dict[str, torch.nn.Parameter] | None = None, + tied_parameter_duplicate_buffers: dict[str, torch.nn.Parameter] | None = None, weight_buffer_shared_with: collections.abc.Sequence["Stage"] = (), ) -> None: super().setup( @@ -57,7 +57,7 @@ def setup( # noqa weight_buffers=weight_buffers, grad_buffers=grad_buffers, mode=mode, - tied_weight_duplicate_buffers=tied_weight_duplicate_buffers, + tied_parameter_duplicate_buffers=tied_parameter_duplicate_buffers, ) self._is_tied_weight_copy = is_tied_weight_copy if self._mode.support_forward: diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 450474ac0..fb63ee600 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -108,7 +108,7 @@ def setup( weight_buffers: list[torch.Tensor | None] | None, grad_buffers: list[torch.Tensor | None] | None, mode: StageMode = StageMode.training, - tied_weight_duplicate_buffers: dict[str, torch.nn.Parameter] | None, + tied_parameter_duplicate_buffers: dict[str, torch.nn.Parameter] | None, ) -> None: assert not self._is_setup distributed.check_config(self._distributed_config) @@ -124,7 +124,7 @@ def setup( weight_buffers = [None for _ in self._fsdps] if grad_buffers is None: grad_buffers = [None for _ in self._fsdps] - if tied_weight_duplicate_buffers is None: + if tied_parameter_duplicate_buffers is None: assert not self._tied_parameter_duplicates for fsdp, weight_shard, grad_shard, weight_buffer, grad_buffer in zip( @@ -147,8 +147,15 @@ def _replace(module: torch.nn.Module): nonlocal i for key in module._parameters: meta = typing.cast(ParameterMeta, module._parameters[key]) + print( + "AAAAAA", + key, + meta.tensor_name, + self._tied_parameter_duplicates, + tied_parameter_duplicate_buffers.keys(), + ) if meta.tensor_name in self._tied_parameter_duplicates: - module._parameters[key] = tied_weight_duplicate_buffers[meta.tensor_name] + module._parameters[key] = tied_parameter_duplicate_buffers.pop(meta.tensor_name) else: module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) i += 1 @@ -158,6 +165,7 @@ def _replace(module: torch.nn.Module): layer.apply(_replace) Assert.eq(i, len(self._parameter_metas)) + assert not tied_parameter_duplicate_buffers, tied_parameter_duplicate_buffers.keys() def initialize_weights(self) -> None: # TODO: Avoid all the _on_device checks diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index dd311d520..0de823e2a 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -8,7 +8,7 @@ from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs -from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead +from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -252,37 +252,32 @@ def test_lm_head( kwargs[LanguageModelKwargs.loss_mask] = loss_mask if config.tied_embedding_weight or config.head.max_prediction_distance > 1: - logit_weight = ( + logit_weight = torch.nn.Parameter( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device - ) - .normal_(config.embeddings.hidden_size**-0.5) - .requires_grad_(True) + ).normal_(config.embeddings.hidden_size**-0.5) ) - kwargs[WORD_EMBEDDINGS_WEIGHT if config.tied_embedding_weight else OUTPUT_WEIGHTS] = logit_weight else: logit_weight = None for prediction_distance, head in enumerate((model.head,) if prediction_heads == 1 else model.head.heads): - print( - "AIUFHGUKI", - prediction_distance, - head.config, - head._prediction_distance, - head._prediction_heads, - head._is_last_head, - ) # Prepare the LM head Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) - stage = get_stage([head], distributed) + is_duplicate = config.tied_embedding_weight or prediction_distance > 0 + stage = get_stage( + [head], + distributed, + tied_parameter_duplicates=[head.output_weights.tensor_name] if is_duplicate else [], + tied_parameter_duplicate_buffers={head.output_weights.tensor_name: logit_weight} if is_duplicate else {}, + ) # Get reference outputs and grads - if logit_weight is None: - logit_weight = head.output_weights - else: + if is_duplicate: logit_weight.grad_buffer = torch.full_like(logit_weight, float("nan")) logit_weight.param_grad_is_zero = True + else: + logit_weight = head.output_weights ref_input = input_.detach().requires_grad_() ref_rms_weight = head.final_norm.weight.detach().requires_grad_() @@ -326,10 +321,6 @@ def test_lm_head( {loss_key: 1 for loss_key in loss_keys}, ) losses = {key: [] for key in loss_keys} - print("head_input", head_input) - for kew, value in kwargs.items(): - print("kwargs", kew, value) - output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) @@ -344,18 +335,6 @@ def test_lm_head( Assert.eq(len(losses["z_loss"]), 1) Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) - print("losses", losses) - print("ref_loss", ref_loss) - - print("input_grad", head_input.grad if head._is_last_head else head_input.grad.unbind()[1]) - print("ref_input.grad", ref_input.grad) - - print("head.final_norm.weight.grad_buffer", head.final_norm.weight.grad_buffer) - print("ref_rms_weight.grad", ref_rms_weight.grad) - - print("logit_weight.grad_buffer", logit_weight.grad_buffer) - print("ref_logit_weight.grad", ref_logit_weight.grad) - Assert.rms_close_relative(losses[loss_name][0], ref_loss, threshold, min_threshold) if head._is_last_head: diff --git a/tests/utils/utils.py b/tests/utils/utils.py index b6cad34d6..098f0240e 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -37,15 +37,21 @@ def get_base_model(config: FastLLMModelConfig): return base_model, distributed -def get_stage(layers: list[Layer], distributed: Distributed): +def get_stage( + layers: list[Layer], + distributed: Distributed, + tied_parameter_duplicates: typing.Iterable[str] = (), + tied_parameter_duplicate_buffers: dict[str, torch.nn.Parameter] | None = None, +): # Create a fast-llm stage which allocates and initializes meta tensors correctly. stage = Stage( config=StageConfig(), layers=layers, distributed_config=distributed.config, index=0, + tied_parameter_duplicates=tied_parameter_duplicates, ) - stage.setup(distributed=distributed) + stage.setup(distributed=distributed, tied_parameter_duplicate_buffers=tied_parameter_duplicate_buffers) stage.initialize_weights() stage.restore_parameters() stage.reset_gradients() From 209739e6fdfc56c2945a355fd36bc4d998dc54a2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 2 Oct 2025 19:16:08 -0400 Subject: [PATCH 13/16] misc --- Megatron-LM | 2 +- fast_llm/engine/multi_stage/stage.py | 4 ++-- fast_llm/engine/multi_stage/stage_base.py | 7 ------- fast_llm/layers/language_model/embedding.py | 2 +- fast_llm/layers/language_model/head.py | 2 +- tests/utils/compare_tensor_logs.py | 4 +++- tests/utils/distributed_configs.py | 4 ++-- 7 files changed, 10 insertions(+), 15 deletions(-) diff --git a/Megatron-LM b/Megatron-LM index 30e7aeccd..dee27459d 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 30e7aeccd87ec22e424f35c6e61f05ceb878a8df +Subproject commit dee27459d46fecc513be76732a0095bb38be32fb diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 56644e0ff..03429beed 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -232,7 +232,7 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] ): check_parallel_match(output, self._distributed.tensor_group, f"layer {self._layers[i].module_name} fw") if self._config.debug_layer_outputs: - name = f"layer {self._layers[i].module_name} fw" + name = f"{self._layers[i].module_name} fw" if (nmb := kwargs.get("num_micro_batches", 1)) > 1: name = f"{name}, mb={kwargs.get('micro_batch',0)}/{nmb}" if (nms := kwargs.get("micro_batch_splits", 1)) > 1: @@ -265,7 +265,7 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any ) ) if self._config.debug_layer_gradients: - name = f"layer {self._layers[i].module_name} bw" + name = f"{self._layers[i].module_name} bw" if (nmb := kwargs.get("num_micro_batches", 1)) > 1: name = f"{name}, mb={kwargs.get('micro_batch',0)}/{nmb}" if (nms := kwargs.get("micro_batch_splits", 1)) > 1: diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index fb63ee600..ef6ebc27e 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -147,13 +147,6 @@ def _replace(module: torch.nn.Module): nonlocal i for key in module._parameters: meta = typing.cast(ParameterMeta, module._parameters[key]) - print( - "AAAAAA", - key, - meta.tensor_name, - self._tied_parameter_duplicates, - tied_parameter_duplicate_buffers.keys(), - ) if meta.tensor_name in self._tied_parameter_duplicates: module._parameters[key] = tied_parameter_duplicate_buffers.pop(meta.tensor_name) else: diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index ae65f5ac6..0ad3225c8 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -119,7 +119,7 @@ def forward( if isinstance(input_, TensorMeta): return TensorMeta.from_dims( kwargs[LanguageModelKwargs.hidden_dims], - tensor_name="Embedding output", + tensor_name=f"{self.module_name} output", dtype=self._residual_dtype, ) return self._forward( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 0cbfafa87..4b0e3d102 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -120,7 +120,7 @@ def forward( if self._is_last_head: return TensorMeta.from_dims( (scalar_dim,), - tensor_name="Loss", + tensor_name=f"{self.module_name} output", reductions=( (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.AVG), ), diff --git a/tests/utils/compare_tensor_logs.py b/tests/utils/compare_tensor_logs.py index 51ee66d31..1c8ebd76a 100644 --- a/tests/utils/compare_tensor_logs.py +++ b/tests/utils/compare_tensor_logs.py @@ -79,7 +79,9 @@ def _compare_dict_keys(self, dict_ref, dict_test, errors, name): keys_test = set(dict_test) if keys_ref != keys_test: errors.append( - f">>>> {name} do not match. Missing = {keys_ref - keys_test}, extra = {keys_test - keys_ref}." + f">>>> {name} do not match." + f"\n Missing = \n{"\n * ".join(keys_ref - keys_test)}" + f"\n Extra = \n{"\n * ".join(keys_test - keys_ref)}" ) # Avoid set to preserve ordering. diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 0ef18279d..fac595905 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -228,8 +228,8 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", - "model.base_model.embeddings_layer.vocab_parallel=False", - "model.base_model.output_layer.cross_entropy_splits=4", + "model.base_model.embeddings.vocab_parallel=False", + "model.base_model.head.cross_entropy_splits=4", ], num_gpus=2, compare_config=_compare_layer_match, From 2c792df0064646f52c9e26163df501c7ebbe7f70 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 2 Oct 2025 21:26:17 -0400 Subject: [PATCH 14/16] fixes --- fast_llm/engine/checkpoint/huggingface.py | 12 ++-- fast_llm/engine/multi_stage/multi_stage.py | 16 +++-- fast_llm/engine/multi_stage/stage_base.py | 3 +- fast_llm/layers/language_model/config.py | 16 ++--- .../language_model/multi_token_prediction.py | 6 +- fast_llm/models/gpt/config.py | 39 ++++++----- fast_llm/models/gpt/conversion/apriel.py | 22 +++--- fast_llm/models/gpt/conversion/llama.py | 39 ++++------- fast_llm/models/gpt/conversion/mistral.py | 4 +- fast_llm/models/gpt/conversion/mtp_llama.py | 70 +++++++++++++------ fast_llm/models/gpt/trainer.py | 3 +- tests/utils/model_configs.py | 15 +++- 12 files changed, 141 insertions(+), 104 deletions(-) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index e5d14711d..afe381295 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -31,7 +31,7 @@ def export_config(cls, config: BaseModelConfig) -> dict: @classmethod @abc.abstractmethod - def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]: + def get_converters(cls, config: BaseModelConfig, exported_config: dict) -> list[WeightConverter]: pass @@ -39,6 +39,10 @@ class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, architecture: typing.ClassVar[str] base_model_converter_class: typing.ClassVar[type[HuggingFaceBaseModelConverter]] + def __init__(self, model: "FastLLMModel"): + self._exported_config = self._export_config(model.config) + super().__init__(model) + @classmethod @abc.abstractmethod def get_transformers_configuration_class(cls) -> type["transformers.PretrainedConfig"]: @@ -126,10 +130,8 @@ def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: Assert.eq(config["architecture"], cls.architecture) return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) - def _create_weight_converters( - self, - ) -> list[WeightConverter]: - return self.base_model_converter_class.get_converters(self._model.config.base_model) + def _create_weight_converters(self) -> list[WeightConverter]: + return self.base_model_converter_class.get_converters(self._model.config.base_model, self._exported_config) def _load_weights( self, config: CheckpointLoadConfig, device diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index c0078c15c..c6c8f31a4 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -328,12 +328,16 @@ def _setup_stages(self) -> None: if self._mode.support_forward and weight_buffer_index is not None else [] ) - tied_weight_duplicate_buffers = { - parameter_name: self._stages[tied_parameter.main_stage].get_parameter_buffer( - tied_parameter.metas[0].tensor_name - ) - for parameter_name, tied_parameter in self._tied_parameter_duplicates[stage_index].items() - } + tied_weight_duplicate_buffers = ( + { + parameter_name: self._stages[tied_parameter.main_stage].get_parameter_buffer( + tied_parameter.metas[0].tensor_name + ) + for parameter_name, tied_parameter in self._tied_parameter_duplicates[stage_index].items() + } + if self._mode.support_forward + else None + ) stage.setup( distributed=self._distributed, weight_shards=stage_weight_shards, diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index ef6ebc27e..96d80ce06 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -124,8 +124,6 @@ def setup( weight_buffers = [None for _ in self._fsdps] if grad_buffers is None: grad_buffers = [None for _ in self._fsdps] - if tied_parameter_duplicate_buffers is None: - assert not self._tied_parameter_duplicates for fsdp, weight_shard, grad_shard, weight_buffer, grad_buffer in zip( self._fsdps, weight_shards, grad_shards, weight_buffers, grad_buffers, strict=True @@ -148,6 +146,7 @@ def _replace(module: torch.nn.Module): for key in module._parameters: meta = typing.cast(ParameterMeta, module._parameters[key]) if meta.tensor_name in self._tied_parameter_duplicates: + assert tied_parameter_duplicate_buffers is not None module._parameters[key] = tied_parameter_duplicate_buffers.pop(meta.tensor_name) else: module._parameters[key] = self.get_parameter_buffer(meta.tensor_name) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 93eb54815..9b2b2fec1 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -183,9 +183,9 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - enable_dpo: bool | None = Field( - default=False, - desc="Whether to enable DPO loss", + dpo_reference_model: str | None = Field( + default=None, + desc="Name of the reference model to use for dpo.", hint=FieldHint.feature, ) dpo_beta: float | None = Field( @@ -193,11 +193,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Beta value for DPO loss.", hint=FieldHint.feature, ) - dpo_reference_model: str | None = Field( - default=None, - desc="Name of the reference model to use for dpo.", - hint=FieldHint.feature, - ) distillation_model: str | None = Field( default=None, desc="Name of the reference model to use for knowledge distillation." @@ -243,11 +238,16 @@ def _validate(self) -> None: else: self.language_model_loss_factor = 0.0 super()._validate() + assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both @property def max_prediction_distance(self) -> int: return 1 + @property + def enable_dpo(self) -> bool: + return self.dpo_reference_model is not None + @config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py index 10fb4fe97..e0eb8175d 100644 --- a/fast_llm/layers/language_model/multi_token_prediction.py +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -81,10 +81,12 @@ def get_layers(self) -> list[Layer]: ] def get_output_weights(self) -> list[torch.Tensor]: - return sum((head.output_weights for head in self.heads), []) + return sum((head.get_output_weights() for head in self.heads), []) def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: self._layers_with_namespace[0].preprocess(batch, kwargs) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self[0].get_loss_definitions(count=count * self._config.prediction_heads) + return self.blocks[0].get_loss_definitions(count=count * self._config.prediction_heads) + [ + loss_definition for head in self.heads for loss_definition in head.get_loss_definitions(count=count) + ] diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 9c8e90d44..1e57f3b8c 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -10,7 +10,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelConfig +from fast_llm.layers.language_model.config import LanguageModelConfig, MultiTokenPredictionConfig from fast_llm.models.gpt.conversion.config import ( AprielHybridSSMCheckpointFormat, AutoGPTHuggingfaceCheckpointFormat, @@ -156,33 +156,34 @@ def _validate(self) -> None: if self.model.base_model.embeddings.position_embeddings.enabled: Assert.geq(self.model.base_model.embeddings.num_position_embeddings, self.batch.sequence_length) - distillation_model = self.model.base_model.head.distillation_model - dpo_reference_model = self.model.base_model.head.dpo_reference_model - - if self.model.base_model.head.enable_dpo: - assert dpo_reference_model is not None - Assert.none(distillation_model) + # TODO: Avoid digging inside the model. + head = self.model.base_model.head + if isinstance(head, MultiTokenPredictionConfig): + prediction_heads = head.prediction_heads + head = head.head else: - Assert.none(dpo_reference_model) + prediction_heads = 1 - if distillation_model is None and dpo_reference_model is None: - Assert.empty(self.reference_models) - else: - assert distillation_model is None or dpo_reference_model is None # currently don't support both - expected_names = {name for name in (distillation_model, dpo_reference_model) if name is not None} - Assert.eq(self.reference_models.keys(), expected_names) + expected_names = {name for name in (head.distillation_model, head.dpo_reference_model) if name is not None} + Assert.eq(self.reference_models.keys(), expected_names) for reference_model in self.reference_models.values(): - output_layer = reference_model.model.base_model.head - Assert.none(output_layer.distillation_model) - Assert.none(output_layer.dpo_reference_model) + reference_head = reference_model.model.base_model.head + if isinstance(reference_head, MultiTokenPredictionConfig): + reference_prediction_heads = reference_head.prediction_heads + reference_head = reference_head.heads + else: + reference_prediction_heads = 1 + Assert.geq(reference_prediction_heads, prediction_heads) + + Assert.none(reference_head.distillation_model) + Assert.none(reference_head.dpo_reference_model) # TODO: Support more LM head features. - Assert.none(output_layer.cross_entropy_splits) + Assert.none(reference_head.cross_entropy_splits) Assert.eq( reference_model.model.base_model.embeddings.vocab_parallel, self.model.base_model.embeddings.vocab_parallel, ) - Assert.geq(output_layer.prediction_heads, output_layer.prediction_heads) @classmethod def get_trainer_class(cls) -> type["GPTTrainer"]: diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 9434b9116..4b9849630 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -24,11 +24,11 @@ class AprielDiscreteMamba2Converter: @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: return { "type": "discrete_mamba_2", "state_size": config["ssm_cfg"]["d_state"], - "d_inner": config["ssm_cfg"].get("d_inner") or hidden_size * config["ssm_cfg"].get("expand", 1), + "d_inner": config["ssm_cfg"].get("d_inner") or config["hidden_size"] * config["ssm_cfg"].get("expand", 1), "add_linear_biases": config["ssm_cfg"]["bias"], "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, "n_qk_heads": config["ssm_cfg"]["n_qk_heads"], @@ -117,17 +117,17 @@ def get_converters( class AprielMamba2Converter: @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: return { "type": "mamba_2", "state_size": config["ssm_cfg"]["d_state"], - "d_inner": config["ssm_cfg"].get("d_inner") or hidden_size * config["ssm_cfg"].get("expand", 1), + "d_inner": config["ssm_cfg"].get("d_inner") or config["hidden_size"] * config["ssm_cfg"].get("expand", 1), "add_linear_biases": config["ssm_cfg"]["bias"], "convolution_layer": {"bias": {"enabled": config["ssm_cfg"].get("conv_bias", True)}}, - "d_xb": config["ssm_cfg"].get("d_xb") or hidden_size, + "d_xb": config["ssm_cfg"].get("d_xb") or config["hidden_size"], "dt_layer": {"bias": {"enabled": config["ssm_cfg"].get("dt_proj_bias", True)}}, "dt_rank": ( - math.ceil(hidden_size) + math.ceil(config["hidden_size"]) if config["ssm_cfg"].get("dt_rank", "auto") == "auto" else config["ssm_cfg"]["dt_rank"] ), @@ -246,8 +246,8 @@ class AprielBlockConverter: _config_classes = {value: key for key, value in layout_names.items()} @classmethod - def import_config(cls, config: dict, hidden_size: int, layout_name: str = "t") -> dict: - return cls._converter_classes[cls._config_classes[layout_name]].import_config(config, hidden_size) + def import_config(cls, config: dict, layout_name: str = "t") -> dict: + return cls._converter_classes[cls._config_classes[layout_name]].import_config(config) @classmethod def export_config(cls, config) -> dict: @@ -270,18 +270,18 @@ class AprielDecoderConverter(MistralDecoderConverter): block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: layout = config["hybrid_block_layout"] if len(layout) == 1: return { - "block": cls.block_converter_class.import_config(config, hidden_size, layout[0]), + "block": cls.block_converter_class.import_config(config, layout[0]), "num_blocks": config["num_hidden_layers"], } else: return { "type": "pattern", "blocks": { - layout_name: cls.block_converter_class.import_config(config, hidden_size, layout_name) + layout_name: cls.block_converter_class.import_config(config, layout_name) for layout_name in set(layout) }, "pattern": layout, diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 8a82235b4..786d923f2 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -184,7 +184,7 @@ def import_weight( class LlamaAttentionConverter: @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: try: rope_type = config["rope_scaling"]["rope_type"] except (KeyError, TypeError): @@ -224,7 +224,7 @@ def import_config(cls, config: dict, hidden_size: int) -> dict: "dropout": config["attention_dropout"], } if out["head_size"] is None: - out["head_size"] = div(hidden_size, out["heads"]) + out["head_size"] = div(config["hidden_size"], out["heads"]) return out @@ -360,9 +360,9 @@ class LlamaBlockConverter: hf_norm_2_name: typing.ClassVar[str] = "post_attention_layernorm" @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: return { - "mixer": cls.mixer_converter_class.import_config(config, hidden_size), + "mixer": cls.mixer_converter_class.import_config(config), "mlp": cls.mlp_converter_class.import_config(config), "normalization": cls.normalization_converter_class.import_config(config), } @@ -412,9 +412,9 @@ class LlamaDecoderConverter: block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: + def import_config(cls, config: dict) -> dict: return { - "block": cls.block_converter_class.import_config(config, hidden_size), + "block": cls.block_converter_class.import_config(config), "num_blocks": config["num_hidden_layers"], } @@ -485,22 +485,11 @@ def export_config(cls, config: LanguageModelHeadConfig) -> dict: @classmethod def get_converters( - cls, config: LanguageModelHeadConfig, fast_llm_prefix: str, tied_embedding_weight: bool + cls, + config: LanguageModelHeadConfig, + exported_config: dict, + fast_llm_prefix: str, ) -> list[WeightConverter]: - # for prediction_distance in range(config.prediction_heads): - # if prediction_distance > 0: - # converters += cls.block_converter_class.get_converters( - # block_config, - # f"{fast_llm_prefix}.{start_index+2*prediction_distance-1}", - # "", - # drop_on_export=True, - # ) - # converters += cls.normalization_converter_class.get_converters( - # config.normalization, - # f"{fast_llm_prefix}.{start_index+2*prediction_distance}.final_norm", - # f"model.norm", - # drop_on_export=prediction_distance > 0, - # ) return [ *cls.normalization_converter_class.get_converters( config.normalization, @@ -510,7 +499,7 @@ def get_converters( get_parameter_converter( f"{fast_llm_prefix}.output_weights", "lm_head.weight", - drop_on_import=tied_embedding_weight, + drop_on_import=exported_config["tie_word_embeddings"], ), ] @@ -525,7 +514,7 @@ class LlamaBaseModelConverter: def import_config(cls, config: dict) -> dict: return { "embeddings": cls.embeddings_converter_class.import_config(config), - "decoder": cls.decoder_converter_class.import_config(config, config["hidden_size"]), + "decoder": cls.decoder_converter_class.import_config(config), "head": cls.head_converter_class.import_config(config), "tied_embedding_weight": config["tie_word_embeddings"], } @@ -541,11 +530,11 @@ def export_config(cls, config: GPTBaseModelConfig) -> dict: ) @classmethod - def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: + def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: return [ *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), - *cls.head_converter_class.get_converters(config.head, "head", config.tied_embedding_weight), + *cls.head_converter_class.get_converters(config.head, exported_config, "head"), ] diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index 4673f5b2c..bfc7d5569 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -16,8 +16,8 @@ class MistralAttentionConverter(LlamaAttentionConverter): @classmethod - def import_config(cls, config: dict, hidden_size: int) -> dict: - return safe_merge_dicts(super().import_config(config, hidden_size), {"window_size": config["sliding_window"]}) + def import_config(cls, config: dict) -> dict: + return safe_merge_dicts(super().import_config(config), {"window_size": config["sliding_window"]}) @classmethod def export_config(cls, config: AttentionConfig) -> dict: diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 194c263f9..6dcbbe4be 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -4,63 +4,93 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import WeightConverter -from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.config import LanguageModelHeadConfig +from fast_llm.layers.block.config import FixedBlockSequenceConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig, MultiTokenPredictionConfig from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaBaseModelConverter, + LlamaBlockConverter, + LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, get_parameter_converter, ) -from fast_llm.utils import safe_merge_dicts +from fast_llm.utils import Assert, safe_merge_dicts class MTPLlamaHeadConverter(LlamaHeadConverter): @classmethod def import_config(cls, config: dict) -> dict: - return safe_merge_dicts( - super().import_config(config), - {"prediction_heads": config["prediction_heads"]}, - ) + return { + "type": "multi_token_prediction", + "block": LlamaBlockConverter.import_config(config), + "head": super().import_config(config), + "prediction_heads": config["prediction_heads"], + } @classmethod - def export_config(cls, config: LanguageModelHeadConfig) -> dict: + def export_config(cls, config: MultiTokenPredictionConfig) -> dict: + Assert.custom(isinstance, config, MultiTokenPredictionConfig) return safe_merge_dicts( - super().export_config(config), + super().export_config(config.head), {"prediction_heads": config.prediction_heads}, ) @classmethod def get_converters( - cls, config: LanguageModelHeadConfig, block_config: DecoderBlockConfig, fast_llm_prefix: str, start_index: int + cls, + config: LanguageModelHeadConfig, + exported_config: dict, + fast_llm_prefix: str, ) -> list[WeightConverter]: converters = [] for prediction_distance in range(config.prediction_heads): - if prediction_distance > 0: - converters += cls.block_converter_class.get_converters( - block_config, - f"{fast_llm_prefix}.{start_index+2*prediction_distance-1}", - f"model.mtp_heads.{prediction_distance - 1}", - ) + converters += cls.block_converter_class.get_converters( + config.block, + f"{fast_llm_prefix}.blocks.{prediction_distance}", + ( + f"model.layers.{exported_config["num_hidden_layers"]-1}" + if prediction_distance == 0 + else f"model.mtp_heads.{prediction_distance - 1}" + ), + ) converters += cls.normalization_converter_class.get_converters( - config.normalization, - f"{fast_llm_prefix}.{start_index+2*prediction_distance}.final_norm", + config.head.normalization, + f"{fast_llm_prefix}.heads.{prediction_distance}.final_norm", f"model.mtp_norms.{prediction_distance}", ) converters.append( get_parameter_converter( - f"{fast_llm_prefix}.{start_index}.output_weights", + f"heads.0.output_weights", "lm_head.weight", - drop_on_import=config.tied_weight, + drop_on_import=exported_config["tie_word_embeddings"], ) ) return converters +class MTPLlamaDecoderConverter(LlamaDecoderConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + return { + "block": cls.block_converter_class.import_config(config), + "num_blocks": config["num_hidden_layers"] - 1, + } + + @classmethod + def export_config(cls, config: FixedBlockSequenceConfig) -> dict: + # TODO: Support PatternBlockSequenceConfig with compatible configs. + Assert.custom(isinstance, config, FixedBlockSequenceConfig) + return safe_merge_dicts( + cls.block_converter_class.export_config(config.block), + {"num_hidden_layers": config.num_blocks + 1}, + ) + + class MTPLlamaBaseModelConverter(LlamaBaseModelConverter): + decoder_converter_class: typing.ClassVar[type[MTPLlamaDecoderConverter]] = MTPLlamaDecoderConverter head_converter_class: typing.ClassVar[type[MTPLlamaHeadConverter]] = MTPLlamaHeadConverter diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 26500212d..54ea13dc4 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -25,7 +25,8 @@ def _get_sampling_parameters( "vocab_size": self._config.model.base_model.embeddings.vocab_size, "sequence_length": self._config.batch.sequence_length, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - "use_preference_loss_spans": self._config.model.base_model.head.enable_dpo, + # OK since DPO is not supported for MTP. + "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), "cross_document_attention": self._config.batch.cross_document_attention, "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.head.max_prediction_distance, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1e303b9f1..ea608ea4b 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -437,12 +437,21 @@ def _update_and_add_testing_config( }, ) + +_llama_block = MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]["block"] + + _update_and_add_testing_config( # Tests multi-token prediction, custom HF model and converter. "llama", "mtp_llama", updates={ - ("model", "base_model", "head", "prediction_heads"): 2, + ("model", "base_model", "head"): { + "type": "multi_token_prediction", + "block": _llama_block, + "head": MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["head"], + "prediction_heads": 2, + }, }, # Megatron doesn't support multi-token prediction. megatron_args=None, @@ -457,6 +466,8 @@ def _update_and_add_testing_config( ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, compare_factor=2.0, + # Arg update for cross-entropy splits doesn't work here. + skip_tests=("ce4", "ms"), ) _update_and_add_testing_config( @@ -550,8 +561,6 @@ def _update_and_add_testing_config( compare_factor=2.0, ) -_llama_block = MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]["block"] - _update_and_add_testing_config( # Tests hybrid Mamba, llamba converter. From 84b8d7d7a507a14d6bc8a41bc4f2a2644c96ea8d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 3 Oct 2025 18:48:41 -0400 Subject: [PATCH 15/16] Fixes, language model --- fast_llm/engine/base_model/base_model.py | 6 +- fast_llm/engine/multi_stage/multi_stage.py | 4 +- fast_llm/engine/multi_stage/stage.py | 1 - fast_llm/engine/schedule/runner.py | 2 +- fast_llm/layers/language_model/config.py | 10 +-- .../layers/language_model/language_model.py | 61 +++++++++++++++++++ fast_llm/models/gpt/conversion/mtp_llama.py | 2 +- fast_llm/models/gpt/model.py | 55 +++-------------- tests/test_multi_stage.py | 4 +- tests/utils/model_configs.py | 1 + 10 files changed, 79 insertions(+), 67 deletions(-) create mode 100644 fast_llm/layers/language_model/language_model.py diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index d61630e07..5df59d4cd 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -133,11 +133,11 @@ def __init__( @abc.abstractmethod def preprocess_meta(self, batch_meta: typing.Any, phase: PhaseType) -> list[tuple[TensorMeta, dict]]: - # TODO ====== Remove (Move batch splitting elsewhere) ====== + # TODO Remove (Move batch splitting elsewhere) pass @abc.abstractmethod - def preprocess( + def preprocess_batch( self, batch: typing.Any, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, @@ -146,7 +146,7 @@ def preprocess( iteration: int, metrics: dict | None = None, ) -> list[tuple[torch.Tensor, dict]]: - # TODO ====== Move batch splitting elsewhere, align interface with LayerBase ====== + # TODO Move batch splitting elsewhere, align interface with LayerBase pass def get_tied_parameters(self) -> dict[str, list[ParameterMeta]]: diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index c6c8f31a4..f45f93862 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -61,7 +61,6 @@ def __init__( self._num_stages, self._config.distributed.pipeline_parallel * self._config.multi_stage.stages_per_pipeline_stage, ) - # Keep track of which stage each parameter belongs to. self._parameter_stages: dict[str, int] = {} for stage_index in range(self._num_stages): @@ -85,7 +84,6 @@ def __init__( self._tied_parameter_duplicates[self._parameter_stages[meta.tensor_name]][ meta.tensor_name ] = tied_parameter - print("IUHWO", self._base_model.get_tied_parameters(), self._tied_parameters, self._tied_parameter_duplicates) # Create the stages. self._stages = [ @@ -335,7 +333,7 @@ def _setup_stages(self) -> None: ) for parameter_name, tied_parameter in self._tied_parameter_duplicates[stage_index].items() } - if self._mode.support_forward + if self._mode.support_forward and stage_index in self._stages_on_device else None ) stage.setup( diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 03429beed..9f5543590 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -145,7 +145,6 @@ def forward( # TODO: Handle variable shape. output_global = output - # TODO ====== Use ====== kwargs["hidden_states"][self._layers[i].module_name] = { "layer_type": type(layer).__name__, "tensor": output_global, diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index d08932c49..133b3206b 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -324,7 +324,7 @@ def _preprocess_data( for micro_batch in range(batch_config.sequential_micro_batches): micro_batch_data = next(data_iterator) if not preprocessed: - micro_batch_data = self._multi_stage.base_model.preprocess( + micro_batch_data = self._multi_stage.base_model.preprocess_batch( micro_batch_data, context.schedule.preprocessed_meta, phase=context.phase, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 9b2b2fec1..d2fbc4909 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -2,7 +2,7 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import LossDef, ModuleConfig +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -254,7 +254,6 @@ class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): _abstract = False # Needs to be `DecoderBlockConfig` for the `return_input` interface. # TODO: Make a generic wrapper for returning input instead? - # TODO ====== Tied weight ====== block: DecoderBlockConfig = Field( desc="Configuration for the decoder block before each head.", hint=FieldHint.architecture, @@ -270,7 +269,6 @@ class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # TODO ====== Adjust ====== prediction_loss_coefficient: list[float] | None = Field( default=None, desc="Loss coefficient for each prediction head.", @@ -291,12 +289,6 @@ def layer_class(self) -> "type[MultiTokenPrediction]": return MultiTokenPrediction - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - # TODO ====== Wrong ====== - return self.block.get_loss_definitions(count=count * self.prediction_heads) + self.head.get_loss_definitions( - count=count * self.prediction_heads - ) - @property def max_prediction_distance(self) -> int: return self.prediction_heads diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py new file mode 100644 index 000000000..9a3bef195 --- /dev/null +++ b/fast_llm/layers/language_model/language_model.py @@ -0,0 +1,61 @@ +import logging +import typing + +from fast_llm.config import Configurable +from fast_llm.engine.base_model.base_model import Layer, LayerBase +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.language_model.config import LanguageModelConfig +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding + +logger = logging.getLogger(__name__) + + +class LanguageModel[ConfigType: LanguageModelConfig](Configurable[ConfigType], LayerBase): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + ): + super().__init__(config, distributed_config) + + self._hidden_dim = TensorDim("hidden", config.embeddings.hidden_size) + self.embeddings: LanguageModelEmbedding = self._config.embeddings.get_layer( + distributed_config, + hidden_dim=self._hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + self.decoder = self._config.decoder.get_layer( + distributed_config, + self._hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + self.head = self._config.head.get_layer( + distributed_config, + self._config.embeddings, + hidden_dim=self._hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + + def get_layers(self) -> list["Layer"]: + return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() + + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? + self.embeddings.preprocess(batch, kwargs) + self.decoder.preprocess(batch, kwargs) + self.head.preprocess(batch, kwargs) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? + return ( + self.embeddings.get_loss_definitions(count) + + self.decoder.get_loss_definitions(count) + + self.head.get_loss_definitions(count) + ) diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 6dcbbe4be..5b83fed69 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -62,7 +62,7 @@ def get_converters( ) converters.append( get_parameter_converter( - f"heads.0.output_weights", + f"{fast_llm_prefix}.heads.0.output_weights", "lm_head.weight", drop_on_import=exported_config["tie_word_embeddings"], ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 0b4a5d381..2c1fb0e4a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -4,8 +4,7 @@ import torch from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.base_model.base_model import BaseModel, Layer -from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner @@ -13,7 +12,7 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -22,7 +21,7 @@ logger = logging.getLogger(__name__) -class GPTBaseModel[ConfigType: GPTBaseModelConfig](BaseModel[ConfigType]): +class GPTBaseModel[ConfigType: GPTBaseModelConfig](LanguageModel[ConfigType], BaseModel[ConfigType]): """ A transformer-based language model generalizing the GPT model architecture. """ @@ -35,28 +34,6 @@ def __init__( distributed_config: DistributedConfig, ): super().__init__(config, distributed_config) - - self._hidden_dim = TensorDim("hidden", config.embeddings.hidden_size) - self.embeddings: LanguageModelEmbedding = self._config.embeddings.get_layer( - distributed_config, - hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, - ) - self.decoder = self._config.decoder.get_layer( - distributed_config, - self._hidden_dim, - lr_scale=None, - peft=self._config.peft, - ) - self.head = self._config.head.get_layer( - distributed_config, - self._config.embeddings, - hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, - ) - if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) @@ -64,13 +41,10 @@ def __init__( param, self._config.decoder.block, config.embeddings.hidden_size ) # Noqa - def get_layers(self) -> list["Layer"]: - return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() - def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: - # TODO ====== Remove (Move batch splitting elsewhere) ====== + # TODO Remove (Move batch splitting elsewhere) # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence if isinstance(batch_meta, GPTBatchConfig): @@ -177,7 +151,7 @@ def preprocess_meta( return preprocessed_meta - def preprocess( + def preprocess_batch( self, batch: GPTBatch, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, @@ -186,7 +160,7 @@ def preprocess( iteration: int, metrics: dict | None = None, ) -> list[tuple[torch.Tensor, dict]]: - # TODO ====== Move batch splitting elsewhere, align interface with LayerBase ====== + # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup if preprocessed_meta is None: @@ -209,7 +183,7 @@ def preprocess( (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta ] - reference_batch = reference_model.fast_llm_model.base_model.preprocess( + reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch( batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration ) @@ -285,7 +259,6 @@ def preprocess( kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) - # TODO ====== Preference spans ====== if batch.chosen_spans is not None: chosen_valid_spans = [] for spans in batch.chosen_spans: @@ -317,28 +290,18 @@ def preprocess( rejected_valid_spans.append(valid_spans) kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans - # TODO ====== Turn into super() call ====== - self.embeddings.preprocess(tokens, kwargs) - self.decoder.preprocess(tokens, kwargs) - self.head.preprocess(tokens, kwargs) - + self.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) return preprocessed def get_tied_parameters(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: + # TODO: Integrate to the `LayerBase` interface, move to `LanguageModel`, `MultiTokenPrediction`? output_weights = self.head.get_output_weights() if self._config.tied_embedding_weight: output_weights.insert(0, self.embeddings.word_embeddings_weight) return {output_weights[0].tensor_name: output_weights} if len(output_weights) > 1 else {} - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return ( - self.embeddings.get_loss_definitions(count) - + self.decoder.get_loss_definitions(count) - + self.head.get_loss_definitions(count) - ) - class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): # TODO: Can we drop class? diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index a4f1e19c8..c06ba506b 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -5,7 +5,6 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel -from fast_llm.layers.decoder.block import DecoderBlock from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup @@ -42,8 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.unwrap().mlp.parameters()) if isinstance(layer.unwrap(), DecoderBlock) else 0 - for layer in model_ref.base_model.get_layers() + sum(p.numel() for p in layer.mlp.parameters()) for layer in model_ref.base_model.decoder.layers ] # Make sure each layer has its own buffer so the check below works. diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index ea608ea4b..6b313aa8a 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -452,6 +452,7 @@ def _update_and_add_testing_config( "head": MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["head"], "prediction_heads": 2, }, + ("model", "base_model", "decoder", "num_blocks"): 1, }, # Megatron doesn't support multi-token prediction. megatron_args=None, From 53538a26ab58c846b650b8ba14a6ac2271865413 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 3 Oct 2025 19:14:07 -0400 Subject: [PATCH 16/16] fixes --- fast_llm/models/gpt/huggingface.py | 2 +- tests/test_multi_stage.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 7f0fefc18..9215e6dc7 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -79,7 +79,7 @@ def inner_forward( # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) - batch = self.fast_llm_base_model.preprocess( + batch = self.fast_llm_base_model.preprocess_batch( GPTBatch(input_ids, sequence_lengths=sequence_lenghts), phase=PhaseType.inference, iteration=iteration ) ((input_, kwargs),) = batch diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c06ba506b..407b47767 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -41,7 +41,8 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) for layer in model_ref.base_model.decoder.layers + sum(p.numel() for p in layer.unwrap().mlp.parameters()) if layer.module_name.startswith("decoder") else 0 + for layer in model_ref.base_model.get_layers() ] # Make sure each layer has its own buffer so the check below works.