Skip to content
3 changes: 3 additions & 0 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,6 @@ def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None:
shutil.copy(self.configuration_file, config.path)
if self.generation_utils_file:
shutil.copy(self.generation_utils_file, config.path)
gen_config = pathlib.Path(self.generation_utils_file).parent / "generation_config.json"
if gen_config.exists():
shutil.copy(gen_config, config.path)
92 changes: 17 additions & 75 deletions fast_llm/models/gpt/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,15 @@ def import_weight(
class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler):
_model: GPTModel
_model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig
architecture: typing.ClassVar[str]
"""
Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral)
"""

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]),
ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False),
RenameParamConverter(
fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),)
Expand Down Expand Up @@ -320,8 +322,8 @@ class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler)

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = "Starcoder2ForCausalLM"
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Starcoder2ForCausalLM"]),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "rotary", "type"),),
fast_llm_value=DefaultRotaryConfig.dynamic_type_name,
Expand Down Expand Up @@ -447,8 +449,8 @@ class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler)

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = "LlamaForCausalLM"
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["LlamaForCausalLM"]),
# TODO: Llama supports biases
ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False),
ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False),
Expand Down Expand Up @@ -499,8 +501,8 @@ class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler):

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = "Qwen2ForCausalLM"
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Qwen2ForCausalLM"]),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "normalization", "type"),),
fast_llm_value="rms_norm",
Expand Down Expand Up @@ -545,8 +547,8 @@ class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = "MistralForCausalLM"
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MistralForCausalLM"]),
IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None),
]

Expand All @@ -569,8 +571,8 @@ class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = "MixtralForCausalLM"
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MixtralForCausalLM"]),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk
),
Expand Down Expand Up @@ -613,8 +615,8 @@ class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlam

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = "MTPLlamaForCausalLM"
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MTPLlamaForCausalLM"]),
ConstantExportParamConverter(
export_names=(("auto_map",),),
export_value={
Expand Down Expand Up @@ -685,7 +687,12 @@ def _create_lm_head_converters(self) -> list[WeightConverter]:
return converters


class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonHuggingfaceCheckpointHandler):
class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen2HuggingfaceCheckpointHandler):
"""
Handler for DiffusionDream Huggingface checkpoints.
Inherits from Qwen2HuggingfaceCheckpointHandler (and CustomModelingExportMixin),
but overrides _create_config_converters to update architectures and auto_map.
"""

from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, generation_utils, modeling_dream

Expand All @@ -697,33 +704,8 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Comm

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = "DreamModel"
return super()._create_config_converters() + [
# From Qwen2HuggingfaceCheckpointHandler - Change architectures to DiffusionDream
ConstantImportParamConverter(
fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm
),
RenameParamConverter(
fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),)
),
ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv"
),
RopeScalingParamConverter(
fast_llm_names=(
("transformer", "rotary", "type"),
("transformer", "rotary", "scale_factor"),
("transformer", "rotary", "low_frequency_factor"),
("transformer", "rotary", "high_frequency_factor"),
("transformer", "rotary", "original_context_length"),
("transformer", "rotary", "attention_factor"),
("transformer", "rotary", "beta_fast"),
("transformer", "rotary", "beta_slow"),
),
export_names=(("rope_scaling",),),
),
IgnoreImportQwen2SlidingWindowParamsConverter(),
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DreamModel"]),
ConstantExportParamConverter(
export_names=(("auto_map",),),
export_value={
Expand All @@ -733,26 +715,8 @@ def _create_config_converters(cls) -> list[ParamConverter]:
),
]

def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
# From Qwen2HuggingfaceCheckpointHandler
transformer_config: TransformerConfig = self._model.config.base_model.transformer
return [
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1",
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
transformer_config.add_mlp_bias,
SplitWeightConverter,
),
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.down_proj",
transformer_config.add_mlp_bias,
MLPLayer2Converter,
),
]


class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlamaHuggingfaceCheckpointHandler):
class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlamaHuggingfaceCheckpointHandler):

from fast_llm.models.gpt.external.diffusion_llama import (
configuration_diffusion_llama,
Expand All @@ -768,12 +732,8 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Comm

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = "DiffusionLlamaModel"
return super()._create_config_converters() + [
# From LlamaHuggingfaceCheckpointHandler - Update architectures to DiffusionLlama
# TODO: Llama supports biases
ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False),
ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False),
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DiffusionLlamaModel"]),
ConstantExportParamConverter(
export_names=(("auto_map",),),
export_value={
Expand All @@ -789,24 +749,6 @@ def _create_config_converters(cls) -> list[ParamConverter]:
# ),
]

def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
# From LlamaHuggingfaceCheckpointHandler
transformer_config: TransformerConfig = self._model.config.base_model.transformer
return [
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1",
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
transformer_config.add_mlp_bias,
SplitWeightConverter,
),
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.down_proj",
transformer_config.add_mlp_bias,
MLPLayer2Converter,
),
]


class AutoGPTHuggingfaceCheckpointHandler(
AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -18,7 +17,6 @@
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -47,7 +45,7 @@ def __init__(
max_window_layers=28,
attention_dropout=0.0,
mask_token_id=151666,
pad_token_id=151643, # vocab_size is set to 8192 for test cases this would fail on Embedding layer check: # pad_token_id=None,
pad_token_id=None, # vocab_size is set to 8192 for test cases this would fail on Embedding layer check: # pad_token_id=151643,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -77,7 +75,7 @@ def __init__(
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"_from_model_config": true,
"bos_token_id": 151643,
"eos_token_id": 151643,
"pad_token_id": 151643,
"transformers_version": "4.46.2"
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
Expand All @@ -20,7 +19,7 @@
"""LLaMA model configuration"""

import math
from typing import Optional, Tuple
from typing import Optional

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import is_torch_available, logging
Expand All @@ -30,13 +29,14 @@
if is_torch_available():
import torch


# Update yarn implementation for RoPE (Taken from Llama but updated to use original_max_position_embeddings)
def _compute_default_rope_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
Expand Down Expand Up @@ -72,9 +72,10 @@ def _compute_default_rope_parameters(
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, attention_factor


def _compute_yarn_parameters(
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with NTK scaling. Please refer to the
[original paper](https://arxiv.org/abs/2309.00071)
Expand All @@ -101,7 +102,7 @@ def _compute_yarn_parameters(
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)

# Apriel: Use original max_position_embeddings instead of max_position_embeddings
max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings")
factor = config.rope_scaling["factor"]
Expand Down Expand Up @@ -152,14 +153,14 @@ def linear_ramp_factor(min, max, dim):

return inv_freq, attention_factor


def _check_received_keys(
rope_type: str,
received_keys: set,
required_keys: set,
optional_keys: Optional[set] = None,
ignore_keys: Optional[set] = None,
):

"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
if "type" in received_keys:
Expand Down Expand Up @@ -189,6 +190,7 @@ def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Opt
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)


def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
Expand Down Expand Up @@ -218,6 +220,8 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
)


# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
# parameterizations, as long as the callable has the same signature.
Expand All @@ -232,6 +236,7 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
"yarn": _validate_yarn_parameters,
}


def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
"""
Validate the RoPE config arguments, given a `PretrainedConfig` object
Expand All @@ -250,6 +255,7 @@ def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set]
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
)


class DiffusionLlamaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
Expand Down Expand Up @@ -397,7 +403,7 @@ def __init__(
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=False, # cache not implemented in diffusion
use_cache=False, # cache not implemented in diffusion
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
Expand All @@ -409,7 +415,7 @@ def __init__(
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
# mask_token_id= TODO: add the mask_token_id we will be using,
mask_token_id=131072, # TODO: add the mask_token_id we will be using,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -435,6 +441,8 @@ def __init__(
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
self.mask_token_id = mask_token_id
self.pad_token_id = pad_token_id
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
Expand All @@ -450,4 +458,5 @@ def __init__(
)
# TODO: self.mask_token_id = mask_token_id

__all__ = ["LlamaConfig"]

__all__ = ["DiffusionLlamaConfig"]
Loading
Loading