Skip to content

Commit

Permalink
Various type annotation fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
KerfuffleV2 committed Nov 8, 2023
1 parent 8047aa1 commit d7688dc
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 175 deletions.
196 changes: 98 additions & 98 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import sys
from enum import Enum, IntEnum, auto
from typing import Any, NamedTuple
from enum import Enum, IntEnum, auto, StrEnum
from typing import Any, NamedTuple, Type

#
# constants
Expand All @@ -16,63 +16,63 @@
# metadata keys
#

class GeneralKeys(NamedTuple):
ARCHITECTURE = "general.architecture"
QUANTIZATION_VERSION = "general.quantization_version"
ALIGNMENT = "general.alignment"
NAME = "general.name"
AUTHOR = "general.author"
URL = "general.url"
DESCRIPTION = "general.description"
LICENSE = "general.license"
SOURCE_URL = "general.source.url"
SOURCE_HF_REPO = "general.source.huggingface.repository"
FILE_TYPE = "general.file_type"

class AttentionKeys(NamedTuple):
HEAD_COUNT = "{arch}.attention.head_count"
HEAD_COUNT_KV = "{arch}.attention.head_count_kv"
MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias"
CLAMP_KQV = "{arch}.attention.clamp_kqv"
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"

class RopeKeys(NamedTuple):
DIMENSION_COUNT = "{arch}.rope.dimension_count"
FREQ_BASE = "{arch}.rope.freq_base"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"

class TokenizerKeys(NamedTuple):
MODEL = "tokenizer.ggml.model"
LIST = "tokenizer.ggml.tokens"
TOKEN_TYPE = "tokenizer.ggml.token_type"
SCORES = "tokenizer.ggml.scores"
MERGES = "tokenizer.ggml.merges"
BOS_ID = "tokenizer.ggml.bos_token_id"
EOS_ID = "tokenizer.ggml.eos_token_id"
UNK_ID = "tokenizer.ggml.unknown_token_id"
SEP_ID = "tokenizer.ggml.seperator_token_id"
PAD_ID = "tokenizer.ggml.padding_token_id"
HF_JSON = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world"

class LLMKeys(NamedTuple):
CONTEXT_LENGTH = "{arch}.context_length"
EMBEDDING_LENGTH = "{arch}.embedding_length"
BLOCK_COUNT = "{arch}.block_count"
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
class GeneralKeys(StrEnum):
ARCHITECTURE : str = "general.architecture"
QUANTIZATION_VERSION: str = "general.quantization_version"
ALIGNMENT : str = "general.alignment"
NAME : str = "general.name"
AUTHOR : str = "general.author"
URL : str = "general.url"
DESCRIPTION : str = "general.description"
LICENSE : str = "general.license"
SOURCE_URL : str = "general.source.url"
SOURCE_HF_REPO : str = "general.source.huggingface.repository"
FILE_TYPE : str = "general.file_type"

class AttentionKeys(StrEnum):
HEAD_COUNT : str = "{arch}.attention.head_count"
HEAD_COUNT_KV : str = "{arch}.attention.head_count_kv"
MAX_ALIBI_BIAS : str = "{arch}.attention.max_alibi_bias"
CLAMP_KQV : str = "{arch}.attention.clamp_kqv"
LAYERNORM_EPS : str = "{arch}.attention.layer_norm_epsilon"
LAYERNORM_RMS_EPS: str = "{arch}.attention.layer_norm_rms_epsilon"

class RopeKeys(StrEnum):
DIMENSION_COUNT : str = "{arch}.rope.dimension_count"
FREQ_BASE : str = "{arch}.rope.freq_base"
SCALING_TYPE : str = "{arch}.rope.scaling.type"
SCALING_FACTOR : str = "{arch}.rope.scaling.factor"
SCALING_ORIG_CTX_LEN: str = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED : str = "{arch}.rope.scaling.finetuned"

class TokenizerKeys(StrEnum):
MODEL : str = "tokenizer.ggml.model"
LIST : str = "tokenizer.ggml.tokens"
TOKEN_TYPE: str = "tokenizer.ggml.token_type"
SCORES : str = "tokenizer.ggml.scores"
MERGES : str = "tokenizer.ggml.merges"
BOS_ID : str = "tokenizer.ggml.bos_token_id"
EOS_ID : str = "tokenizer.ggml.eos_token_id"
UNK_ID : str = "tokenizer.ggml.unknown_token_id"
SEP_ID : str = "tokenizer.ggml.seperator_token_id"
PAD_ID : str = "tokenizer.ggml.padding_token_id"
HF_JSON : str = "tokenizer.huggingface.json"
RWKV : str = "tokenizer.rwkv.world"

