Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ COPY --chmod=777 ./tests tests
COPY --chmod=777 ./tools tools
COPY --chmod=777 ./fast_llm_external_models fast_llm_external_models
COPY --chmod=777 --exclude=./fast_llm/csrc/ ./fast_llm/ fast_llm/

# Set a dummy default user so we don't run in root by default.
# The image is still compatible with any user id.
RUN useradd user
USER user
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

2 changes: 1 addition & 1 deletion examples/mistral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ optimizer:
model:
base_model:
embeddings:
hidden_size: 4096
vocab_size: 32000
dropout: 0.0
decoder:
Expand Down Expand Up @@ -58,6 +57,7 @@ model:
normalization:
type: rms_norm
epsilon: 1.0e-05
hidden_size: 4096
tied_embedding_weight: false
multi_stage:
zero_stage: 2
Expand Down
5 changes: 2 additions & 3 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def _load_weights(
].values()
}
elif (config.path / transformers.utils.WEIGHTS_NAME).is_file():
# TODO: Prevent unsafe by default
paths = {config.path / transformers.utils.WEIGHTS_NAME}
elif (config.path / transformers.utils.WEIGHTS_INDEX_NAME).is_file():
logger.info(f"Loading index from {config.path / transformers.utils.WEIGHTS_INDEX_NAME}")
Expand All @@ -170,7 +169,7 @@ def _load_weights(
for key in f.keys():
yield key, "weights", f.get_slice(key)
elif path.suffix == ".bin":
# TODO: Prevent unsafe by default
yield from torch.load(path)
# TODO: Confirm that loading works with `weights_only=True`
yield from torch.load(path, weights_only=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for making our sec folks happy

else:
raise NotImplementedError(f"Unknown file format for {path}")
5 changes: 3 additions & 2 deletions fast_llm/engine/config_utils/parameter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math
import typing

from fast_llm.config import Config, Field, FieldHint, config_class
from fast_llm.config import Field, FieldHint, config_class
from fast_llm.engine.base_model.config import ModuleConfig
from fast_llm.engine.config_utils.initialization import Initialization, InitializationConfig
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.layers.common.peft.config import PeftConfig
Expand Down Expand Up @@ -36,7 +37,7 @@ def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]):


