diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 8381a8c..253e168 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -32,6 +32,7 @@ from ruamel.yaml import YAML +from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig from common.health import HealthManager from backends.exllamav2.grammar import ( @@ -43,6 +44,7 @@ hardware_supports_flash_attn, supports_paged_attn, ) +from common.tabby_config import config from common.concurrency import iterate_in_threadpool from common.gen_logging import ( log_generation_params, @@ -103,7 +105,12 @@ class ExllamaV2Container: load_condition: asyncio.Condition = asyncio.Condition() @classmethod - async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): + async def create( + cls, + model: ModelInstanceConfig, + draft: DraftModelInstanceConfig, + quiet=False, + ): """ Primary asynchronous initializer for model container. @@ -117,8 +124,15 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Initialize config self.config = ExLlamaV2Config() - self.model_dir = model_directory - self.config.model_dir = str(model_directory.resolve()) + + model_path = pathlib.Path(config.model.model_dir) + model_path = model_path / model.model_name + model_path = model_path.resolve() + if not model_path.exists(): + raise FileNotFoundError(f"Model path {model_path} does not exist.") + + self.model_dir = model_path + self.config.model_dir = str(model_path) # Make the max seq len 4096 before preparing the config # This is a better default than 2048 @@ -130,35 +144,23 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): self.config.arch_compat_overrides() # Prepare the draft model config if necessary - draft_args = unwrap(kwargs.get("draft"), {}) - draft_model_name = draft_args.get("draft_model_name") - enable_draft = draft_args and draft_model_name - - # Always disable draft if params are incorrectly configured - if draft_args and draft_model_name is None: - logger.warning( - "Draft model is disabled because a model name " - "wasn't provided. Please check your config.yml!" - ) - enable_draft = False - - if enable_draft: + if draft.draft_model_name: self.draft_config = ExLlamaV2Config() self.draft_config.no_flash_attn = self.config.no_flash_attn - draft_model_path = pathlib.Path( - unwrap(draft_args.get("draft_model_dir"), "models") + + draft_model_path = ( + config.draft_model.draft_model_dir / draft.draft_model_name ) - draft_model_path = draft_model_path / draft_model_name self.draft_model_dir = draft_model_path self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() # Create the hf_config - self.hf_config = await HuggingFaceConfig.from_file(model_directory) + self.hf_config = await HuggingFaceConfig.from_file(model_path) # Load generation config overrides - generation_config_path = model_directory / "generation_config.json" + generation_config_path = model_path / "generation_config.json" if generation_config_path.exists(): try: self.generation_config = await GenerationConfig.from_file( @@ -171,18 +173,20 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): ) # Apply a model's config overrides while respecting user settings - kwargs = await self.set_model_overrides(**kwargs) + + # FIXME: THIS IS BROKEN!!! + # kwargs do not exist now + # should be investigated after the models have pydantic stuff + # kwargs = await self.set_model_overrides(**kwargs) # MARK: User configuration # Get cache mode - self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16") + self.cache_mode = model.cache_mode # Turn off GPU split if the user is using 1 GPU gpu_count = torch.cuda.device_count() - gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) - use_tp = unwrap(kwargs.get("tensor_parallel"), False) - gpu_split = kwargs.get("gpu_split") + gpu_split_auto = model.gpu_split_auto gpu_device_list = list(range(0, gpu_count)) # Set GPU split options @@ -191,16 +195,16 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): logger.info("Disabling GPU split because one GPU is in use.") else: # Set tensor parallel - if use_tp: + if model.tensor_parallel: self.use_tp = True # TP has its own autosplit loader self.gpu_split_auto = False # Enable manual GPU split if provided - if gpu_split: + if model.gpu_split: self.gpu_split_auto = False - self.gpu_split = gpu_split + self.gpu_split = model.gpu_split gpu_device_list = [ device_idx @@ -211,9 +215,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Otherwise fallback to autosplit settings self.gpu_split_auto = gpu_split_auto - autosplit_reserve_megabytes = unwrap( - kwargs.get("autosplit_reserve"), [96] - ) + autosplit_reserve_megabytes = model.autosplit_reserve # Reserve VRAM for each GPU self.autosplit_reserve = [ @@ -225,37 +227,34 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): self.config.max_output_len = 16 # Then override the base_seq_len if present - override_base_seq_len = kwargs.get("override_base_seq_len") - if override_base_seq_len: - self.config.max_seq_len = override_base_seq_len + if model.override_base_seq_len: + self.config.max_seq_len = model.override_base_seq_len # Grab the base model's sequence length before overrides for # rope calculations base_seq_len = self.config.max_seq_len # Set the target seq len if present - target_max_seq_len = kwargs.get("max_seq_len") + target_max_seq_len = model.max_seq_len if target_max_seq_len: self.config.max_seq_len = target_max_seq_len # Set the rope scale - self.config.scale_pos_emb = unwrap( - kwargs.get("rope_scale"), self.config.scale_pos_emb - ) + self.config.scale_pos_emb = unwrap(model.rope_scale, self.config.scale_pos_emb) # Sets rope alpha value. # Automatically calculate if unset or defined as an "auto" literal. - rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto") + rope_alpha = unwrap(model.rope_alpha, "auto") if rope_alpha == "auto": self.config.scale_alpha_value = self.calculate_rope_alpha(base_seq_len) else: self.config.scale_alpha_value = rope_alpha # Enable fasttensors loading if present - self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False) + self.config.fasttensors = config.model.fasttensors # Set max batch size to the config override - self.max_batch_size = unwrap(kwargs.get("max_batch_size")) + self.max_batch_size = model.max_batch_size # Check whether the user's configuration supports flash/paged attention # Also check if exl2 has disabled flash attention @@ -272,7 +271,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Set k/v cache size # cache_size is only relevant when paged mode is enabled if self.paged: - cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len) + cache_size = unwrap(model.cache_size, self.config.max_seq_len) if cache_size < self.config.max_seq_len: logger.warning( @@ -314,7 +313,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): # Try to set prompt template self.prompt_template = await self.find_prompt_template( - kwargs.get("prompt_template"), model_directory + model.prompt_template, model.model_name ) # Catch all for template lookup errors @@ -329,29 +328,26 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): ) # Set num of experts per token if provided - num_experts_override = kwargs.get("num_experts_per_token") - if num_experts_override: - self.config.num_experts_per_token = kwargs.get("num_experts_per_token") + if model.num_experts_per_token: + self.config.num_experts_per_token = model.num_experts_per_token # Make sure chunk size is >= 16 and <= max seq length - user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048) + user_chunk_size = unwrap(model.chunk_size, 2048) chunk_size = sorted((16, user_chunk_size, self.config.max_seq_len))[1] self.config.max_input_len = chunk_size self.config.max_attention_size = chunk_size**2 # Set user-configured draft model values - if enable_draft: - # Fetch from the updated kwargs - draft_args = unwrap(kwargs.get("draft"), {}) + if draft.draft_model_name: self.draft_config.max_seq_len = self.config.max_seq_len self.draft_config.scale_pos_emb = unwrap( - draft_args.get("draft_rope_scale"), 1.0 + draft.draft_rope_scale, 1.0 ) # Set draft rope alpha. Follows same behavior as model rope alpha. - draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto") + draft_rope_alpha = unwrap(draft.draft_rope_alpha, "auto") if draft_rope_alpha == "auto": self.draft_config.scale_alpha_value = self.calculate_rope_alpha( self.draft_config.max_seq_len @@ -360,7 +356,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): self.draft_config.scale_alpha_value = draft_rope_alpha # Set draft cache mode - self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16") + self.draft_cache_mode = draft.draft_cache_mode if chunk_size: self.draft_config.max_input_len = chunk_size @@ -524,7 +520,7 @@ def progress(loaded_modules: int, total_modules: int) async for _ in self.load_gen(progress_callback): pass - async def load_gen(self, progress_callback=None, **kwargs): + async def load_gen(self, progress_callback=None, skip_wait=False): """Loads a model and streams progress via a generator.""" # Indicate that model load has started @@ -534,7 +530,7 @@ async def load_gen(self, progress_callback=None, **kwargs): self.model_is_loading = True # Wait for existing generation jobs to finish - await self.wait_for_jobs(kwargs.get("skip_wait")) + await self.wait_for_jobs(skip_wait) # Streaming gen for model load progress model_load_generator = self.load_model_sync(progress_callback) @@ -1130,19 +1126,19 @@ async def generate_gen( grammar_handler = ExLlamaV2Grammar() # Add JSON schema filter if it exists - json_schema = unwrap(kwargs.get("json_schema")) + json_schema = kwargs.get("json_schema") if json_schema: grammar_handler.add_json_schema_filter( json_schema, self.model, self.tokenizer ) # Add regex filter if it exists - regex_pattern = unwrap(kwargs.get("regex_pattern")) + regex_pattern = kwargs.get("regex_pattern") if regex_pattern: grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer) # Add EBNF filter if it exists - grammar_string = unwrap(kwargs.get("grammar_string")) + grammar_string = kwargs.get("grammar_string") if grammar_string: grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer) diff --git a/backends/exllamav2/types.py b/backends/exllamav2/types.py new file mode 100644 index 0000000..10a1ee0 --- /dev/null +++ b/backends/exllamav2/types.py @@ -0,0 +1,187 @@ +from typing import List, Literal, Optional, Union +from pydantic import BaseModel, ConfigDict, Field + +CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] + + +class DraftModelInstanceConfig(BaseModel): + draft_model_name: Optional[str] = Field( + None, + description=( + "An initial draft model to load.\n" + "Ensure the model is in the model directory." + ), + ) + draft_rope_scale: float = Field( + 1.0, + description=( + "Rope scale for draft models (default: 1.0).\n" + "Same as compress_pos_emb.\n" + "Use if the draft model was trained on long context with rope." + ), + ) + draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( + None, + description=( + "Rope alpha for draft models (default: None).\n" + 'Same as alpha_value. Set to "auto" to auto-calculate.\n' + "Leaving this value blank will either pull from the model " + "or auto-calculate." + ), + examples=[1.0], + ) + draft_cache_mode: CACHE_SIZES = Field( + "FP16", + description=( + "Cache mode for draft models to save VRAM (default: FP16).\n" + f"Possible values: {str(CACHE_SIZES)[15:-1]}." + ), + ) + + model_config = ConfigDict(revalidate_instances="always") + + +class ModelInstanceConfig(BaseModel): + """ + Options for model overrides and loading + Please read the comments to understand how arguments are handled + between initial and API loads + """ + + model_name: Optional[str] = Field( + None, + description=( + "An initial model to load.\n" + "Make sure the model is located in the model directory!\n" + "REQUIRED: This must be filled out to load a model on startup." + ), + ) + max_seq_len: Optional[int] = Field( + None, + description=( + "Max sequence length (default: Empty).\n" + "Fetched from the model's base sequence length in config.json by default." + ), + ge=0, + examples=[16384, 4096, 2048], + ) + override_base_seq_len: Optional[int] = Field( + None, + description=( + "Overrides base model context length (default: Empty).\n" + "WARNING: Don't set this unless you know what you're doing!\n" + "Again, do NOT use this for configuring context length, " + "use max_seq_len above ^" + ), + ge=0, + examples=[4096], + ) + tensor_parallel: bool = Field( + False, + description=( + "Load model with tensor parallelism.\n" + "Falls back to autosplit if GPU split isn't provided.\n" + "This ignores the gpu_split_auto value." + ), + ) + gpu_split_auto: bool = Field( + True, + description=( + "Automatically allocate resources to GPUs (default: True).\n" + "Not parsed for single GPU users." + ), + ) + autosplit_reserve: List[int] = Field( + [96], + description=( + "Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).\n" + "Represented as an array of MB per GPU." + ), + ) + gpu_split: List[float] = Field( + default_factory=list, + description=( + "An integer array of GBs of VRAM to split between GPUs (default: []).\n" + "Used with tensor parallelism." + ), + ) + rope_scale: float = Field( + 1.0, + description=( + "Rope scale (default: 1.0).\n" + "Same as compress_pos_emb.\n" + "Use if the model was trained on long context with rope.\n" + "Leave blank to pull the value from the model." + ), + examples=[1.0], + ) + rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( + "auto", + description=( + "Rope alpha (default: None).\n" + 'Same as alpha_value. Set to "auto" to auto-calculate.\n' + "Leaving this value blank will either pull from the model " + "or auto-calculate." + ), + examples=["auto", 1.0], + ) + cache_mode: CACHE_SIZES = Field( + "FP16", + description=( + "Enable different cache modes for VRAM savings (default: FP16).\n" + f"Possible values: {str(CACHE_SIZES)[15:-1]}." + ), + ) + cache_size: Optional[int] = Field( + None, + description=( + "Size of the prompt cache to allocate (default: max_seq_len).\n" + "Must be a multiple of 256 and can't be less than max_seq_len.\n" + "For CFG, set this to 2 * max_seq_len." + ), + multiple_of=256, + gt=0, + examples=[4096], + ) + chunk_size: int = Field( + 2048, + description=( + "Chunk size for prompt ingestion (default: 2048).\n" + "A lower value reduces VRAM usage but decreases ingestion speed.\n" + "NOTE: Effects vary depending on the model.\n" + "An ideal value is between 512 and 4096." + ), + gt=0, + ) + max_batch_size: Optional[int] = Field( + None, + description=( + "Set the maximum number of prompts to process at one time " + "(default: None/Automatic).\n" + "Automatically calculated if left blank.\n" + "NOTE: Only available for Nvidia ampere (30 series) and above GPUs." + ), + ge=1, + ) + prompt_template: Optional[str] = Field( + None, + description=( + "Set the prompt template for this model. (default: None)\n" + "If empty, attempts to look for the model's chat template.\n" + "If a model contains multiple templates in its tokenizer_config.json,\n" + "set prompt_template to the name of the template you want to use.\n" + "NOTE: Only works with chat completion message lists!" + ), + ) + num_experts_per_token: Optional[int] = Field( + None, + description=( + "Number of experts to use per token.\n" + "Fetched from the model's config.json if empty.\n" + "NOTE: For MoE models only.\n" + "WARNING: Don't set this unless you know what you're doing!" + ), + ge=1, + ) + + model_config = ConfigDict(protected_namespaces=(), revalidate_instances="always") \ No newline at end of file diff --git a/common/auth.py b/common/auth.py index b02cdd0..3a1f26e 100644 --- a/common/auth.py +++ b/common/auth.py @@ -3,16 +3,18 @@ application, it should be fine. """ +from functools import partial import aiofiles import io import secrets from ruamel.yaml import YAML from fastapi import Header, HTTPException, Request -from pydantic import BaseModel +from pydantic import BaseModel, Field, SecretStr from loguru import logger from typing import Optional from common.utils import coalesce +from common.tabby_config import config class AuthKeys(BaseModel): @@ -24,32 +26,33 @@ class AuthKeys(BaseModel): to verify if a given key matches the stored 'api_key' or 'admin_key'. """ - api_key: str - admin_key: str + api_key: SecretStr = Field(default_factory=partial(secrets.token_hex, 16)) + admin_key: SecretStr = Field(default_factory=partial(secrets.token_hex, 16)) def verify_key(self, test_key: str, key_type: str): """Verify if a given key matches the stored key.""" + if key_type == "admin_key": - return test_key == self.admin_key + return test_key == self.admin_key.get_secret_value() if key_type == "api_key": # Admin keys are valid for all API calls - return test_key == self.api_key or test_key == self.admin_key + return ( + test_key == self.api_key.get_secret_value() + or test_key == self.admin_key.get_secret_value() + ) return False # Global auth constants AUTH_KEYS: Optional[AuthKeys] = None -DISABLE_AUTH: bool = False -async def load_auth_keys(disable_from_config: bool): +async def load_auth_keys(): """Load the authentication keys from api_tokens.yml. If the file does not exist, generate new keys and save them to api_tokens.yml.""" global AUTH_KEYS - global DISABLE_AUTH - DISABLE_AUTH = disable_from_config - if disable_from_config: + if config.network.disable_auth: logger.warning( "Disabling authentication makes your instance vulnerable. " "Set the `disable_auth` flag to False in config.yml if you " @@ -67,9 +70,7 @@ async def load_auth_keys(disable_from_config: bool): auth_keys_dict = yaml.load(contents) AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict) except FileNotFoundError: - new_auth_keys = AuthKeys( - api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16) - ) + new_auth_keys = AuthKeys() AUTH_KEYS = new_auth_keys async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file: @@ -79,8 +80,8 @@ async def load_auth_keys(disable_from_config: bool): await auth_file.write(string_stream.getvalue()) logger.info( - f"Your API key is: {AUTH_KEYS.api_key}\n" - f"Your admin key is: {AUTH_KEYS.admin_key}\n\n" + f"Your API key is: {AUTH_KEYS.api_key.get_secret_value()}\n" + f"Your admin key is: {AUTH_KEYS.admin_key.get_secret_value()}\n\n" "If these keys get compromised, make sure to delete api_tokens.yml " "and restart the server. Have fun!" ) @@ -94,7 +95,7 @@ def get_key_permission(request: Request): """ # Give full admin permissions if auth is disabled - if DISABLE_AUTH: + if config.network.disable_auth: return "admin" # Hyphens are okay here @@ -124,7 +125,7 @@ async def check_api_key( """Check if the API key is valid.""" # Allow request if auth is disabled - if DISABLE_AUTH: + if config.network.disable_auth: return if x_api_key: @@ -152,7 +153,7 @@ async def check_admin_key( """Check if the admin key is valid.""" # Allow request if auth is disabled - if DISABLE_AUTH: + if config.network.disable_auth: return if x_admin_key: diff --git a/common/config_models.py b/common/config_models.py index 79d774f..13d6de3 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -1,21 +1,24 @@ -from pathlib import Path from pydantic import ( BaseModel, ConfigDict, + DirectoryPath, Field, + FilePath, PrivateAttr, field_validator, ) -from typing import List, Literal, Optional, Union - +from typing import List, Literal, Optional +from pathlib import Path -CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] +from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig class Metadata(BaseModel): """metadata model for config options""" - include_in_config: Optional[bool] = Field(True) + include_in_config: bool = Field( + True, description="if the model is included by the config file generator" + ) class BaseConfigModel(BaseModel): @@ -27,8 +30,7 @@ class BaseConfigModel(BaseModel): class ConfigOverrideConfig(BaseConfigModel): """Model for overriding a provided config file.""" - # TODO: convert this to a pathlib.path? - config: Optional[str] = Field( + config: Optional[FilePath] = Field( None, description=("Path to an overriding config.yml file") ) @@ -39,18 +41,14 @@ class UtilityActions(BaseConfigModel): """Model used for arg actions.""" # YAML export options - export_config: Optional[str] = Field( - None, description="generate a template config file" - ) - config_export_path: Optional[Path] = Field( + export_config: bool = Field(False, description="generate a template config file") + config_export_path: Path = Field( "config_sample.yml", description="path to export configuration file to" ) # OpenAPI JSON export options - export_openapi: Optional[bool] = Field( - False, description="export openapi schema files" - ) - openapi_export_path: Optional[Path] = Field( + export_openapi: bool = Field(False, description="export openapi schema files") + openapi_export_path: Path = Field( "openapi.json", description="path to export openapi schema to" ) @@ -60,17 +58,16 @@ class UtilityActions(BaseConfigModel): class NetworkConfig(BaseConfigModel): """Options for networking""" - host: Optional[str] = Field( + # TODO: convert to IPvAnyAddress? + host: str = Field( "127.0.0.1", description=( "The IP to host on (default: 127.0.0.1).\n" "Use 0.0.0.0 to expose on all network adapters." ), ) - port: Optional[int] = Field( - 5000, description=("The port to host on (default: 5000).") - ) - disable_auth: Optional[bool] = Field( + port: int = Field(5000, description=("The port to host on (default: 5000).")) + disable_auth: bool = Field( False, description=( "Disable HTTP token authentication with requests.\n" @@ -78,14 +75,14 @@ class NetworkConfig(BaseConfigModel): "Turn on this option if you are ONLY connecting from localhost." ), ) - send_tracebacks: Optional[bool] = Field( + send_tracebacks: bool = Field( False, description=( "Send tracebacks over the API (default: False).\n" "NOTE: Only enable this for debug purposes." ), ) - api_servers: Optional[List[Literal["oai", "kobold"]]] = Field( + api_servers: List[Literal["oai", "kobold"]] = Field( ["OAI"], description=( 'Select API servers to enable (default: ["OAI"]).\n' @@ -105,15 +102,15 @@ def api_server_validator(cls, api_servers): class LoggingConfig(BaseConfigModel): """Options for logging""" - log_prompt: Optional[bool] = Field( + log_prompt: bool = Field( False, description=("Enable prompt logging (default: False)."), ) - log_generation_params: Optional[bool] = Field( + log_generation_params: bool = Field( False, description=("Enable generation parameter logging (default: False)."), ) - log_requests: Optional[bool] = Field( + log_requests: bool = Field( False, description=( "Enable request logging (default: False).\n" @@ -122,43 +119,34 @@ class LoggingConfig(BaseConfigModel): ) -class ModelConfig(BaseConfigModel): +class ModelConfig(BaseConfigModel, ModelInstanceConfig): """ Options for model overrides and loading Please read the comments to understand how arguments are handled between initial and API loads """ - # TODO: convert this to a pathlib.path? - model_dir: str = Field( + model_dir: DirectoryPath = Field( "models", description=( "Directory to look for models (default: models).\n" "Windows users, do NOT put this path in quotes!" ), ) - inline_model_loading: Optional[bool] = Field( + inline_model_loading: bool = Field( False, description=( "Allow direct loading of models " "from a completion or chat completion request (default: False)." ), ) - use_dummy_models: Optional[bool] = Field( + use_dummy_models: bool = Field( False, description=( "Sends dummy model names when the models endpoint is queried.\n" "Enable this if the client is looking for specific OAI models." ), ) - model_name: Optional[str] = Field( - None, - description=( - "An initial model to load.\n" - "Make sure the model is located in the model directory!\n" - "REQUIRED: This must be filled out to load a model on startup." - ), - ) use_as_default: List[str] = Field( default_factory=list, description=( @@ -168,129 +156,7 @@ class ModelConfig(BaseConfigModel): "Example: ['max_seq_len', 'cache_mode']." ), ) - max_seq_len: Optional[int] = Field( - None, - description=( - "Max sequence length (default: Empty).\n" - "Fetched from the model's base sequence length in config.json by default." - ), - ge=0, - ) - override_base_seq_len: Optional[int] = Field( - None, - description=( - "Overrides base model context length (default: Empty).\n" - "WARNING: Don't set this unless you know what you're doing!\n" - "Again, do NOT use this for configuring context length, " - "use max_seq_len above ^" - ), - ge=0, - ) - tensor_parallel: Optional[bool] = Field( - False, - description=( - "Load model with tensor parallelism.\n" - "Falls back to autosplit if GPU split isn't provided.\n" - "This ignores the gpu_split_auto value." - ), - ) - gpu_split_auto: Optional[bool] = Field( - True, - description=( - "Automatically allocate resources to GPUs (default: True).\n" - "Not parsed for single GPU users." - ), - ) - autosplit_reserve: List[int] = Field( - [96], - description=( - "Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).\n" - "Represented as an array of MB per GPU." - ), - ) - gpu_split: List[float] = Field( - default_factory=list, - description=( - "An integer array of GBs of VRAM to split between GPUs (default: []).\n" - "Used with tensor parallelism." - ), - ) - rope_scale: Optional[float] = Field( - 1.0, - description=( - "Rope scale (default: 1.0).\n" - "Same as compress_pos_emb.\n" - "Use if the model was trained on long context with rope.\n" - "Leave blank to pull the value from the model." - ), - ) - rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( - None, - description=( - "Rope alpha (default: None).\n" - 'Same as alpha_value. Set to "auto" to auto-calculate.\n' - "Leaving this value blank will either pull from the model " - "or auto-calculate." - ), - ) - cache_mode: Optional[CACHE_SIZES] = Field( - "FP16", - description=( - "Enable different cache modes for VRAM savings (default: FP16).\n" - f"Possible values: {str(CACHE_SIZES)[15:-1]}." - ), - ) - cache_size: Optional[int] = Field( - None, - description=( - "Size of the prompt cache to allocate (default: max_seq_len).\n" - "Must be a multiple of 256 and can't be less than max_seq_len.\n" - "For CFG, set this to 2 * max_seq_len." - ), - multiple_of=256, - gt=0, - ) - chunk_size: Optional[int] = Field( - 2048, - description=( - "Chunk size for prompt ingestion (default: 2048).\n" - "A lower value reduces VRAM usage but decreases ingestion speed.\n" - "NOTE: Effects vary depending on the model.\n" - "An ideal value is between 512 and 4096." - ), - gt=0, - ) - max_batch_size: Optional[int] = Field( - None, - description=( - "Set the maximum number of prompts to process at one time " - "(default: None/Automatic).\n" - "Automatically calculated if left blank.\n" - "NOTE: Only available for Nvidia ampere (30 series) and above GPUs." - ), - ge=1, - ) - prompt_template: Optional[str] = Field( - None, - description=( - "Set the prompt template for this model. (default: None)\n" - "If empty, attempts to look for the model's chat template.\n" - "If a model contains multiple templates in its tokenizer_config.json,\n" - "set prompt_template to the name of the template you want to use.\n" - "NOTE: Only works with chat completion message lists!" - ), - ) - num_experts_per_token: Optional[int] = Field( - None, - description=( - "Number of experts to use per token.\n" - "Fetched from the model's config.json if empty.\n" - "NOTE: For MoE models only.\n" - "WARNING: Don't set this unless you know what you're doing!" - ), - ge=1, - ) - fasttensors: Optional[bool] = Field( + fasttensors: bool = Field( False, description=( "Enables fasttensors to possibly increase model loading speeds " @@ -302,48 +168,16 @@ class ModelConfig(BaseConfigModel): model_config = ConfigDict(protected_namespaces=()) -class DraftModelConfig(BaseConfigModel): +class DraftModelConfig(BaseConfigModel, DraftModelInstanceConfig): """ Options for draft models (speculative decoding) This will use more VRAM! """ - # TODO: convert this to a pathlib.path? - draft_model_dir: Optional[str] = Field( + draft_model_dir: DirectoryPath = Field( "models", description=("Directory to look for draft models (default: models)"), ) - draft_model_name: Optional[str] = Field( - None, - description=( - "An initial draft model to load.\n" - "Ensure the model is in the model directory." - ), - ) - draft_rope_scale: Optional[float] = Field( - 1.0, - description=( - "Rope scale for draft models (default: 1.0).\n" - "Same as compress_pos_emb.\n" - "Use if the draft model was trained on long context with rope." - ), - ) - draft_rope_alpha: Optional[float] = Field( - None, - description=( - "Rope alpha for draft models (default: None).\n" - 'Same as alpha_value. Set to "auto" to auto-calculate.\n' - "Leaving this value blank will either pull from the model " - "or auto-calculate." - ), - ) - draft_cache_mode: Optional[CACHE_SIZES] = Field( - "FP16", - description=( - "Cache mode for draft models to save VRAM (default: FP16).\n" - f"Possible values: {str(CACHE_SIZES)[15:-1]}." - ), - ) class LoraInstanceModel(BaseConfigModel): @@ -357,7 +191,7 @@ class LoraConfig(BaseConfigModel): """Options for Loras""" # TODO: convert this to a pathlib.path? - lora_dir: Optional[str] = Field( + lora_dir: DirectoryPath = Field( "loras", description=("Directory to look for LoRAs (default: loras).") ) loras: Optional[List[LoraInstanceModel]] = Field( @@ -379,12 +213,11 @@ class EmbeddingsConfig(BaseConfigModel): Install it via "pip install .[extras]" """ - # TODO: convert this to a pathlib.path? - embedding_model_dir: Optional[str] = Field( + embedding_model_dir: DirectoryPath = Field( "models", description=("Directory to look for embedding models (default: models)."), ) - embeddings_device: Optional[Literal["cpu", "auto", "cuda"]] = Field( + embeddings_device: Literal["cpu", "auto", "cuda"] = Field( "cpu", description=( "Device to load embedding models on (default: cpu).\n" @@ -416,7 +249,7 @@ class SamplingConfig(BaseConfigModel): class DeveloperConfig(BaseConfigModel): """Options for development and experimentation""" - unsafe_launch: Optional[bool] = Field( + unsafe_launch: bool = Field( False, description=( "Skip Exllamav2 version check (default: False).\n" @@ -424,13 +257,13 @@ class DeveloperConfig(BaseConfigModel): "than enabling this flag." ), ) - disable_request_streaming: Optional[bool] = Field( + disable_request_streaming: bool = Field( False, description=("Disable API request streaming (default: False).") ) - cuda_malloc_backend: Optional[bool] = Field( + cuda_malloc_backend: bool = Field( False, description=("Enable the torch CUDA malloc backend (default: False).") ) - uvloop: Optional[bool] = Field( + uvloop: bool = Field( False, description=( "Run asyncio using Uvloop or Winloop which can improve performance.\n" @@ -438,7 +271,7 @@ class DeveloperConfig(BaseConfigModel): "turn this off." ), ) - realtime_process_priority: Optional[bool] = Field( + realtime_process_priority: bool = Field( False, description=( "Set process to use a higher priority.\n" @@ -451,31 +284,15 @@ class DeveloperConfig(BaseConfigModel): class TabbyConfigModel(BaseModel): """Base model for a TabbyConfig.""" - config: Optional[ConfigOverrideConfig] = Field( - default_factory=ConfigOverrideConfig.model_construct - ) - network: Optional[NetworkConfig] = Field( - default_factory=NetworkConfig.model_construct - ) - logging: Optional[LoggingConfig] = Field( - default_factory=LoggingConfig.model_construct - ) - model: Optional[ModelConfig] = Field(default_factory=ModelConfig.model_construct) - draft_model: Optional[DraftModelConfig] = Field( - default_factory=DraftModelConfig.model_construct - ) - lora: Optional[LoraConfig] = Field(default_factory=LoraConfig.model_construct) - embeddings: Optional[EmbeddingsConfig] = Field( - default_factory=EmbeddingsConfig.model_construct - ) - sampling: Optional[SamplingConfig] = Field( - default_factory=SamplingConfig.model_construct - ) - developer: Optional[DeveloperConfig] = Field( - default_factory=DeveloperConfig.model_construct - ) - actions: Optional[UtilityActions] = Field( - default_factory=UtilityActions.model_construct - ) + config: ConfigOverrideConfig = Field(default_factory=ConfigOverrideConfig) + network: NetworkConfig = Field(default_factory=NetworkConfig) + logging: LoggingConfig = Field(default_factory=LoggingConfig) + model: ModelConfig = Field(default_factory=ModelConfig) + draft_model: DraftModelConfig = Field(default_factory=DraftModelConfig) + lora: LoraConfig = Field(default_factory=LoraConfig) + embeddings: EmbeddingsConfig = Field(default_factory=EmbeddingsConfig) + sampling: SamplingConfig = Field(default_factory=SamplingConfig) + developer: DeveloperConfig = Field(default_factory=DeveloperConfig) + actions: UtilityActions = Field(default_factory=UtilityActions) model_config = ConfigDict(validate_assignment=True, protected_namespaces=()) diff --git a/common/downloader.py b/common/downloader.py index 6813e0d..2915842 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -76,9 +76,9 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str """Gets the download folder for the repo.""" if repo_type == "lora": - download_path = pathlib.Path(config.lora.lora_dir) + download_path = config.lora.lora_dir else: - download_path = pathlib.Path(config.model.model_dir) + download_path = config.model.model_dir download_path = download_path / (folder_name or repo_id.split("/")[-1]) return download_path diff --git a/common/logger.py b/common/logger.py index f21ab09..c4d5795 100644 --- a/common/logger.py +++ b/common/logger.py @@ -61,7 +61,7 @@ def _log_formatter(record: dict): message = unwrap(record.get("message"), "") # Replace once loguru allows for turning off str.format - message = message.replace("{", "{{").replace("}", "}}").replace("<", "\<") + message = message.replace(r"{", r"{{").replace(r"}", r"}}").replace(r"<", r"\<") # Escape markup tags from Rich message = escape(message) diff --git a/common/model.py b/common/model.py index 87b06ad..9937537 100644 --- a/common/model.py +++ b/common/model.py @@ -10,10 +10,11 @@ from loguru import logger from typing import Optional +from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig from common.logger import get_loading_progress_bar from common.networking import handle_request_error -from common.tabby_config import config from common.optional_dependencies import dependencies +from common.tabby_config import config if dependencies.exllamav2: from backends.exllamav2.model import ExllamaV2Container @@ -48,7 +49,11 @@ async def unload_model(skip_wait: bool = False, shutdown: bool = False): container = None -async def load_model_gen(model_path: pathlib.Path, **kwargs): +async def load_model_gen( + model: ModelInstanceConfig, + draft: Optional[DraftModelInstanceConfig] = None, + skip_wait: bool = False, +): """Generator to load a model""" global container @@ -56,7 +61,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): if container and container.model: loaded_model_name = container.model_dir.name - if loaded_model_name == model_path.name and container.model_loaded: + if loaded_model_name == model.model_name and container.model_loaded: raise ValueError( f'Model "{loaded_model_name}" is already loaded! Aborting.' ) @@ -65,13 +70,18 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): await unload_model() # Merge with config defaults - kwargs = {**config.model_defaults, **kwargs} + model = model.model_copy(update=config.model_defaults) + model.model_validate(model, strict=True) # Create a new container - container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs) + draft = draft or DraftModelInstanceConfig() + + container = await ExllamaV2Container.create( + model=model, draft=draft, quiet=False + ) model_type = "draft" if container.draft_config else "model" - load_status = container.load_gen(load_progress, **kwargs) + load_status = container.load_gen(load_progress, skip_wait) progress = get_loading_progress_bar() progress.start() @@ -97,8 +107,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): progress.stop() -async def load_model(model_path: pathlib.Path, **kwargs): - async for _ in load_model_gen(model_path, **kwargs): +async def load_model( + model: ModelInstanceConfig, draft: Optional[DraftModelInstanceConfig] = None +): + async for _ in load_model_gen(model=model, draft=draft): pass diff --git a/common/tabby_config.py b/common/tabby_config.py index d41cc64..bca0802 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -37,9 +37,10 @@ def load(self, arguments: Optional[dict] = None): # This should be less expensive than pruning the entire merged dictionary configs = filter_none_values(configs) merged_config = merge_dicts(*configs) + merged_config = filter_none_values(merged_config) # validate and update config - merged_config_model = TabbyConfigModel.model_validate(merged_config) + merged_config_model = TabbyConfigModel(**merged_config) for field in TabbyConfigModel.model_fields.keys(): value = getattr(merged_config_model, field) setattr(self, field, value) @@ -106,7 +107,8 @@ def _from_file(self, config_path: pathlib.Path): ) # Create a temporary base config model - new_cfg = TabbyConfigModel.model_validate(cfg) + cfg = filter_none_values(cfg) + new_cfg = TabbyConfigModel(**cfg) try: config_path.rename(f"{config_path}.bak") @@ -175,11 +177,12 @@ def _from_environment(self): def generate_config_file( - model: BaseModel = None, - filename: str = "config_sample.yml", + model: Optional[BaseModel] = None, + filename: Optional[pathlib.Path] = None, ) -> None: """Creates a config.yml file from Pydantic models.""" + file = unwrap(filename, "config_sample.yml") schema = unwrap(model, TabbyConfigModel()) preamble = """ # Sample YAML file for configuration. @@ -193,7 +196,7 @@ def generate_config_file( yaml_content = pydantic_model_to_yaml(schema) - with open(filename, "w") as f: + with open(file, "w") as f: f.write(dedent(preamble).lstrip()) yaml.dump(yaml_content, f) diff --git a/common/utils.py b/common/utils.py index 97ecaf7..17c74b5 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,12 +1,14 @@ """Common utility functions""" from types import NoneType -from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar +from typing import Dict, Type, Union, get_args, get_origin, TypeVar +from pydantic import BaseModel T = TypeVar("T") +M = TypeVar("M", bound=BaseModel) -def unwrap(wrapped: Optional[T], default: T = None) -> T: +def unwrap(wrapped: Type[T], default: Type[T]) -> T: """Unwrap function for Optionals.""" if wrapped is None: return default @@ -85,3 +87,7 @@ def unwrap_optional_type(type_hint) -> Type: return arg return type_hint + + +def cast_model(model: BaseModel, new: Type[M]) -> M: + return new(**model.model_dump()) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index c8b02c8..8ec83af 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -13,6 +13,7 @@ from loguru import logger +from backends.exllamav2.types import ModelInstanceConfig from common import model from common.auth import get_key_permission from common.networking import ( @@ -138,19 +139,17 @@ async def load_inline_model(model_name: str, request: Request): return - model_path = pathlib.Path(config.model.model_dir) - model_path = model_path / model_name - # Model path doesn't exist - if not model_path.exists(): - logger.warning( - f"Could not find model path {str(model_path)}. Skipping inline model load." - ) + # if not model_path.exists(): + # logger.warning( + # f"Could not find model path {str(model_path)}." + + # "Skipping inline model load." + # ) - return + # return # Load the model - await model.load_model(model_path) + await model.load_model(ModelInstanceConfig(model_name=model_name)) async def stream_generate_completion( diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 2c60cd7..94de914 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -71,13 +71,10 @@ async def list_models(request: Request) -> ModelList: Requires an admin key to see all models. """ - model_dir = config.model.model_dir - model_path = pathlib.Path(model_dir) - - draft_model_dir = config.draft_model.draft_model_dir - if get_key_permission(request) == "admin": - models = get_model_list(model_path.resolve(), draft_model_dir) + models = get_model_list( + config.model.model_dir, config.draft_model.draft_model_dir + ) else: models = await get_current_model_list() @@ -110,7 +107,7 @@ async def list_draft_models(request: Request) -> ModelList: draft_model_dir = config.draft_model.draft_model_dir draft_model_path = pathlib.Path(draft_model_dir) - models = get_model_list(draft_model_path.resolve()) + models = get_model_list(draft_model_path) else: models = await get_current_model_list(model_type="draft") @@ -123,7 +120,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: """Loads a model into the model container. This returns an SSE stream.""" # Verify request parameters - if not data.name: + if not data.model_name: error_message = handle_request_error( "A model name was not provided for load.", exc_info=False, @@ -131,10 +128,6 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: raise HTTPException(400, error_message) - model_path = pathlib.Path(config.model.model_dir) - model_path = model_path / data.name - - draft_model_path = None if data.draft: if not data.draft.draft_model_name: error_message = handle_request_error( @@ -144,19 +137,15 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: raise HTTPException(400, error_message) - draft_model_path = config.draft_model.draft_model_dir + # if not model_path.exists(): + # error_message = handle_request_error( + # "Could not find the model path for load. Check model name or config.yml?", + # exc_info=False, + # ).error.message - if not model_path.exists(): - error_message = handle_request_error( - "Could not find the model path for load. Check model name or config.yml?", - exc_info=False, - ).error.message + # raise HTTPException(400, error_message) - raise HTTPException(400, error_message) - - return EventSourceResponse( - stream_model_load(data, model_path, draft_model_path), ping=maxsize - ) + return EventSourceResponse(stream_model_load(data), ping=maxsize) # Unload model endpoint @@ -278,7 +267,7 @@ async def list_embedding_models(request: Request) -> ModelList: embedding_model_dir = config.embeddings.embedding_model_dir embedding_model_path = pathlib.Path(embedding_model_dir) - models = get_model_list(embedding_model_path.resolve()) + models = get_model_list(embedding_model_path) else: models = await get_current_model_list(model_type="embedding") diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index b169162..a887c83 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -1,9 +1,10 @@ """Contains model card types.""" -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, Field, ConfigDict, model_validator from time import time -from typing import List, Literal, Optional, Union +from typing import List, Optional +from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig from common.config_models import LoggingConfig from common.tabby_config import config @@ -44,74 +45,31 @@ class ModelList(BaseModel): data: List[ModelCard] = Field(default_factory=list) -class DraftModelLoadRequest(BaseModel): - """Represents a draft model load request.""" - - # Required - draft_model_name: str - - # Config arguments - draft_rope_scale: Optional[float] = None - draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( - description='Automatically calculated if set to "auto"', - default=None, - examples=[1.0], - ) - draft_cache_mode: Optional[str] = None - - -class ModelLoadRequest(BaseModel): +class ModelLoadRequest(ModelInstanceConfig): """Represents a model load request.""" - # Required - name: str - - # Config arguments - - max_seq_len: Optional[int] = Field( - description="Leave this blank to use the model's base sequence length", - default=None, - examples=[4096], + # These Fields only exist to stop a breaking change + name: Optional[str] = Field( + None, description="model name to load", deprecated="Use model_name instead" ) - override_base_seq_len: Optional[int] = Field( - description=( - "Overrides the model's base sequence length. " "Leave blank if unsure" - ), - default=None, - examples=[4096], + fasttensors: Optional[bool] = Field( + None, + description="ignored, set globally from config.yml", + deprecated="Use model config instead", ) - cache_size: Optional[int] = Field( - description=("Number in tokens, must be greater than or equal to max_seq_len"), - default=None, - examples=[4096], - ) - tensor_parallel: Optional[bool] = None - gpu_split_auto: Optional[bool] = None - autosplit_reserve: Optional[List[float]] = None - gpu_split: Optional[List[float]] = Field( - default=None, - examples=[[24.0, 20.0]], - ) - rope_scale: Optional[float] = Field( - description="Automatically pulled from the model's config if not present", - default=None, - examples=[1.0], - ) - rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( - description='Automatically calculated if set to "auto"', - default=None, - examples=[1.0], - ) - cache_mode: Optional[str] = None - chunk_size: Optional[int] = None - prompt_template: Optional[str] = None - num_experts_per_token: Optional[int] = None - fasttensors: Optional[bool] = None # Non-config arguments - draft: Optional[DraftModelLoadRequest] = None + draft: Optional[DraftModelInstanceConfig] = None skip_queue: Optional[bool] = False + # for the name value + @model_validator(mode="after") + def set_model_name(self): + """Sets the model name.""" + if self.name and self.model_name is None: + self.model_name = self.name + return self + class EmbeddingModelLoadRequest(BaseModel): name: str diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index d151fdd..014b6fb 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -2,10 +2,12 @@ from asyncio import CancelledError from typing import Optional +from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig from common import model from common.networking import get_generator_error, handle_request_disconnect from common.tabby_config import config -from common.utils import unwrap +from common.utils import cast_model, unwrap +from common.model import ModelType from endpoints.core.types.model import ( ModelCard, ModelCardParameters, @@ -15,13 +17,17 @@ ) -def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None): +def get_model_list( + model_path: pathlib.Path, draft_model_path: Optional[pathlib.Path] = None +): """Get the list of models from the provided path.""" # Convert the provided draft model path to a pathlib path for # equality comparisons + if model_path: + model_path = model_path.resolve() if draft_model_path: - draft_model_path = pathlib.Path(draft_model_path).resolve() + draft_model_path = draft_model_path.resolve() model_card_list = ModelList() for path in model_path.iterdir(): @@ -33,7 +39,7 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N return model_card_list -async def get_current_model_list(model_type: str = "model"): +async def get_current_model_list(model_type: ModelType): """ Gets the current model in list format and with path only. @@ -45,13 +51,13 @@ async def get_current_model_list(model_type: str = "model"): # Make sure the model container exists match model_type: - case "model": + case ModelType.MODEL: if model.container: model_path = model.container.model_dir - case "draft": + case ModelType.DRAFT: if model.container: model_path = model.container.draft_model_dir - case "embedding": + case ModelType.EMBEDDING: if model.embeddings_container: model_path = model.embeddings_container.model_dir @@ -94,20 +100,17 @@ def get_current_model(): async def stream_model_load( data: ModelLoadRequest, - model_path: pathlib.Path, - draft_model_path: str, ): """Request generation wrapper for the loading process.""" - # Get trimmed load data - load_data = data.model_dump(exclude_none=True) + load_config = cast_model(data, ModelInstanceConfig) - # Set the draft model path if it exists - if draft_model_path: - load_data["draft"]["draft_model_dir"] = draft_model_path + draft_load_config = ( + cast_model(data.draft, DraftModelInstanceConfig) if data.draft else None + ) load_status = model.load_model_gen( - model_path, skip_wait=data.skip_queue, **load_data + model=load_config, draft=draft_load_config, skip_wait=data.skip_queue ) try: async for module, modules, model_type in load_status: diff --git a/main.py b/main.py index 06db5d5..a25a738 100644 --- a/main.py +++ b/main.py @@ -16,9 +16,11 @@ from common.networking import is_port_in_use from common.signals import signal_handler from common.tabby_config import config +from common.utils import cast_model from endpoints.server import start_api from backends.exllamav2.version import check_exllama_version +from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig async def entrypoint_async(): @@ -47,7 +49,7 @@ async def entrypoint_async(): port = fallback_port # Initialize auth keys - await load_auth_keys(config.network.disable_auth) + await load_auth_keys() gen_logging.broadcast_status() @@ -61,16 +63,10 @@ async def entrypoint_async(): # If an initial model name is specified, create a container # and load the model - model_name = config.model.model_name - if model_name: - model_path = pathlib.Path(config.model.model_dir) - model_path = model_path / model_name - - # TODO: remove model_dump() + if config.model.model_name: await model.load_model( - model_path.resolve(), - **config.model.model_dump(), - draft=config.draft_model.model_dump(), + model=cast_model(config.model, ModelInstanceConfig), + draft=cast_model(config.draft_model, DraftModelInstanceConfig), ) # Load loras after loading the model