class LLMKeys(StrEnum):
CONTEXT_LENGTH : str = "{arch}.context_length"
EMBEDDING_LENGTH : str = "{arch}.embedding_length"
BLOCK_COUNT : str = "{arch}.block_count"
FEED_FORWARD_LENGTH : str = "{arch}.feed_forward_length"
USE_PARALLEL_RESIDUAL: str = "{arch}.use_parallel_residual"
TENSOR_DATA_LAYOUT : str = "{arch}.tensor_data_layout"

class Keys(NamedTuple):
GENERAL = GeneralKeys()
LLM = LLMKeys()
ATTENTION = AttentionKeys()
ROPE = RopeKeys()
TOKENIZER = TokenizerKeys()
GENERAL : Type[GeneralKeys ] = GeneralKeys
LLM : Type[LLMKeys ] = LLMKeys
ATTENTION: Type[AttentionKeys] = AttentionKeys
ROPE : Type[RopeKeys ] = RopeKeys
TOKENIZER: Type[TokenizerKeys] = TokenizerKeys

KEY = Keys()

Expand Down Expand Up @@ -418,52 +418,52 @@ def get_type(val: Any) -> GGUFValueType:
# Aliases for backward compatibility.

# general
KEY_GENERAL_ARCHITECTURE = KEY.GENERAL.ARCHITECTURE
KEY_GENERAL_QUANTIZATION_VERSION = KEY.GENERAL.QUANTIZATION_VERSION
KEY_GENERAL_ALIGNMENT = KEY.GENERAL.ALIGNMENT
KEY_GENERAL_NAME = KEY.GENERAL.NAME
KEY_GENERAL_AUTHOR = KEY.GENERAL.AUTHOR
KEY_GENERAL_URL = KEY.GENERAL.URL
KEY_GENERAL_DESCRIPTION = KEY.GENERAL.DESCRIPTION
KEY_GENERAL_LICENSE = KEY.GENERAL.LICENSE
KEY_GENERAL_SOURCE_URL = KEY.GENERAL.SOURCE_URL
KEY_GENERAL_SOURCE_HF_REPO = KEY.GENERAL.SOURCE_HF_REPO
KEY_GENERAL_FILE_TYPE = KEY.GENERAL.FILE_TYPE
KEY_GENERAL_ARCHITECTURE : str = KEY.GENERAL.ARCHITECTURE
KEY_GENERAL_QUANTIZATION_VERSION: str = KEY.GENERAL.QUANTIZATION_VERSION
KEY_GENERAL_ALIGNMENT : str = KEY.GENERAL.ALIGNMENT
KEY_GENERAL_NAME : str = KEY.GENERAL.NAME
KEY_GENERAL_AUTHOR : str = KEY.GENERAL.AUTHOR
KEY_GENERAL_URL : str = KEY.GENERAL.URL
KEY_GENERAL_DESCRIPTION : str = KEY.GENERAL.DESCRIPTION
KEY_GENERAL_LICENSE : str = KEY.GENERAL.LICENSE
KEY_GENERAL_SOURCE_URL : str = KEY.GENERAL.SOURCE_URL
KEY_GENERAL_SOURCE_HF_REPO : str = KEY.GENERAL.SOURCE_HF_REPO
KEY_GENERAL_FILE_TYPE : str = KEY.GENERAL.FILE_TYPE

# LLM
KEY_CONTEXT_LENGTH = KEY.LLM.CONTEXT_LENGTH
KEY_EMBEDDING_LENGTH = KEY.LLM.EMBEDDING_LENGTH
KEY_BLOCK_COUNT = KEY.LLM.BLOCK_COUNT
KEY_FEED_FORWARD_LENGTH = KEY.LLM.FEED_FORWARD_LENGTH
KEY_USE_PARALLEL_RESIDUAL = KEY.LLM.USE_PARALLEL_RESIDUAL
KEY_TENSOR_DATA_LAYOUT = KEY.LLM.TENSOR_DATA_LAYOUT
KEY_CONTEXT_LENGTH : str = KEY.LLM.CONTEXT_LENGTH
KEY_EMBEDDING_LENGTH : str = KEY.LLM.EMBEDDING_LENGTH
KEY_BLOCK_COUNT : str = KEY.LLM.BLOCK_COUNT
KEY_FEED_FORWARD_LENGTH : str = KEY.LLM.FEED_FORWARD_LENGTH
KEY_USE_PARALLEL_RESIDUAL: str = KEY.LLM.USE_PARALLEL_RESIDUAL
KEY_TENSOR_DATA_LAYOUT : str = KEY.LLM.TENSOR_DATA_LAYOUT