@config_class()
class ParameterConfig(Config):
class ParameterConfig(ModuleConfig):
initialization: InitializationConfig = Field(
desc="If provided, override the default initialization method set by the parent layer.",
hint=FieldHint.feature,
Expand Down
1 change: 1 addition & 0 deletions fast_llm/layers/block/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
config: ConfigType,
distributed_config: DistributedConfig,
*,
# TODO: Review. Use `input_dim(s)` and `output_dim(s)` instead?
hidden_dim: TensorDim,
lr_scale: float | None,
peft: PeftConfig | None,
Expand Down
33 changes: 19 additions & 14 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import typing

from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.engine.base_model.config import ModuleConfig
from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
Expand All @@ -16,6 +15,7 @@
if typing.TYPE_CHECKING:
from fast_llm.layers.language_model.embedding import LanguageModelEmbedding
from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase
from fast_llm.layers.language_model.language_model import LanguageModel
from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction


Expand All @@ -41,12 +41,6 @@ class LanguageModelEmbeddingsConfig(BlockConfig):
desc="Configuration for the word embedding (weight).",
hint=FieldHint.architecture,
)
hidden_size: int = Field(
default=1024,
desc="Size of the model's main hidden dimension, e.g., for its input and output layers.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
vocab_size: int = Field(
default=49152,
desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.",
Expand Down Expand Up @@ -295,24 +289,29 @@ def max_prediction_distance(self) -> int:


@config_class()
class LanguageModelConfig(ModuleConfig):
# TODO: block
class LanguageModelConfig(BlockConfig):
decoder: BlockSequenceConfig = Field(
desc="Configuration for the language model decoder.",
hint=FieldHint.architecture,
)
embeddings: LanguageModelEmbeddingsConfig = Field()
head: LanguageModelHeadBaseConfig = Field()
# TODO: Allow overriding in sub-models?
peft: PeftConfig = Field(
desc="Configuration for parameter-efficient fine tuning.",
embeddings: LanguageModelEmbeddingsConfig = Field(
hint=FieldHint.architecture,
desc="Configuration for the language model embeddings.",
)
head: LanguageModelHeadBaseConfig = Field(
hint=FieldHint.architecture, desc="Configuration for the language model head(s)."
)
tied_embedding_weight: bool = Field(
default=False,
desc="Tie the output weights (logits) with the vocabulary embedding.",
hint=FieldHint.architecture,
)
hidden_size: int = Field(
default=1024,
desc="Size of the model's main hidden dimension, e.g., for its input and output layers.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
sequence_first: bool | None = Field(
default=None,
desc="Override the default dimension ordering",
Expand All @@ -321,3 +320,9 @@ class LanguageModelConfig(ModuleConfig):
" Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error.",
hint=FieldHint.testing,
)

@property
def layer_class(self) -> "type[LanguageModel]":
from fast_llm.layers.language_model.language_model import LanguageModel

return LanguageModel
40 changes: 26 additions & 14 deletions fast_llm/layers/language_model/language_model.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,64 @@
import logging
import typing

from fast_llm.config import Configurable
from fast_llm.engine.base_model.base_model import Layer, LayerBase
import torch

from fast_llm.engine.base_model.base_model import Layer
from fast_llm.engine.base_model.config import LossDef
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.layers.block.block import BlockBase
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.language_model.config import LanguageModelConfig
from fast_llm.layers.language_model.embedding import LanguageModelEmbedding

logger = logging.getLogger(__name__)


class LanguageModel[ConfigType: LanguageModelConfig](Configurable[ConfigType], LayerBase):
class LanguageModel[ConfigType: LanguageModelConfig](BlockBase[ConfigType]):
_config: ConfigType

def __init__(
self,
config: ConfigType,
distributed_config: DistributedConfig,
*,
# TODO: Unused, but required by the `BlockBase` interface.
hidden_dim: TensorDim | None = None,
lr_scale: float | None,
peft: PeftConfig | None,
):
super().__init__(config, distributed_config)

self._hidden_dim = TensorDim("hidden", config.embeddings.hidden_size)
super().__init__(
config,
distributed_config,
hidden_dim=TensorDim("hidden", config.hidden_size),
lr_scale=lr_scale,
peft=peft,
)
self.embeddings: LanguageModelEmbedding = self._config.embeddings.get_layer(
distributed_config,
hidden_dim=self._hidden_dim,
lr_scale=None,
peft=self._config.peft,
lr_scale=self._lr_scale,
peft=self._peft,
)
self.decoder = self._config.decoder.get_layer(
distributed_config,
self._hidden_dim,
lr_scale=None,
peft=self._config.peft,
lr_scale=self._lr_scale,
peft=self._peft,
)
self.head = self._config.head.get_layer(
distributed_config,
self._config.embeddings,
hidden_dim=self._hidden_dim,
lr_scale=None,
peft=self._config.peft,
lr_scale=self._lr_scale,
peft=self._peft,
)

def get_layers(self) -> list["Layer"]:
def get_layers(self) -> list[Layer]:
return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers()

def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None:
# Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable?
self.embeddings.preprocess(batch, kwargs)
self.decoder.preprocess(batch, kwargs)
Expand Down
6 changes: 6 additions & 0 deletions fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig
from fast_llm.engine.schedule.config import BatchConfig
from fast_llm.engine.training.config import TrainerConfig
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.language_model.config import LanguageModelConfig, MultiTokenPredictionConfig
from fast_llm.models.gpt.conversion.config import (
AprielHybridSSMCheckpointFormat,
Expand Down Expand Up @@ -84,6 +85,11 @@ def micro_batch_splits(self) -> int:
class GPTBaseModelConfig(LanguageModelConfig, BaseModelConfig):
_abstract = False

# TODO: Allow overriding in sub-models?
peft: PeftConfig = Field(
desc="Configuration for parameter-efficient fine tuning.",
hint=FieldHint.architecture,
)
# Debug, to get an exact match with megatron init.
use_megatron_initialization: bool = Field(
default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing
Expand Down
16 changes: 7 additions & 9 deletions fast_llm/models/gpt/conversion/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,19 +449,13 @@ def get_converters(
class LlamaEmbeddingsConverter:
@classmethod
def import_config(cls, config: dict) -> dict:
return {
"vocab_size": config["vocab_size"],
"hidden_size": config["hidden_size"],
}
return {"vocab_size": config["vocab_size"]}

@classmethod
def export_config(cls, config: LanguageModelEmbeddingsConfig) -> dict:
Assert.custom(isinstance, config, LanguageModelEmbeddingsConfig)
assert not config.position_embeddings.enabled
return {
"vocab_size": config.vocab_size,
"hidden_size": config.hidden_size,
}
return {"vocab_size": config.vocab_size}

@classmethod
def get_converters(
Expand Down Expand Up @@ -516,6 +510,7 @@ def import_config(cls, config: dict) -> dict:
"embeddings": cls.embeddings_converter_class.import_config(config),
"decoder": cls.decoder_converter_class.import_config(config),
"head": cls.head_converter_class.import_config(config),
"hidden_size": config["hidden_size"],
"tied_embedding_weight": config["tie_word_embeddings"],
}

Expand All @@ -526,7 +521,10 @@ def export_config(cls, config: GPTBaseModelConfig) -> dict:
cls.embeddings_converter_class.export_config(config.embeddings),
cls.decoder_converter_class.export_config(config.decoder),
cls.head_converter_class.export_config(config.head),
{"tie_word_embeddings": config.tied_embedding_weight},
{
"tie_word_embeddings": config.tied_embedding_weight,
"hidden_size": config.hidden_size,
},
)

@classmethod
Expand Down
8 changes: 3 additions & 5 deletions fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,14 @@ class GPTBaseModel[ConfigType: GPTBaseModelConfig](LanguageModel[ConfigType], Ba

def __init__(
self,
config: GPTBaseModelConfig,
config: ConfigType,
distributed_config: DistributedConfig,
):
super().__init__(config, distributed_config)
super().__init__(config, distributed_config, lr_scale=config.lr_scale, peft=config.peft)
if self._config.use_megatron_initialization:
for param in self.parameters():
Assert.custom(isinstance, param, ParameterMeta)
param.init_parameter = get_init_megatron(
param, self._config.decoder.block, config.embeddings.hidden_size
) # Noqa
param.init_parameter = get_init_megatron(param, self._config.decoder.block, config.hidden_size) # Noqa

def preprocess_meta(
self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType
Expand Down
4 changes: 1 addition & 3 deletions fast_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,7 @@ def new_decorator(*args, **kwargs):
return new_decorator


def compare_nested(
config_a, config_b, errors: list | None = None, prefix: tuple = (), ignore_missing: tuple[str, ...] = ()
):
def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple = ()):
if errors is None:
errors = []
# Check for equality of both values and types.
Expand Down
12 changes: 4 additions & 8 deletions tests/layers/test_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,8 @@ def test_lm_head(
}
config = GPTBaseModelConfig.from_dict(
{
"decoder": {
"num_blocks": 0,
},
"embeddings": {
"vocab_size": VOCAB_SIZE,
"hidden_size": HIDDEN_SIZE,
},
"decoder": {"num_blocks": 0},
"embeddings": {"vocab_size": VOCAB_SIZE},
"head": (
head_config
if prediction_heads == 1
Expand All @@ -187,6 +182,7 @@ def test_lm_head(
"prediction_heads": prediction_heads,
}
),
"hidden_size": HIDDEN_SIZE,
},
config_dict,
update_type=UpdateType.update,
Expand Down Expand Up @@ -255,7 +251,7 @@ def test_lm_head(
logit_weight = torch.nn.Parameter(
torch.empty(
VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device
).normal_(config.embeddings.hidden_size**-0.5)
).normal_(config.hidden_size**-0.5)
)
else:
logit_weight = None
Expand Down
Loading