Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finalize the API changes for 2.0 #374

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions curated_transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .falcon import FalconGenerator
from .generator import Generator
from .generator_wrapper import GeneratorWrapper
from .hf_hub import FromHFHub
from .hf_hub import FromHF
from .llama import LlamaGenerator
from .logits import (
CompoundLogitsTransform,
Expand All @@ -32,7 +32,7 @@
"DollyV2Generator",
"EndOfSequenceCondition",
"FalconGenerator",
"FromHFHub",
"FromHF",
"Generator",
"GeneratorConfig",
"GeneratorWrapper",
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/generation/auto_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from .dolly_v2 import DollyV2Generator
from .falcon import FalconGenerator
from .generator_wrapper import GeneratorWrapper
from .hf_hub import FromHFHub
from .hf_hub import FromHF
from .llama import LlamaGenerator
from .mpt import MPTGenerator

# For the time being, we enable support for a generator on a case-by-case basis.
# In the future we might defer all unknown generators to DefaultGenerator.
GENERATOR_MAP: Dict[str, Type[FromHFHub]] = {
GENERATOR_MAP: Dict[str, Type[FromHF]] = {
"dolly-v2": DollyV2Generator,
"falcon": FalconGenerator,
"llama": LlamaGenerator,
Expand Down Expand Up @@ -70,7 +70,7 @@ def from_hf_hub(
return generator


def _resolve_generator_class(name: str) -> Type[FromHFHub]:
def _resolve_generator_class(name: str) -> Type[FromHF]:
for substring, generator_cls in GENERATOR_MAP.items():
if substring in name.lower():
return generator_cls
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/generation/default_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from .config import GeneratorConfig, SampleGeneratorConfig
from .generator import Generator
from .generator_wrapper import GeneratorWrapper
from .hf_hub import FromHFHub
from .hf_hub import FromHF
from .string_generator import StringGenerator

# Only provided as typing.Self in Python 3.11+.
Self = TypeVar("Self", bound="DefaultGenerator")


class DefaultGenerator(Generic[CacheT], GeneratorWrapper, FromHFHub):
class DefaultGenerator(Generic[CacheT], GeneratorWrapper, FromHF):
"""
Generator wrapper for models that do not need specific prompting.
"""
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/generation/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from ..tokenizers.chunks import InputChunks, TextChunk
from ..tokenizers.tokenizer import Tokenizer
from .default_generator import DefaultGenerator
from .hf_hub import FromHFHub
from .hf_hub import FromHF

# Only provided as typing.Self in Python 3.11+.
Self = TypeVar("Self", bound="FalconGenerator")


class FalconGenerator(DefaultGenerator, FromHFHub):
class FalconGenerator(DefaultGenerator, FromHF):
"""
Generator for Falcon model variants.
"""
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/generation/hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from ..quantization.bnb.config import BitsAndBytesConfig

# Only provided as typing.Self in Python 3.11+.
Self = TypeVar("Self", bound="FromHFHub")
Self = TypeVar("Self", bound="FromHF")


class FromHFHub(ABC):
class FromHF(ABC):
"""
Mixin class for downloading generators from Hugging Face Hub.

Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/generation/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from ..models.llama import LlamaCausalLM
from ..tokenizers.tokenizer import Tokenizer
from .default_generator import DefaultGenerator
from .hf_hub import FromHFHub
from .hf_hub import FromHF

# Only provided as typing.Self in Python 3.11+.
Self = TypeVar("Self", bound="LlamaGenerator")


class LlamaGenerator(DefaultGenerator, FromHFHub):
class LlamaGenerator(DefaultGenerator, FromHF):
"""
Generator for Llama and Llama 2 model variants.
"""
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/generation/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from ..models.mpt import MPTCausalLM
from ..tokenizers.tokenizer import Tokenizer
from .default_generator import DefaultGenerator
from .hf_hub import FromHFHub
from .hf_hub import FromHF


class MPTGenerator(DefaultGenerator, FromHFHub):
class MPTGenerator(DefaultGenerator, FromHF):
"""
Generator for MPT model variants.
"""
Expand Down
12 changes: 4 additions & 8 deletions curated_transformers/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch import Tensor
from torch.nn import Dropout, Linear, Module

from ..semver import Default, FutureMandatory
from .cache import KeyValueCache
from .embeddings import QueryKeyRotaryEmbeddings

Expand Down Expand Up @@ -346,7 +345,7 @@ def __init__(
*,
n_query_heads: int,
n_key_value_heads: int,
qkv_split: FutureMandatory[QkvSplit] = Default,
qkv_split: QkvSplit,
):
"""
Construct an attention head configuration. This constructor must
Expand All @@ -366,16 +365,13 @@ def __init__(
"""
self._n_query_heads = n_query_heads
self._n_key_value_heads = n_key_value_heads

qkv_split = QkvSplitGroupedByKVHeads() if qkv_split is Default else qkv_split
assert isinstance(qkv_split, QkvSplit)
self._qkv_split = qkv_split

@classmethod
def uniform(
cls,
n_attention_heads: int,
qkv_split: FutureMandatory[QkvSplit] = Default,
qkv_split: QkvSplit,
) -> "AttentionHeads":
"""
Construct a head configuration where query, key, and value have the
Expand All @@ -398,7 +394,7 @@ def uniform(
def multi_query(
cls,
n_query_heads: int,
qkv_split: FutureMandatory[QkvSplit] = Default,
qkv_split: QkvSplit,
) -> "AttentionHeads":
"""
Construct a multi-query attention configuration: key has one head,
Expand All @@ -425,7 +421,7 @@ def key_value_broadcast(
*,
n_query_heads: int,
n_key_value_heads: int,
qkv_split: FutureMandatory[QkvSplit] = Default,
qkv_split: QkvSplit,
) -> "AttentionHeads":
"""
Construct a head configuration where query has a larger number
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from .falcon import FalconCausalLM, FalconConfig, FalconDecoder
from .gpt_neox import GPTNeoXCausalLM, GPTNeoXConfig, GPTNeoXDecoder
from .hf_hub import FromHFHub
from .hf_hub import FromHF
from .llama import LlamaCausalLM, LlamaConfig, LlamaDecoder
from .module import CausalLMModule, DecoderModule, EncoderModule
from .mpt import MPTCausalLM, MPTConfig, MPTDecoder
Expand All @@ -37,7 +37,7 @@
"FalconCausalLM",
"FalconConfig",
"FalconDecoder",
"FromHFHub",
"FromHF",
"GPTNeoXCausalLM",
"GPTNeoXConfig",
"GPTNeoXDecoder",
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/albert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
EmbeddingLayerNorms,
TransformerEmbeddings,
)
from ..hf_hub import FromHFHub
from ..hf_hub import FromHF
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..module import EncoderModule
from ..output import ModelOutput
Expand All @@ -22,7 +22,7 @@
Self = TypeVar("Self", bound="ALBERTEncoder")


class ALBERTEncoder(EncoderModule[ALBERTConfig], FromHFHub[ALBERTConfig]):
class ALBERTEncoder(EncoderModule[ALBERTConfig], FromHF[ALBERTConfig]):
"""
ALBERT (`Lan et al., 2022`_) encoder.

Expand Down
4 changes: 3 additions & 1 deletion curated_transformers/models/albert/layer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AttentionHeads,
AttentionMask,
QkvMode,
QkvSplitGroupedByKVHeads,
ScaledDotProductAttention,
SelfAttention,
)
Expand Down Expand Up @@ -45,7 +46,8 @@ def __init__(
EncoderLayer(
attention_layer=SelfAttention(
attention_heads=AttentionHeads.uniform(
attention_config.n_query_heads
attention_config.n_query_heads,
QkvSplitGroupedByKVHeads(),
),
attention_scorer=ScaledDotProductAttention(
dropout_prob=attention_config.dropout_prob,
Expand Down
10 changes: 5 additions & 5 deletions curated_transformers/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..repository.hf_hub import HfHubRepository
from ..repository.repository import ModelRepository, Repository
from .config import TransformerConfig
from .hf_hub import FromHFHub
from .hf_hub import FromHF
from .module import CausalLMModule, DecoderModule, EncoderModule, TransformerModule

ModelT = TypeVar("ModelT")
Expand All @@ -33,14 +33,14 @@ class AutoModel(ABC, Generic[ModelT]):
def _resolve_model_cls(
cls,
repo: ModelRepository,
) -> Type[FromHFHub]:
) -> Type[FromHF]:
config = repo.model_config()

for entrypoint, module_cls in cls._registry.get_entry_points().items():
if not issubclass(module_cls, FromHFHub):
if not issubclass(module_cls, FromHF):
warnings.warn(
f"Entry point `{entrypoint}` cannot load from Hugging Face Hub "
"since the FromHFHub mixin is not implemented"
"since the FromHF mixin is not implemented"
)
continue

Expand Down Expand Up @@ -70,7 +70,7 @@ def _instantiate_model(
repo: Repository,
device: Optional[torch.device],
quantization_config: Optional[BitsAndBytesConfig],
) -> FromHFHub:
) -> FromHF:
module_cls = cls._resolve_model_cls(ModelRepository(repo))
module = module_cls.from_repo(
repo=repo,
Expand Down
8 changes: 5 additions & 3 deletions curated_transformers/models/bert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...layers.attention import (
AttentionHeads,
QkvMode,
QkvSplitGroupedByKVHeads,
ScaledDotProductAttention,
SelfAttention,
)
Expand All @@ -20,7 +21,7 @@
TransformerEmbeddings,
TransformerLayerNorms,
)
from ..hf_hub import FromHFHub
from ..hf_hub import FromHF
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..transformer import TransformerEncoder
from ._hf import HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf
Expand All @@ -30,7 +31,7 @@
Self = TypeVar("Self", bound="BERTEncoder")


class BERTEncoder(TransformerEncoder[BERTConfig], FromHFHub[BERTConfig]):
class BERTEncoder(TransformerEncoder[BERTConfig], FromHF[BERTConfig]):
"""
BERT (`Devlin et al., 2018`_) encoder.

Expand Down Expand Up @@ -85,7 +86,8 @@ def __init__(
EncoderLayer(
attention_layer=SelfAttention(
attention_heads=AttentionHeads.uniform(
config.layer.attention.n_query_heads
config.layer.attention.n_query_heads,
QkvSplitGroupedByKVHeads(),
),
attention_scorer=ScaledDotProductAttention(
dropout_prob=config.layer.attention.dropout_prob,
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/falcon/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn import Linear

from ...quantization.quantizable import Quantizable
from ..hf_hub import FromHFHub
from ..hf_hub import FromHF
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..transformer import TransformerCausalLM
from ._hf import CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf
Expand All @@ -17,7 +17,7 @@


class FalconCausalLM(
TransformerCausalLM[FalconConfig], FromHFHub[FalconConfig], Quantizable
TransformerCausalLM[FalconConfig], FromHF[FalconConfig], Quantizable
):
"""
Falcon (`Penedo et al., 2019`_) causal language model.
Expand Down
6 changes: 4 additions & 2 deletions curated_transformers/models/falcon/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AttentionHeads,
AttentionLinearBiases,
QkvMode,
QkvSplitGroupedByKVHeads,
ScaledDotProductAttention,
SelfAttention,
)
Expand All @@ -22,7 +23,7 @@
TransformerEmbeddings,
TransformerLayerNorms,
)
from ..hf_hub import FromHFHub
from ..hf_hub import FromHF
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..transformer import TransformerDecoder
from ._hf import DECODER_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf
Expand All @@ -33,7 +34,7 @@
Self = TypeVar("Self", bound="FalconDecoder")


class FalconDecoder(TransformerDecoder[FalconConfig], FromHFHub[FalconConfig]):
class FalconDecoder(TransformerDecoder[FalconConfig], FromHF[FalconConfig]):
"""
Falcon (`Penedo et al., 2019`_) decoder.

Expand Down Expand Up @@ -166,6 +167,7 @@ def _create_new_decoder_architecture_layer(
attention_heads=AttentionHeads.key_value_broadcast(
n_query_heads=n_attention_heads,
n_key_value_heads=config.layer.attention.n_key_value_heads,
qkv_split=QkvSplitGroupedByKVHeads(),
),
attention_scorer=ScaledDotProductAttention(
dropout_prob=config.layer.attention.dropout_prob,
Expand Down
2 changes: 2 additions & 0 deletions curated_transformers/models/falcon/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AttentionMask,
KeyValueCache,
QkvMode,
QkvSplitGroupedByKVHeads,
ScaledDotProductAttention,
SelfAttention,
)
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
attention_heads=AttentionHeads.key_value_broadcast(
n_query_heads=attention_config.n_query_heads,
n_key_value_heads=attention_config.n_key_value_heads,
qkv_split=QkvSplitGroupedByKVHeads(),
),
rotary_embeds=rotary_embeds,
qkv_mode=(
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/gpt_neox/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn import Linear

from ...quantization import Quantizable
from ..hf_hub import FromHFHub
from ..hf_hub import FromHF
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..transformer import TransformerCausalLM
from ._hf import CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf
Expand All @@ -17,7 +17,7 @@


class GPTNeoXCausalLM(
TransformerCausalLM[GPTNeoXConfig], FromHFHub[GPTNeoXConfig], Quantizable
TransformerCausalLM[GPTNeoXConfig], FromHF[GPTNeoXConfig], Quantizable
):
"""
GPT-NeoX (`Black et al., 2022`_) causal language model.
Expand Down
Loading
Loading