# attention
KEY_ATTENTION_HEAD_COUNT = KEY.ATTENTION.HEAD_COUNT
KEY_ATTENTION_HEAD_COUNT_KV = KEY.ATTENTION.HEAD_COUNT_KV
KEY_ATTENTION_MAX_ALIBI_BIAS = KEY.ATTENTION.MAX_ALIBI_BIAS
KEY_ATTENTION_CLAMP_KQV = KEY.ATTENTION.CLAMP_KQV
KEY_ATTENTION_LAYERNORM_EPS = KEY.ATTENTION.LAYERNORM_EPS
KEY_ATTENTION_LAYERNORM_RMS_EPS = KEY.ATTENTION.LAYERNORM_RMS_EPS
KEY_ATTENTION_HEAD_COUNT : str = KEY.ATTENTION.HEAD_COUNT
KEY_ATTENTION_HEAD_COUNT_KV : str = KEY.ATTENTION.HEAD_COUNT_KV
KEY_ATTENTION_MAX_ALIBI_BIAS : str = KEY.ATTENTION.MAX_ALIBI_BIAS
KEY_ATTENTION_CLAMP_KQV : str = KEY.ATTENTION.CLAMP_KQV
KEY_ATTENTION_LAYERNORM_EPS : str = KEY.ATTENTION.LAYERNORM_EPS
KEY_ATTENTION_LAYERNORM_RMS_EPS: str = KEY.ATTENTION.LAYERNORM_RMS_EPS

# RoPE
KEY_ROPE_DIMENSION_COUNT = KEY.ROPE.DIMENSION_COUNT
KEY_ROPE_FREQ_BASE = KEY.ROPE.FREQ_BASE
KEY_ROPE_SCALING_TYPE = KEY.ROPE.SCALING_TYPE
KEY_ROPE_SCALING_FACTOR = KEY.ROPE.SCALING_FACTOR
KEY_ROPE_SCALING_ORIG_CTX_LEN = KEY.ROPE.SCALING_ORIG_CTX_LEN
KEY_ROPE_SCALING_FINETUNED = KEY.ROPE.SCALING_FINETUNED
KEY_ROPE_DIMENSION_COUNT : str = KEY.ROPE.DIMENSION_COUNT
KEY_ROPE_FREQ_BASE : str = KEY.ROPE.FREQ_BASE
KEY_ROPE_SCALING_TYPE : str = KEY.ROPE.SCALING_TYPE
KEY_ROPE_SCALING_FACTOR : str = KEY.ROPE.SCALING_FACTOR
KEY_ROPE_SCALING_ORIG_CTX_LEN: str = KEY.ROPE.SCALING_ORIG_CTX_LEN
KEY_ROPE_SCALING_FINETUNED : str = KEY.ROPE.SCALING_FINETUNED

# tokenization
KEY_TOKENIZER_MODEL = KEY.TOKENIZER.MODEL
KEY_TOKENIZER_LIST = KEY.TOKENIZER.LIST
KEY_TOKENIZER_TOKEN_TYPE = KEY.TOKENIZER.TOKEN_TYPE
KEY_TOKENIZER_SCORES = KEY.TOKENIZER.SCORES
KEY_TOKENIZER_MERGES = KEY.TOKENIZER.MERGES
KEY_TOKENIZER_BOS_ID = KEY.TOKENIZER.BOS_ID
KEY_TOKENIZER_EOS_ID = KEY.TOKENIZER.EOS_ID
KEY_TOKENIZER_UNK_ID = KEY.TOKENIZER.UNK_ID
KEY_TOKENIZER_SEP_ID = KEY.TOKENIZER.SEP_ID
KEY_TOKENIZER_PAD_ID = KEY.TOKENIZER.PAD_ID
KEY_TOKENIZER_HF_JSON = KEY.TOKENIZER.HF_JSON
KEY_TOKENIZER_RWKV = KEY.TOKENIZER.RWKV
KEY_TOKENIZER_MODEL : str = KEY.TOKENIZER.MODEL
KEY_TOKENIZER_LIST : str = KEY.TOKENIZER.LIST
KEY_TOKENIZER_TOKEN_TYPE: str = KEY.TOKENIZER.TOKEN_TYPE
KEY_TOKENIZER_SCORES : str = KEY.TOKENIZER.SCORES
KEY_TOKENIZER_MERGES : str = KEY.TOKENIZER.MERGES
KEY_TOKENIZER_BOS_ID : str = KEY.TOKENIZER.BOS_ID
KEY_TOKENIZER_EOS_ID : str = KEY.TOKENIZER.EOS_ID
KEY_TOKENIZER_UNK_ID : str = KEY.TOKENIZER.UNK_ID
KEY_TOKENIZER_SEP_ID : str = KEY.TOKENIZER.SEP_ID
KEY_TOKENIZER_PAD_ID : str = KEY.TOKENIZER.PAD_ID
KEY_TOKENIZER_HF_JSON : str = KEY.TOKENIZER.HF_JSON
KEY_TOKENIZER_RWKV : str = KEY.TOKENIZER.RWKV
Loading

0 comments on commit d7688dc

Please sign in to comment.