diff --git a/curated_transformers/generation/__init__.py b/curated_transformers/generation/__init__.py index 9f125d3e..ee428b52 100644 --- a/curated_transformers/generation/__init__.py +++ b/curated_transformers/generation/__init__.py @@ -5,7 +5,7 @@ from .falcon import FalconGenerator from .generator import Generator from .generator_wrapper import GeneratorWrapper -from .hf_hub import FromHFHub +from .hf_hub import FromHF from .llama import LlamaGenerator from .logits import ( CompoundLogitsTransform, @@ -32,7 +32,7 @@ "DollyV2Generator", "EndOfSequenceCondition", "FalconGenerator", - "FromHFHub", + "FromHF", "Generator", "GeneratorConfig", "GeneratorWrapper", diff --git a/curated_transformers/generation/auto_generator.py b/curated_transformers/generation/auto_generator.py index 36e6df3e..4a4d727a 100644 --- a/curated_transformers/generation/auto_generator.py +++ b/curated_transformers/generation/auto_generator.py @@ -7,13 +7,13 @@ from .dolly_v2 import DollyV2Generator from .falcon import FalconGenerator from .generator_wrapper import GeneratorWrapper -from .hf_hub import FromHFHub +from .hf_hub import FromHF from .llama import LlamaGenerator from .mpt import MPTGenerator # For the time being, we enable support for a generator on a case-by-case basis. # In the future we might defer all unknown generators to DefaultGenerator. -GENERATOR_MAP: Dict[str, Type[FromHFHub]] = { +GENERATOR_MAP: Dict[str, Type[FromHF]] = { "dolly-v2": DollyV2Generator, "falcon": FalconGenerator, "llama": LlamaGenerator, @@ -70,7 +70,7 @@ def from_hf_hub( return generator -def _resolve_generator_class(name: str) -> Type[FromHFHub]: +def _resolve_generator_class(name: str) -> Type[FromHF]: for substring, generator_cls in GENERATOR_MAP.items(): if substring in name.lower(): return generator_cls diff --git a/curated_transformers/generation/default_generator.py b/curated_transformers/generation/default_generator.py index 358161d3..c1b81cad 100644 --- a/curated_transformers/generation/default_generator.py +++ b/curated_transformers/generation/default_generator.py @@ -13,14 +13,14 @@ from .config import GeneratorConfig, SampleGeneratorConfig from .generator import Generator from .generator_wrapper import GeneratorWrapper -from .hf_hub import FromHFHub +from .hf_hub import FromHF from .string_generator import StringGenerator # Only provided as typing.Self in Python 3.11+. Self = TypeVar("Self", bound="DefaultGenerator") -class DefaultGenerator(Generic[CacheT], GeneratorWrapper, FromHFHub): +class DefaultGenerator(Generic[CacheT], GeneratorWrapper, FromHF): """ Generator wrapper for models that do not need specific prompting. """ diff --git a/curated_transformers/generation/falcon.py b/curated_transformers/generation/falcon.py index b92b5ab0..1f2fa74e 100644 --- a/curated_transformers/generation/falcon.py +++ b/curated_transformers/generation/falcon.py @@ -4,13 +4,13 @@ from ..tokenizers.chunks import InputChunks, TextChunk from ..tokenizers.tokenizer import Tokenizer from .default_generator import DefaultGenerator -from .hf_hub import FromHFHub +from .hf_hub import FromHF # Only provided as typing.Self in Python 3.11+. Self = TypeVar("Self", bound="FalconGenerator") -class FalconGenerator(DefaultGenerator, FromHFHub): +class FalconGenerator(DefaultGenerator, FromHF): """ Generator for Falcon model variants. """ diff --git a/curated_transformers/generation/hf_hub.py b/curated_transformers/generation/hf_hub.py index dbb7b363..0e6b3cf3 100644 --- a/curated_transformers/generation/hf_hub.py +++ b/curated_transformers/generation/hf_hub.py @@ -6,10 +6,10 @@ from ..quantization.bnb.config import BitsAndBytesConfig # Only provided as typing.Self in Python 3.11+. -Self = TypeVar("Self", bound="FromHFHub") +Self = TypeVar("Self", bound="FromHF") -class FromHFHub(ABC): +class FromHF(ABC): """ Mixin class for downloading generators from Hugging Face Hub. diff --git a/curated_transformers/generation/llama.py b/curated_transformers/generation/llama.py index 8f2eb512..039a50a1 100644 --- a/curated_transformers/generation/llama.py +++ b/curated_transformers/generation/llama.py @@ -3,13 +3,13 @@ from ..models.llama import LlamaCausalLM from ..tokenizers.tokenizer import Tokenizer from .default_generator import DefaultGenerator -from .hf_hub import FromHFHub +from .hf_hub import FromHF # Only provided as typing.Self in Python 3.11+. Self = TypeVar("Self", bound="LlamaGenerator") -class LlamaGenerator(DefaultGenerator, FromHFHub): +class LlamaGenerator(DefaultGenerator, FromHF): """ Generator for Llama and Llama 2 model variants. """ diff --git a/curated_transformers/generation/mpt.py b/curated_transformers/generation/mpt.py index fade3cce..1dde3e1a 100644 --- a/curated_transformers/generation/mpt.py +++ b/curated_transformers/generation/mpt.py @@ -3,10 +3,10 @@ from ..models.mpt import MPTCausalLM from ..tokenizers.tokenizer import Tokenizer from .default_generator import DefaultGenerator -from .hf_hub import FromHFHub +from .hf_hub import FromHF -class MPTGenerator(DefaultGenerator, FromHFHub): +class MPTGenerator(DefaultGenerator, FromHF): """ Generator for MPT model variants. """ diff --git a/curated_transformers/layers/attention.py b/curated_transformers/layers/attention.py index 4a31bc5f..d72c9f60 100644 --- a/curated_transformers/layers/attention.py +++ b/curated_transformers/layers/attention.py @@ -11,7 +11,6 @@ from torch import Tensor from torch.nn import Dropout, Linear, Module -from ..semver import Default, FutureMandatory from .cache import KeyValueCache from .embeddings import QueryKeyRotaryEmbeddings @@ -346,7 +345,7 @@ def __init__( *, n_query_heads: int, n_key_value_heads: int, - qkv_split: FutureMandatory[QkvSplit] = Default, + qkv_split: QkvSplit, ): """ Construct an attention head configuration. This constructor must @@ -366,16 +365,13 @@ def __init__( """ self._n_query_heads = n_query_heads self._n_key_value_heads = n_key_value_heads - - qkv_split = QkvSplitGroupedByKVHeads() if qkv_split is Default else qkv_split - assert isinstance(qkv_split, QkvSplit) self._qkv_split = qkv_split @classmethod def uniform( cls, n_attention_heads: int, - qkv_split: FutureMandatory[QkvSplit] = Default, + qkv_split: QkvSplit, ) -> "AttentionHeads": """ Construct a head configuration where query, key, and value have the @@ -398,7 +394,7 @@ def uniform( def multi_query( cls, n_query_heads: int, - qkv_split: FutureMandatory[QkvSplit] = Default, + qkv_split: QkvSplit, ) -> "AttentionHeads": """ Construct a multi-query attention configuration: key has one head, @@ -425,7 +421,7 @@ def key_value_broadcast( *, n_query_heads: int, n_key_value_heads: int, - qkv_split: FutureMandatory[QkvSplit] = Default, + qkv_split: QkvSplit, ) -> "AttentionHeads": """ Construct a head configuration where query has a larger number diff --git a/curated_transformers/models/__init__.py b/curated_transformers/models/__init__.py index 825d02e1..a0e1f5e1 100644 --- a/curated_transformers/models/__init__.py +++ b/curated_transformers/models/__init__.py @@ -12,7 +12,7 @@ ) from .falcon import FalconCausalLM, FalconConfig, FalconDecoder from .gpt_neox import GPTNeoXCausalLM, GPTNeoXConfig, GPTNeoXDecoder -from .hf_hub import FromHFHub +from .hf_hub import FromHF from .llama import LlamaCausalLM, LlamaConfig, LlamaDecoder from .module import CausalLMModule, DecoderModule, EncoderModule from .mpt import MPTCausalLM, MPTConfig, MPTDecoder @@ -37,7 +37,7 @@ "FalconCausalLM", "FalconConfig", "FalconDecoder", - "FromHFHub", + "FromHF", "GPTNeoXCausalLM", "GPTNeoXConfig", "GPTNeoXDecoder", diff --git a/curated_transformers/models/albert/encoder.py b/curated_transformers/models/albert/encoder.py index 05284c23..954a6c2a 100644 --- a/curated_transformers/models/albert/encoder.py +++ b/curated_transformers/models/albert/encoder.py @@ -10,7 +10,7 @@ EmbeddingLayerNorms, TransformerEmbeddings, ) -from ..hf_hub import FromHFHub +from ..hf_hub import FromHF from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..module import EncoderModule from ..output import ModelOutput @@ -22,7 +22,7 @@ Self = TypeVar("Self", bound="ALBERTEncoder") -class ALBERTEncoder(EncoderModule[ALBERTConfig], FromHFHub[ALBERTConfig]): +class ALBERTEncoder(EncoderModule[ALBERTConfig], FromHF[ALBERTConfig]): """ ALBERT (`Lan et al., 2022`_) encoder. diff --git a/curated_transformers/models/albert/layer_group.py b/curated_transformers/models/albert/layer_group.py index 874e6c04..1289ac1d 100644 --- a/curated_transformers/models/albert/layer_group.py +++ b/curated_transformers/models/albert/layer_group.py @@ -9,6 +9,7 @@ AttentionHeads, AttentionMask, QkvMode, + QkvSplitGroupedByKVHeads, ScaledDotProductAttention, SelfAttention, ) @@ -45,7 +46,8 @@ def __init__( EncoderLayer( attention_layer=SelfAttention( attention_heads=AttentionHeads.uniform( - attention_config.n_query_heads + attention_config.n_query_heads, + QkvSplitGroupedByKVHeads(), ), attention_scorer=ScaledDotProductAttention( dropout_prob=attention_config.dropout_prob, diff --git a/curated_transformers/models/auto_model.py b/curated_transformers/models/auto_model.py index 172c5b1d..5aa2b286 100644 --- a/curated_transformers/models/auto_model.py +++ b/curated_transformers/models/auto_model.py @@ -14,7 +14,7 @@ from ..repository.hf_hub import HfHubRepository from ..repository.repository import ModelRepository, Repository from .config import TransformerConfig -from .hf_hub import FromHFHub +from .hf_hub import FromHF from .module import CausalLMModule, DecoderModule, EncoderModule, TransformerModule ModelT = TypeVar("ModelT") @@ -33,14 +33,14 @@ class AutoModel(ABC, Generic[ModelT]): def _resolve_model_cls( cls, repo: ModelRepository, - ) -> Type[FromHFHub]: + ) -> Type[FromHF]: config = repo.model_config() for entrypoint, module_cls in cls._registry.get_entry_points().items(): - if not issubclass(module_cls, FromHFHub): + if not issubclass(module_cls, FromHF): warnings.warn( f"Entry point `{entrypoint}` cannot load from Hugging Face Hub " - "since the FromHFHub mixin is not implemented" + "since the FromHF mixin is not implemented" ) continue @@ -70,7 +70,7 @@ def _instantiate_model( repo: Repository, device: Optional[torch.device], quantization_config: Optional[BitsAndBytesConfig], - ) -> FromHFHub: + ) -> FromHF: module_cls = cls._resolve_model_cls(ModelRepository(repo)) module = module_cls.from_repo( repo=repo, diff --git a/curated_transformers/models/bert/encoder.py b/curated_transformers/models/bert/encoder.py index e7d3bdee..7edc84e6 100644 --- a/curated_transformers/models/bert/encoder.py +++ b/curated_transformers/models/bert/encoder.py @@ -8,6 +8,7 @@ from ...layers.attention import ( AttentionHeads, QkvMode, + QkvSplitGroupedByKVHeads, ScaledDotProductAttention, SelfAttention, ) @@ -20,7 +21,7 @@ TransformerEmbeddings, TransformerLayerNorms, ) -from ..hf_hub import FromHFHub +from ..hf_hub import FromHF from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerEncoder from ._hf import HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf @@ -30,7 +31,7 @@ Self = TypeVar("Self", bound="BERTEncoder") -class BERTEncoder(TransformerEncoder[BERTConfig], FromHFHub[BERTConfig]): +class BERTEncoder(TransformerEncoder[BERTConfig], FromHF[BERTConfig]): """ BERT (`Devlin et al., 2018`_) encoder. @@ -85,7 +86,8 @@ def __init__( EncoderLayer( attention_layer=SelfAttention( attention_heads=AttentionHeads.uniform( - config.layer.attention.n_query_heads + config.layer.attention.n_query_heads, + QkvSplitGroupedByKVHeads(), ), attention_scorer=ScaledDotProductAttention( dropout_prob=config.layer.attention.dropout_prob, diff --git a/curated_transformers/models/falcon/causal_lm.py b/curated_transformers/models/falcon/causal_lm.py index f971c9d6..11e943ed 100644 --- a/curated_transformers/models/falcon/causal_lm.py +++ b/curated_transformers/models/falcon/causal_lm.py @@ -5,7 +5,7 @@ from torch.nn import Linear from ...quantization.quantizable import Quantizable -from ..hf_hub import FromHFHub +from ..hf_hub import FromHF from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerCausalLM from ._hf import CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf @@ -17,7 +17,7 @@ class FalconCausalLM( - TransformerCausalLM[FalconConfig], FromHFHub[FalconConfig], Quantizable + TransformerCausalLM[FalconConfig], FromHF[FalconConfig], Quantizable ): """ Falcon (`Penedo et al., 2019`_) causal language model. diff --git a/curated_transformers/models/falcon/decoder.py b/curated_transformers/models/falcon/decoder.py index bd7508d4..c37fc74f 100644 --- a/curated_transformers/models/falcon/decoder.py +++ b/curated_transformers/models/falcon/decoder.py @@ -9,6 +9,7 @@ AttentionHeads, AttentionLinearBiases, QkvMode, + QkvSplitGroupedByKVHeads, ScaledDotProductAttention, SelfAttention, ) @@ -22,7 +23,7 @@ TransformerEmbeddings, TransformerLayerNorms, ) -from ..hf_hub import FromHFHub +from ..hf_hub import FromHF from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerDecoder from ._hf import DECODER_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf @@ -33,7 +34,7 @@ Self = TypeVar("Self", bound="FalconDecoder") -class FalconDecoder(TransformerDecoder[FalconConfig], FromHFHub[FalconConfig]): +class FalconDecoder(TransformerDecoder[FalconConfig], FromHF[FalconConfig]): """ Falcon (`Penedo et al., 2019`_) decoder. @@ -166,6 +167,7 @@ def _create_new_decoder_architecture_layer( attention_heads=AttentionHeads.key_value_broadcast( n_query_heads=n_attention_heads, n_key_value_heads=config.layer.attention.n_key_value_heads, + qkv_split=QkvSplitGroupedByKVHeads(), ), attention_scorer=ScaledDotProductAttention( dropout_prob=config.layer.attention.dropout_prob, diff --git a/curated_transformers/models/falcon/layer.py b/curated_transformers/models/falcon/layer.py index 136b2700..f72f458d 100644 --- a/curated_transformers/models/falcon/layer.py +++ b/curated_transformers/models/falcon/layer.py @@ -10,6 +10,7 @@ AttentionMask, KeyValueCache, QkvMode, + QkvSplitGroupedByKVHeads, ScaledDotProductAttention, SelfAttention, ) @@ -71,6 +72,7 @@ def __init__( attention_heads=AttentionHeads.key_value_broadcast( n_query_heads=attention_config.n_query_heads, n_key_value_heads=attention_config.n_key_value_heads, + qkv_split=QkvSplitGroupedByKVHeads(), ), rotary_embeds=rotary_embeds, qkv_mode=( diff --git a/curated_transformers/models/gpt_neox/causal_lm.py b/curated_transformers/models/gpt_neox/causal_lm.py index d27aed03..85637def 100644 --- a/curated_transformers/models/gpt_neox/causal_lm.py +++ b/curated_transformers/models/gpt_neox/causal_lm.py @@ -5,7 +5,7 @@ from torch.nn import Linear from ...quantization import Quantizable -from ..hf_hub import FromHFHub +from ..hf_hub import FromHF from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerCausalLM from ._hf import CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf @@ -17,7 +17,7 @@ class GPTNeoXCausalLM( - TransformerCausalLM[GPTNeoXConfig], FromHFHub[GPTNeoXConfig], Quantizable + TransformerCausalLM[GPTNeoXConfig], FromHF[GPTNeoXConfig], Quantizable ): """ GPT-NeoX (`Black et al., 2022`_) causal language model. diff --git a/curated_transformers/models/gpt_neox/decoder.py b/curated_transformers/models/gpt_neox/decoder.py index e2ca6f66..6649ebaf 100644 --- a/curated_transformers/models/gpt_neox/decoder.py +++ b/curated_transformers/models/gpt_neox/decoder.py @@ -8,6 +8,7 @@ from ...layers.attention import ( AttentionHeads, QkvMode, + QkvSplitGroupedByKVHeads, ScaledDotProductAttention, SelfAttention, ) @@ -21,7 +22,7 @@ TransformerEmbeddings, TransformerLayerNorms, ) -from ..hf_hub import FromHFHub +from ..hf_hub import FromHF from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerDecoder from ._hf import DECODER_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf @@ -31,7 +32,7 @@ Self = TypeVar("Self", bound="GPTNeoXDecoder") -class GPTNeoXDecoder(TransformerDecoder[GPTNeoXConfig], FromHFHub): +class GPTNeoXDecoder(TransformerDecoder[GPTNeoXConfig], FromHF): """ GPT-NeoX (`Black et al., 2022`_) decoder. @@ -82,7 +83,10 @@ def __init__( [ DecoderLayer( attention_layer=SelfAttention( - attention_heads=AttentionHeads.uniform(n_attention_heads), + attention_heads=AttentionHeads.uniform( + n_attention_heads, + QkvSplitGroupedByKVHeads(), + ), attention_scorer=ScaledDotProductAttention( dropout_prob=config.layer.attention.dropout_prob, linear_biases=None, diff --git a/curated_transformers/models/hf_hub/__init__.py b/curated_transformers/models/hf_hub/__init__.py index d0c7bd0f..28a1021a 100644 --- a/curated_transformers/models/hf_hub/__init__.py +++ b/curated_transformers/models/hf_hub/__init__.py @@ -1 +1 @@ -from .mixin import FromHFHub +from .mixin import FromHF diff --git a/curated_transformers/models/hf_hub/mixin.py b/curated_transformers/models/hf_hub/mixin.py index d940fb72..a16a9b9e 100644 --- a/curated_transformers/models/hf_hub/mixin.py +++ b/curated_transformers/models/hf_hub/mixin.py @@ -15,27 +15,18 @@ from ..module import ConfigT, TransformerModule # Only provided as typing.Self in Python 3.11+. -Self = TypeVar("Self", bound="FromHFHub") +Self = TypeVar("Self", bound="FromHF") -class FromHFHub(ABC, Generic[ConfigT]): +class FromHF(ABC, Generic[ConfigT]): """ Mixin class for downloading models from Hugging Face Hub. - A module using this mixin can implement the ``convert_hf_state_dict`` - and ``from_hf_config`` methods. The mixin will then provide the - ``from_hf_hub`` method to download a model from the Hugging Face Hub. + Implementation of the mixin's abstract methods will provide various ``from_`` + methods to load a model, including the ``from_hf_hub`` method to download a + model from the Hugging Face Hub. """ - @classmethod - def convert_hf_state_dict( - cls, params: Mapping[str, Tensor] - ) -> Mapping[str, Tensor]: - """ - Alias for :meth:`.state_dict_from_hf`. - """ - return cls.state_dict_from_hf(params) - @classmethod @abstractmethod def state_dict_from_hf(cls, params: Mapping[str, Tensor]) -> Mapping[str, Tensor]: @@ -339,7 +330,7 @@ def from_repo_( self, # type:ignore filepaths=checkpoint_filenames, checkpoint_type=checkpoint_type, - state_dict_converter=type(self).convert_hf_state_dict, + state_dict_converter=type(self).state_dict_from_hf, tensor_to_param_converter=tensor2param, device=device, ) diff --git a/curated_transformers/models/llama/causal_lm.py b/curated_transformers/models/llama/causal_lm.py index 4def2088..82221d0f 100644 --- a/curated_transformers/models/llama/causal_lm.py +++ b/curated_transformers/models/llama/causal_lm.py @@ -5,7 +5,7 @@ from torch.nn import Linear from ...quantization import Quantizable -from ..hf_hub import FromHFHub +from ..hf_hub import FromHF from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerCausalLM from ._hf import CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf @@ -16,9 +16,7 @@ Self = TypeVar("Self", bound="LlamaCausalLM") -class LlamaCausalLM( - TransformerCausalLM[LlamaConfig], FromHFHub[LlamaConfig], Quantizable -): +class LlamaCausalLM(TransformerCausalLM[LlamaConfig], FromHF[LlamaConfig], Quantizable): """ Llama (`Touvron et al., 2023 [a]`_, `Touvron et al., 2023 [b]`_) causal language model. diff --git a/curated_transformers/models/llama/decoder.py b/curated_transformers/models/llama/decoder.py index 3c7b57f9..853caccf 100644 --- a/curated_transformers/models/llama/decoder.py +++ b/curated_transformers/models/llama/decoder.py @@ -8,6 +8,7 @@ from ...layers.attention import ( AttentionHeads, QkvMode, + QkvSplitGroupedByKVHeads, ScaledDotProductAttention, SelfAttention, ) @@ -22,7 +23,7 @@ TransformerEmbeddings, TransformerLayerNorms, ) -from ..hf_hub import FromHFHub +from ..hf_hub import FromHF from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerDecoder from ._hf import DECODER_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf @@ -32,7 +33,7 @@ Self = TypeVar("Self", bound="LlamaDecoder") -class LlamaDecoder(TransformerDecoder[LlamaConfig], FromHFHub[LlamaConfig]): +class LlamaDecoder(TransformerDecoder[LlamaConfig], FromHF[LlamaConfig]): """ Llama (`Touvron et al., 2023 [a]`_, `Touvron et al., 2023 [b]`_) decoder. @@ -73,6 +74,7 @@ def __init__( attention_heads = AttentionHeads.key_value_broadcast( n_query_heads=n_query_heads, n_key_value_heads=config.layer.attention.n_key_value_heads, + qkv_split=QkvSplitGroupedByKVHeads(), ) layer_norm = partial( RMSNorm, diff --git a/curated_transformers/models/mpt/causal_lm.py b/curated_transformers/models/mpt/causal_lm.py index 31964fff..0624a6d3 100644 --- a/curated_transformers/models/mpt/causal_lm.py +++ b/curated_transformers/models/mpt/causal_lm.py @@ -8,7 +8,7 @@ from ...layers.attention import AttentionMask from ...layers.cache import KeyValueCache from ...quantization import Quantizable -from ..hf_hub import FromHFHub +from ..hf_hub import FromHF from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..output import CausalLMOutputWithCache from ..transformer import TransformerCausalLM @@ -20,7 +20,7 @@ Self = TypeVar("Self", bound="MPTCausalLM") -class MPTCausalLM(TransformerCausalLM[MPTConfig], FromHFHub[MPTConfig], Quantizable): +class MPTCausalLM(TransformerCausalLM[MPTConfig], FromHF[MPTConfig], Quantizable): """ `MosaicML MPT`_ causal language model. diff --git a/curated_transformers/models/mpt/decoder.py b/curated_transformers/models/mpt/decoder.py index 4a77953c..95fc6fcb 100644 --- a/curated_transformers/models/mpt/decoder.py +++ b/curated_transformers/models/mpt/decoder.py @@ -21,7 +21,7 @@ TransformerEmbeddings, TransformerLayerNorms, ) -from ..hf_hub import FromHFHub +from ..hf_hub import FromHF from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerDecoder from ._hf import DECODER_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf @@ -31,7 +31,7 @@ Self = TypeVar("Self", bound="MPTDecoder") -class MPTDecoder(TransformerDecoder[MPTConfig], FromHFHub[MPTConfig]): +class MPTDecoder(TransformerDecoder[MPTConfig], FromHF[MPTConfig]): """ `MosaicML MPT`_ decoder. diff --git a/curated_transformers/models/roberta/encoder.py b/curated_transformers/models/roberta/encoder.py index b79f3068..ee83fd5e 100644 --- a/curated_transformers/models/roberta/encoder.py +++ b/curated_transformers/models/roberta/encoder.py @@ -8,6 +8,7 @@ from ...layers.attention import ( AttentionHeads, QkvMode, + QkvSplitGroupedByKVHeads, ScaledDotProductAttention, SelfAttention, ) @@ -19,7 +20,7 @@ TransformerDropouts, TransformerLayerNorms, ) -from ..hf_hub import FromHFHub +from ..hf_hub import FromHF from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf from ..transformer import TransformerEncoder from ._hf import HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf @@ -30,7 +31,7 @@ Self = TypeVar("Self", bound="RoBERTaEncoder") -class RoBERTaEncoder(TransformerEncoder[RoBERTaConfig], FromHFHub[RoBERTaConfig]): +class RoBERTaEncoder(TransformerEncoder[RoBERTaConfig], FromHF[RoBERTaConfig]): """ RoBERTa (`Liu et al., 2019`_) encoder. @@ -80,7 +81,8 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No EncoderLayer( attention_layer=SelfAttention( attention_heads=AttentionHeads.uniform( - config.layer.attention.n_query_heads + config.layer.attention.n_query_heads, + QkvSplitGroupedByKVHeads(), ), attention_scorer=ScaledDotProductAttention( dropout_prob=config.layer.attention.dropout_prob, diff --git a/curated_transformers/tests/models/util.py b/curated_transformers/tests/models/util.py index 74bfdef2..e99e5739 100644 --- a/curated_transformers/tests/models/util.py +++ b/curated_transformers/tests/models/util.py @@ -8,7 +8,7 @@ from curated_transformers.layers.attention import AttentionMask, enable_torch_sdp from curated_transformers.layers.cache import KeyValueCache -from curated_transformers.models.hf_hub import FromHFHub +from curated_transformers.models.hf_hub import FromHF from curated_transformers.models.module import ( CausalLMModule, DecoderModule, @@ -71,7 +71,7 @@ def convert(self, model: Module, with_torch_sdp: bool, *args) -> Tuple[ def assert_causal_lm_output_equals_hf( - model_class: Type[FromHFHub], + model_class: Type[FromHF], model_name: str, torch_device: torch.device, *, @@ -122,7 +122,7 @@ def assert_causal_lm_output_equals_hf( def assert_decoder_output_equals_hf( - model_class: Type[FromHFHub], + model_class: Type[FromHF], model_name: str, torch_device: torch.device, *, @@ -186,7 +186,7 @@ def assert_decoder_output_equals_hf( def assert_encoder_output_equals_hf( - model_class: Type[FromHFHub], + model_class: Type[FromHF], model_name: str, torch_device: torch.device, *, @@ -358,7 +358,7 @@ def assert_model_config(model: TransformerModule, model_output: Tensor): def assert_model_hf_serialization_roundtrip( - model_class: Type[FromHFHub], + model_class: Type[FromHF], model_name: str, torch_device: torch.device, *, diff --git a/curated_transformers/tokenizers/__init__.py b/curated_transformers/tokenizers/__init__.py index 6bf6f622..d115c54d 100644 --- a/curated_transformers/tokenizers/__init__.py +++ b/curated_transformers/tokenizers/__init__.py @@ -1,11 +1,11 @@ from .auto_tokenizer import AutoTokenizer from .chunks import InputChunks, SpecialPieceChunk, TextChunk -from .hf_hub import FromHFHub +from .hf_hub import FromHF from .tokenizer import PiecesWithIds, Tokenizer, TokenizerBase __all__ = [ "AutoTokenizer", - "FromHFHub", + "FromHF", "InputChunks", "PiecesWithIds", "SpecialPieceChunk", diff --git a/curated_transformers/tokenizers/auto_tokenizer.py b/curated_transformers/tokenizers/auto_tokenizer.py index a1e9a92c..f5c1994d 100644 --- a/curated_transformers/tokenizers/auto_tokenizer.py +++ b/curated_transformers/tokenizers/auto_tokenizer.py @@ -5,7 +5,7 @@ from ..repository.fsspec import FsspecArgs, FsspecRepository from ..repository.hf_hub import HfHubRepository from ..repository.repository import Repository, TokenizerRepository -from .hf_hub import FromHFHub +from .hf_hub import FromHF from .legacy.bert_tokenizer import BERTTokenizer from .legacy.camembert_tokenizer import CamemBERTTokenizer from .legacy.llama_tokenizer import LlamaTokenizer @@ -13,7 +13,7 @@ from .legacy.xlmr_tokenizer import XLMRTokenizer from .tokenizer import Tokenizer, TokenizerBase -HF_TOKENIZER_MAPPING: Dict[str, Type[FromHFHub]] = { +HF_TOKENIZER_MAPPING: Dict[str, Type[FromHF]] = { "BertTokenizer": BERTTokenizer, "BertTokenizerFast": BERTTokenizer, "CamembertTokenizer": CamemBERTTokenizer, @@ -26,7 +26,7 @@ "XLMRobertaTokenizerFast": XLMRTokenizer, } -HF_MODEL_MAPPING: Dict[str, Type[FromHFHub]] = { +HF_MODEL_MAPPING: Dict[str, Type[FromHF]] = { "bert": BERTTokenizer, "camembert": CamemBERTTokenizer, "llama": LlamaTokenizer, @@ -40,7 +40,7 @@ class AutoTokenizer: Tokenizer loaded from the Hugging Face Model Hub. """ - # NOTE: We do not inherit from FromHFHub, because its from_hf_hub method + # NOTE: We do not inherit from FromHF, because its from_hf_hub method # requires that the return type is Self. @classmethod @@ -122,7 +122,7 @@ def from_hf_hub(cls, *, name: str, revision: str = "main") -> TokenizerBase: def _get_tokenizer_class_from_config( tokenizer_config: Dict[str, Any] -) -> Optional[Type[FromHFHub]]: +) -> Optional[Type[FromHF]]: """ Infer the tokenizer class from the tokenizer configuration. @@ -138,12 +138,12 @@ def _get_tokenizer_class_from_config( def _resolve_tokenizer_class( repo: TokenizerRepository, -) -> Type[FromHFHub]: +) -> Type[FromHF]: tokenizer_file = repo.tokenizer_json() if tokenizer_file.exists(): return Tokenizer - cls: Optional[Type[FromHFHub]] = None + cls: Optional[Type[FromHF]] = None try: tokenizer_config = repo.tokenizer_config() cls = _get_tokenizer_class_from_config(tokenizer_config) diff --git a/curated_transformers/tokenizers/hf_hub.py b/curated_transformers/tokenizers/hf_hub.py index 790cf8a9..d22aa9eb 100644 --- a/curated_transformers/tokenizers/hf_hub.py +++ b/curated_transformers/tokenizers/hf_hub.py @@ -9,10 +9,10 @@ from ..repository.hf_hub import HfHubRepository from ..repository.repository import Repository, TokenizerRepository -SelfFromHFHub = TypeVar("SelfFromHFHub", bound="FromHFHub") +SelfFromHF = TypeVar("SelfFromHF", bound="FromHF") -class FromHFHub(ABC): +class FromHF(ABC): """ Mixin class for downloading tokenizers from Hugging Face Hub. @@ -23,7 +23,7 @@ class FromHFHub(ABC): @classmethod @abstractmethod def from_hf_hub_to_cache( - cls: Type[SelfFromHFHub], + cls: Type[SelfFromHF], *, name: str, revision: str = "main", @@ -43,12 +43,12 @@ def from_hf_hub_to_cache( @classmethod def from_fsspec( - cls: Type[SelfFromHFHub], + cls: Type[SelfFromHF], *, fs: AbstractFileSystem, model_path: str, fsspec_args: Optional[FsspecArgs] = None, - ) -> SelfFromHFHub: + ) -> SelfFromHF: """ Construct a tokenizer and load its parameters from an fsspec filesystem. @@ -68,8 +68,8 @@ def from_fsspec( @classmethod def from_hf_hub( - cls: Type[SelfFromHFHub], *, name: str, revision: str = "main" - ) -> SelfFromHFHub: + cls: Type[SelfFromHF], *, name: str, revision: str = "main" + ) -> SelfFromHF: """ Construct a tokenizer and load its parameters from Hugging Face Hub. @@ -87,9 +87,9 @@ def from_hf_hub( @classmethod @abstractmethod def from_repo( - cls: Type[SelfFromHFHub], + cls: Type[SelfFromHF], repo: Repository, - ) -> SelfFromHFHub: + ) -> SelfFromHF: """ Construct and load a tokenizer from a repository. @@ -101,12 +101,12 @@ def from_repo( ... -SelfLegacyFromHFHub = TypeVar("SelfLegacyFromHFHub", bound="LegacyFromHFHub") +SelfLegacyFromHF = TypeVar("SelfLegacyFromHF", bound="LegacyFromHF") -class LegacyFromHFHub(FromHFHub): +class LegacyFromHF(FromHF): """ - Subclass of :class:`.FromHFHub` for legacy tokenizers. This subclass + Subclass of :class:`.FromHF` for legacy tokenizers. This subclass implements the ``from_hf_hub`` method and provides through the abstract ``_load_from_vocab_files`` method: @@ -120,11 +120,11 @@ class LegacyFromHFHub(FromHFHub): @classmethod @abstractmethod def _load_from_vocab_files( - cls: Type[SelfLegacyFromHFHub], + cls: Type[SelfLegacyFromHF], *, vocab_files: Mapping[str, RepositoryFile], tokenizer_config: Optional[Dict[str, Any]], - ) -> SelfLegacyFromHFHub: + ) -> SelfLegacyFromHF: """ Construct a tokenizer from its vocabulary files and optional configuration. @@ -140,7 +140,7 @@ def _load_from_vocab_files( @classmethod def from_hf_hub_to_cache( - cls: Type[SelfLegacyFromHFHub], + cls: Type[SelfLegacyFromHF], *, name: str, revision: str = "main", @@ -156,9 +156,9 @@ def from_hf_hub_to_cache( @classmethod def from_repo( - cls: Type[SelfLegacyFromHFHub], + cls: Type[SelfLegacyFromHF], repo: Repository, - ) -> SelfLegacyFromHFHub: + ) -> SelfLegacyFromHF: repo = TokenizerRepository(repo) vocab_files = {} for vocab_file, filename in cls.vocab_files.items(): diff --git a/curated_transformers/tokenizers/legacy/bert_tokenizer.py b/curated_transformers/tokenizers/legacy/bert_tokenizer.py index 98aeb4d1..bbca368c 100644 --- a/curated_transformers/tokenizers/legacy/bert_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/bert_tokenizer.py @@ -12,7 +12,7 @@ SpecialPieceChunk, TextChunk, ) -from ..hf_hub import LegacyFromHFHub +from ..hf_hub import LegacyFromHF from ..tokenizer import PiecesWithIds from ..util import remove_pieces_from_sequence from .legacy_tokenizer import ( @@ -192,7 +192,7 @@ def __call__(self, chunks: Iterable[InputChunks]) -> List[InputChunks]: return chunks -class BERTTokenizer(WordPieceTokenizer, LegacyFromHFHub): +class BERTTokenizer(WordPieceTokenizer, LegacyFromHF): """ Legacy tokenizer for BERT (`Devlin et al., 2018`_) models. diff --git a/curated_transformers/tokenizers/legacy/camembert_tokenizer.py b/curated_transformers/tokenizers/legacy/camembert_tokenizer.py index b54b0642..61ffe20b 100644 --- a/curated_transformers/tokenizers/legacy/camembert_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/camembert_tokenizer.py @@ -3,7 +3,7 @@ from curated_tokenizers import SentencePieceProcessor from ...repository.file import RepositoryFile -from ..hf_hub import LegacyFromHFHub +from ..hf_hub import LegacyFromHF from ._fairseq import FAIRSEQ_PIECE_IDS, FairSeqPostEncoder, FairSeqPreDecoder from .legacy_tokenizer import AddBosEosPreEncoder from .sentencepiece_tokenizer import SentencePieceTokenizer @@ -65,7 +65,7 @@ def _fairseq_to_sentencepiece(piece_id: int): return piece_id - _CAMEMBERT_FAIRSEQ_OFFSET -class CamemBERTTokenizer(SentencePieceTokenizer, LegacyFromHFHub): +class CamemBERTTokenizer(SentencePieceTokenizer, LegacyFromHF): """ Legacy tokenizer for CamemBERT (`Martin et al., 2020`_) models. diff --git a/curated_transformers/tokenizers/legacy/llama_tokenizer.py b/curated_transformers/tokenizers/legacy/llama_tokenizer.py index 7a6f2a33..ab837f14 100644 --- a/curated_transformers/tokenizers/legacy/llama_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/llama_tokenizer.py @@ -3,7 +3,7 @@ from curated_tokenizers import SentencePieceProcessor from ...repository.file import RepositoryFile -from ..hf_hub import LegacyFromHFHub +from ..hf_hub import LegacyFromHF from .legacy_tokenizer import AddBosEosPreEncoder from .sentencepiece_tokenizer import SentencePieceTokenizer @@ -13,7 +13,7 @@ DEFAULT_BOS_PIECE = "" -class LlamaTokenizer(SentencePieceTokenizer, LegacyFromHFHub): +class LlamaTokenizer(SentencePieceTokenizer, LegacyFromHF): """ Legacy tokenizer for Llama (`Touvron et al., 2023 [a]`_, `Touvron et al., 2023 [b]`_) models. diff --git a/curated_transformers/tokenizers/legacy/roberta_tokenizer.py b/curated_transformers/tokenizers/legacy/roberta_tokenizer.py index b296f7f1..fe2ebdec 100644 --- a/curated_transformers/tokenizers/legacy/roberta_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/roberta_tokenizer.py @@ -3,7 +3,7 @@ from curated_tokenizers import ByteBPEProcessor from ...repository.file import RepositoryFile -from ..hf_hub import LegacyFromHFHub +from ..hf_hub import LegacyFromHF from ..util import remove_pieces_from_sequence from .bbpe_tokenizer import ByteBPETokenizer from .legacy_tokenizer import AddBosEosPreEncoder, PreDecoder @@ -37,7 +37,7 @@ def __call__(self, input: Iterable[Iterable[int]]) -> List[List[int]]: ] -class RoBERTaTokenizer(ByteBPETokenizer, LegacyFromHFHub): +class RoBERTaTokenizer(ByteBPETokenizer, LegacyFromHF): """ Legacy tokenizer for RoBERTa (`Liu et al., 2019`_) models. diff --git a/curated_transformers/tokenizers/legacy/xlmr_tokenizer.py b/curated_transformers/tokenizers/legacy/xlmr_tokenizer.py index 86d4eca7..7716ef67 100644 --- a/curated_transformers/tokenizers/legacy/xlmr_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/xlmr_tokenizer.py @@ -3,7 +3,7 @@ from curated_tokenizers import SentencePieceProcessor from ...repository.file import RepositoryFile -from ..hf_hub import LegacyFromHFHub +from ..hf_hub import LegacyFromHF from ._fairseq import FAIRSEQ_PIECE_IDS, FairSeqPostEncoder, FairSeqPreDecoder from .legacy_tokenizer import AddBosEosPreEncoder from .sentencepiece_tokenizer import SentencePieceTokenizer @@ -73,7 +73,7 @@ def _fairseq_to_sentencepiece(piece_id: int): return piece_id - _XLMR_FAIRSEQ_OFFSET -class XLMRTokenizer(SentencePieceTokenizer, LegacyFromHFHub): +class XLMRTokenizer(SentencePieceTokenizer, LegacyFromHF): """ Legacy tokenizer for XLM-RoBERTa (`Conneau et al., 2019`_) models. diff --git a/curated_transformers/tokenizers/tokenizer.py b/curated_transformers/tokenizers/tokenizer.py index 88c3951d..0e6e3898 100644 --- a/curated_transformers/tokenizers/tokenizer.py +++ b/curated_transformers/tokenizers/tokenizer.py @@ -17,7 +17,7 @@ from ..repository.repository import Repository, TokenizerRepository from ._hf_compat import clean_up_decoded_string_like_hf from .chunks import InputChunks, MergedSpecialPieceChunk -from .hf_hub import FromHFHub +from .hf_hub import FromHF # Only provided as typing.Self in Python 3.11+. Self = TypeVar("Self", bound="Tokenizer") @@ -176,7 +176,7 @@ def eos_piece(self) -> Optional[str]: ... -class Tokenizer(TokenizerBase, FromHFHub): +class Tokenizer(TokenizerBase, FromHF): """ Wraps the tokenizers from the ``tokenizers`` package. It supports a wide range of piece tokenizers, including word piece, byte pair encoding, and diff --git a/docs/source/api-compat.rst b/docs/source/api-compat.rst index 65580524..237c1eeb 100644 --- a/docs/source/api-compat.rst +++ b/docs/source/api-compat.rst @@ -109,7 +109,7 @@ Version 1 to 2 * The factory methods of :py:class:`~curated_transformers.layers.AttentionHeads` add a new ``qkv_split`` argument which is mandatory in future versions. -* The ``FromHFHub`` mixins will be renamed to ``FromHF``. -* The ``convert_hf_state_dict`` method in ``FromHFHub`` will be removed +* The ``FromHFHub`` mixins are be renamed to ``FromHF``. +* The ``convert_hf_state_dict`` method in ``FromHF`` is be removed in favour of ``state_dict_from_hf``. -* The ``SelfAttention`` class will take an additional ``AttentionScorer`` argument. \ No newline at end of file +* The ``SelfAttention`` class takes an additional ``AttentionScorer`` argument. \ No newline at end of file diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 81bffb24..289782ac 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -23,11 +23,11 @@ tokenizers directly from Hugging Face Hub. .. _Hugging Face Hub client: https://huggingface.co/docs/huggingface_hub/quick-start#login -.. autoclass:: curated_transformers.models.FromHFHub +.. autoclass:: curated_transformers.models.FromHF :members: -.. autoclass:: curated_transformers.generation.FromHFHub +.. autoclass:: curated_transformers.generation.FromHF :members: -.. autoclass:: curated_transformers.tokenizers.FromHFHub +.. autoclass:: curated_transformers.tokenizers.FromHF :members: