From e046cb7be9fbf5c024c0c1af9df46d32b14444b5 Mon Sep 17 00:00:00 2001 From: leoguillaume Date: Fri, 20 Dec 2024 17:43:51 +0100 Subject: [PATCH] feat: add model aliases --- CONTRIBUTING.md | 4 +- app/clients/_modelclients.py | 51 +++++++++-- app/endpoints/audio.py | 5 +- app/endpoints/chat.py | 13 ++- app/endpoints/chunks.py | 2 +- app/endpoints/collections.py | 6 +- app/endpoints/completions.py | 15 ++-- app/endpoints/documents.py | 4 +- app/endpoints/embeddings.py | 9 +- app/endpoints/models.py | 6 +- app/endpoints/rerank.py | 11 +-- app/endpoints/search.py | 6 +- app/helpers/_clientsmanager.py | 20 ++--- app/schemas/chat.py | 7 +- app/schemas/completions.py | 11 +-- app/schemas/settings.py | 81 +++++++++++------- app/tests/conftest.py | 2 +- app/tests/test_chat.py | 19 +++++ app/tests/test_embeddings.py | 149 +++++++++++++++++++++++++++++++++ app/tests/test_models.py | 14 +++- app/tests/test_search.py | 4 +- app/utils/lifespan.py | 4 +- app/utils/security.py | 2 +- docs/deployment.md | 68 +++++++++------ 24 files changed, 380 insertions(+), 133 deletions(-) create mode 100644 app/tests/test_embeddings.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4a17d80a..f771e830 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -55,13 +55,13 @@ Merci, avant chaque pull request, de vérifier le bon déploiement de votre API 1. Lancez l'API en local avec la commande suivante: ```bash - DEFAULT_RATE_LIMIT="30/minute" uvicorn app.main:app --port 8080 --log-level debug --reload + uvicorn app.main:app --port 8080 --log-level debug --reload ``` 2. Exécutez les tests unitaires à la racine du projet ```bash - DEFAULT_RATE_LIMIT="30/minute" PYTHONPATH=. pytest --config-file=pyproject.toml --base-url http://localhost:8080/v1 --api-key-user API_KEY_USER --api-key-admin API_KEY_ADMIN --log-cli-level=INFO + PYTHONPATH=. pytest --config-file=pyproject.toml --base-url http://localhost:8080/v1 --api-key-user API_KEY_USER --api-key-admin API_KEY_ADMIN --log-cli-level=INFO ``` # Notebooks diff --git a/app/clients/_modelclients.py b/app/clients/_modelclients.py index c923fe2f..b1587254 100644 --- a/app/clients/_modelclients.py +++ b/app/clients/_modelclients.py @@ -8,6 +8,7 @@ import requests from app.schemas.embeddings import Embeddings +from app.schemas.chat import ChatCompletion from app.schemas.models import Model, Models from app.schemas.rerank import Rerank from app.schemas.settings import Settings @@ -69,6 +70,7 @@ def get_models_list(self, *args, **kwargs) -> Models: owned_by=self.owned_by, created=self.created, max_context_length=self.max_context_length, + aliases=self.aliases, type=self.type, status=self.status, ) @@ -76,8 +78,28 @@ def get_models_list(self, *args, **kwargs) -> Models: return Models(data=[data]) +def create_chat_completions(self, *args, **kwargs): + """ + Custom method to overwrite OpenAI's create method to raise HTTPException from model API. + """ + try: + url = f"{self.base_url}chat/completions" + headers = {"Authorization": f"Bearer {self.api_key}"} + response = requests.post(url=url, headers=headers, json=kwargs) + response.raise_for_status() + data = response.json() + + return ChatCompletion(**data) + + except Exception as e: + raise HTTPException(status_code=e.response.status_code, detail=json.loads(e.response.text)["message"]) + + # @TODO : useless ? def create_embeddings(self, *args, **kwargs): + """ + Custom method to overwrite OpenAI's create method to raise HTTPException from model API. + """ try: url = f"{self.base_url}embeddings" headers = {"Authorization": f"Bearer {self.api_key}"} @@ -104,6 +126,7 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDI # set attributes for unavailable models self.id = "" self.owned_by = "" + self.aliases = [] self.created = round(number=time.time()) self.max_context_length = None @@ -111,6 +134,9 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, AUDI self.models.list = partial(get_models_list, self) response = self.models.list() + if self.type == LANGUAGE_MODEL_TYPE: + self.chat.completions.create = partial(create_chat_completions, self) + if self.type == EMBEDDINGS_MODEL_TYPE: response = self.embeddings.create(model=self.id, input="hello world") self.vector_size = len(response.data[0].embedding) @@ -145,25 +171,40 @@ class ModelClients(dict): """ def __init__(self, settings: Settings) -> None: - for model_config in settings.models: - model = ModelClient(base_url=model_config.url, api_key=model_config.key, type=model_config.type) + self.aliases = {alias: model_id for model_id, aliases in settings.models.aliases.items() for alias in aliases} + + for model_settings in settings.clients.models: + model = ModelClient( + base_url=model_settings.url, + api_key=model_settings.key, + type=model_settings.type, + ) if model.status == "unavailable": - logger.error(msg=f"unavailable model API on {model_config.url}, skipping.") + logger.error(msg=f"unavailable model API on {model_settings.url}, skipping.") continue try: - logger.info(msg=f"Adding model API {model_config.url} to the client...") + logger.info(msg=f"Adding model API {model_settings.url} to the client...") self.__setitem__(key=model.id, value=model) logger.info(msg="done.") except Exception as e: logger.error(msg=e) + model.aliases = settings.models.aliases.get(model.id, []) + + for alias in self.aliases.keys(): + assert alias not in self.keys(), "Alias is already used by another model." + + assert settings.internet.default_language_model in self.keys(), "Default internet language model not found." + assert settings.internet.default_embeddings_model in self.keys(), "Default internet embeddings model not found." + def __setitem__(self, key: str, value) -> None: - if any(key == k for k in self.keys()): + if key in self.keys(): raise ValueError(f"duplicated model ID {key}, skipping.") else: super().__setitem__(key, value) def __getitem__(self, key: str) -> Any: + key = self.aliases.get(key, key) try: item = super().__getitem__(key) assert item.status == "available", "Model not available." diff --git a/app/endpoints/audio.py b/app/endpoints/audio.py index 7ece0726..1762d169 100644 --- a/app/endpoints/audio.py +++ b/app/endpoints/audio.py @@ -6,12 +6,11 @@ import httpx from app.schemas.audio import AudioTranscription -from app.schemas.settings import AUDIO_MODEL_TYPE from app.utils.exceptions import ModelNotFoundException from app.utils.lifespan import clients, limiter from app.utils.security import User, check_api_key, check_rate_limit from app.utils.settings import settings -from app.utils.variables import DEFAULT_TIMEOUT +from app.utils.variables import DEFAULT_TIMEOUT, AUDIO_MODEL_TYPE router = APIRouter() @@ -135,7 +134,7 @@ @router.post("/audio/transcriptions") -@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def audio_transcriptions( request: Request, file: UploadFile = File(...), diff --git a/app/endpoints/chat.py b/app/endpoints/chat.py index e9d048a3..6db34958 100644 --- a/app/endpoints/chat.py +++ b/app/endpoints/chat.py @@ -11,23 +11,28 @@ from app.schemas.search import Search from app.schemas.security import User from app.schemas.settings import Settings +from app.utils.exceptions import WrongModelTypeException from app.utils.lifespan import clients, limiter from app.utils.security import check_api_key, check_rate_limit from app.utils.settings import settings -from app.utils.variables import DEFAULT_TIMEOUT +from app.utils.variables import DEFAULT_TIMEOUT, LANGUAGE_MODEL_TYPE router = APIRouter() @router.post(path="/chat/completions") -@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(limit_value=settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def chat_completions( request: Request, body: ChatCompletionRequest, user: User = Security(dependency=check_api_key) ) -> Union[ChatCompletion, ChatCompletionChunk]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create for the API specification. """ + client = clients.models[body.model] + if client.type != LANGUAGE_MODEL_TYPE: + raise WrongModelTypeException() + body.model = client.id # replace alias by model id url = f"{client.base_url}chat/completions" headers = {"Authorization": f"Bearer {client.api_key}"} @@ -42,8 +47,8 @@ def retrieval_augmentation_generation( internet_manager=InternetManager( model_clients=clients.models, internet_client=clients.internet, - default_language_model_id=settings.internet.args.default_language_model, - default_embeddings_model_id=settings.internet.args.default_embeddings_model, + default_language_model_id=settings.internet.default_language_model, + default_embeddings_model_id=settings.internet.default_embeddings_model, ), ) searches = search_manager.query( diff --git a/app/endpoints/chunks.py b/app/endpoints/chunks.py index ef9df6ff..6bb0e9b8 100644 --- a/app/endpoints/chunks.py +++ b/app/endpoints/chunks.py @@ -13,7 +13,7 @@ @router.get("/chunks/{collection}/{document}") -@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def get_chunks( request: Request, collection: UUID, diff --git a/app/endpoints/collections.py b/app/endpoints/collections.py index d5a2ce97..b8b2b7c5 100644 --- a/app/endpoints/collections.py +++ b/app/endpoints/collections.py @@ -16,7 +16,7 @@ @router.post("/collections") -@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def create_collection(request: Request, body: CollectionRequest, user: User = Security(check_api_key)) -> Response: """ Create a new collection. @@ -35,7 +35,7 @@ async def create_collection(request: Request, body: CollectionRequest, user: Use @router.get("/collections") -@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def get_collections(request: Request, user: User = Security(check_api_key)) -> Union[Collection, Collections]: """ Get list of collections. @@ -54,7 +54,7 @@ async def get_collections(request: Request, user: User = Security(check_api_key) @router.delete("/collections/{collection}") -@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def delete_collections(request: Request, collection: UUID, user: User = Security(check_api_key)) -> Response: """ Delete a collection. diff --git a/app/endpoints/completions.py b/app/endpoints/completions.py index 0df1759a..8ae3e12d 100644 --- a/app/endpoints/completions.py +++ b/app/endpoints/completions.py @@ -1,25 +1,30 @@ -from fastapi import APIRouter, Request, Security, HTTPException -import httpx import json +from fastapi import APIRouter, HTTPException, Request, Security +import httpx + from app.schemas.completions import CompletionRequest, Completions from app.schemas.security import User -from app.utils.settings import settings +from app.utils.exceptions import WrongModelTypeException from app.utils.lifespan import clients, limiter from app.utils.security import check_api_key, check_rate_limit -from app.utils.variables import DEFAULT_TIMEOUT +from app.utils.settings import settings +from app.utils.variables import DEFAULT_TIMEOUT, LANGUAGE_MODEL_TYPE router = APIRouter() @router.post(path="/completions") -@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(limit_value=settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def completions(request: Request, body: CompletionRequest, user: User = Security(dependency=check_api_key)) -> Completions: """ Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/completions/create for the API specification. """ client = clients.models[body.model] + if client.type != LANGUAGE_MODEL_TYPE: + raise WrongModelTypeException() + body.model = client.id # replace alias by model id url = f"{client.base_url}completions" headers = {"Authorization": f"Bearer {client.api_key}"} diff --git a/app/endpoints/documents.py b/app/endpoints/documents.py index c4caccb0..2332a8d9 100644 --- a/app/endpoints/documents.py +++ b/app/endpoints/documents.py @@ -13,7 +13,7 @@ @router.get("/documents/{collection}") -@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def get_documents( request: Request, collection: UUID, @@ -31,7 +31,7 @@ async def get_documents( @router.delete("/documents/{collection}/{document}") -@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def delete_document(request: Request, collection: UUID, document: UUID, user: User = Security(check_api_key)) -> Response: """ Delete a document and relative collections. diff --git a/app/endpoints/embeddings.py b/app/endpoints/embeddings.py index 180b33c3..9fca1267 100644 --- a/app/endpoints/embeddings.py +++ b/app/endpoints/embeddings.py @@ -14,7 +14,7 @@ @router.post(path="/embeddings") -@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(limit_value=settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Security(dependency=check_api_key)) -> Embeddings: """ Embedding API similar to OpenAI's API. @@ -24,19 +24,14 @@ async def embeddings(request: Request, body: EmbeddingsRequest, user: User = Sec client = clients.models[body.model] if client.type != EMBEDDINGS_MODEL_TYPE: raise WrongModelTypeException() - + body.model = client.id # replace alias by model id url = f"{client.base_url}embeddings" headers = {"Authorization": f"Bearer {client.api_key}"} try: async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as async_client: response = await async_client.request(method="POST", url=url, headers=headers, json=body.model_dump()) - # try: response.raise_for_status() - # except httpx.HTTPStatusError as e: - # if "`inputs` must have less than" in e.response.text: - # raise ContextLengthExceededException() - # raise e data = response.json() return Embeddings(**data) except Exception as e: diff --git a/app/endpoints/models.py b/app/endpoints/models.py index fea2a2a7..dc5a9a8b 100644 --- a/app/endpoints/models.py +++ b/app/endpoints/models.py @@ -13,15 +13,15 @@ @router.get("/models/{model:path}") @router.get("/models") -@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def models(request: Request, model: Optional[str] = None, user: User = Security(check_api_key)) -> Union[Models, Model]: """ Model API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/models/list for the API specification. """ if model is not None: - client = clients.models[model] - response = [row for row in client.models.list().data if row.id == model][0] + model = clients.models[model] + response = [row for row in model.models.list().data if row.id == model.id][0] else: response = {"object": "list", "data": []} for model_id, client in clients.models.items(): diff --git a/app/endpoints/rerank.py b/app/endpoints/rerank.py index 010114b6..d48a8319 100644 --- a/app/endpoints/rerank.py +++ b/app/endpoints/rerank.py @@ -14,17 +14,18 @@ @router.post("/rerank") -@limiter.limit(settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def rerank(request: Request, body: RerankRequest, user: User = Security(check_api_key)): """ Rerank a list of inputs with a language model or reranker model. """ + model = clients.models[body.model] - if clients.models[body.model].type == LANGUAGE_MODEL_TYPE: - reranker = LanguageModelReranker(model=clients.models[body.model]) + if model.type == LANGUAGE_MODEL_TYPE: + reranker = LanguageModelReranker(model=model) data = reranker.create(prompt=body.prompt, input=body.input) - elif clients.models[body.model].type == RERANK_MODEL_TYPE: - data = clients.models[body.model].rerank.create(prompt=body.prompt, input=body.input, model=body.model) + elif model.type == RERANK_MODEL_TYPE: + data = model.rerank.create(prompt=body.prompt, input=body.input, model=model.id) else: raise WrongModelTypeException() diff --git a/app/endpoints/search.py b/app/endpoints/search.py index f2b1db36..a4fcdb3f 100644 --- a/app/endpoints/search.py +++ b/app/endpoints/search.py @@ -11,7 +11,7 @@ @router.post(path="/search") -@limiter.limit(limit_value=settings.default_rate_limit, key_func=lambda request: check_rate_limit(request=request)) +@limiter.limit(limit_value=settings.rate_limit.by_key, key_func=lambda request: check_rate_limit(request=request)) async def search(request: Request, body: SearchRequest, user: User = Security(dependency=check_api_key)) -> Searches: """ Endpoint to search on the internet or with our search client. @@ -26,8 +26,8 @@ async def search(request: Request, body: SearchRequest, user: User = Security(de internet_manager=InternetManager( model_clients=clients.models, internet_client=clients.internet, - default_language_model_id=settings.internet.args.default_language_model, - default_embeddings_model_id=settings.internet.args.default_embeddings_model, + default_language_model_id=settings.internet.default_language_model, + default_embeddings_model_id=settings.internet.default_embeddings_model, ), ) diff --git a/app/helpers/_clientsmanager.py b/app/helpers/_clientsmanager.py index e1776e59..74424622 100644 --- a/app/helpers/_clientsmanager.py +++ b/app/helpers/_clientsmanager.py @@ -15,19 +15,19 @@ def __init__(self, settings: Settings) -> None: def set(self): self.models = ModelClients(settings=self.settings) - self.cache = CacheManager(connection_pool=ConnectionPool(**self.settings.cache.args)) + self.cache = CacheManager(connection_pool=ConnectionPool(**self.settings.clients.cache.args)) - if self.settings.search.type == SEARCH_CLIENT_ELASTIC_TYPE: - self.search = ElasticSearchClient(models=self.models, **self.settings.search.args) - elif self.settings.search.type == SEARCH_CLIENT_QDRANT_TYPE: - self.search = QdrantSearchClient(models=self.models, **self.settings.search.args) + if self.settings.clients.search.type == SEARCH_CLIENT_ELASTIC_TYPE: + self.search = ElasticSearchClient(models=self.models, **self.settings.clients.search.args) + elif self.settings.clients.search.type == SEARCH_CLIENT_QDRANT_TYPE: + self.search = QdrantSearchClient(models=self.models, **self.settings.clients.search.args) - if self.settings.internet.type == INTERNET_CLIENT_DUCKDUCKGO_TYPE: - self.internet = DuckDuckGoInternetClient(**self.settings.internet.args.model_dump()) - elif self.settings.internet.type == INTERNET_CLIENT_BRAVE_TYPE: - self.internet = BraveInternetClient(**self.settings.internet.args.model_dump()) + if self.settings.clients.internet.type == INTERNET_CLIENT_DUCKDUCKGO_TYPE: + self.internet = DuckDuckGoInternetClient(**self.settings.clients.internet.args) + elif self.settings.clients.internet.type == INTERNET_CLIENT_BRAVE_TYPE: + self.internet = BraveInternetClient(**self.settings.clients.internet.args) - self.auth = AuthenticationClient(cache=self.cache, **self.settings.auth.args) if self.settings.auth else None + self.auth = AuthenticationClient(cache=self.cache, **self.settings.clients.auth.args) if self.settings.clients.auth else None def clear(self): self.search.close() diff --git a/app/schemas/chat.py b/app/schemas/chat.py index 09789445..cf6d11b2 100644 --- a/app/schemas/chat.py +++ b/app/schemas/chat.py @@ -4,9 +4,7 @@ from pydantic import BaseModel, Field, model_validator, field_validator from app.schemas.search import SearchArgs, Search -from app.utils.exceptions import WrongModelTypeException -from app.utils.lifespan import clients -from app.utils.variables import LANGUAGE_MODEL_TYPE + DEFAULT_RAG_TEMPLATE = "Réponds à la question suivante en te basant sur les documents ci-dessous : {prompt}\n\nDocuments :\n{chunks}" @@ -55,9 +53,6 @@ class Config: @model_validator(mode="after") def validate_model(cls, values): - if clients.models[values.model].type != LANGUAGE_MODEL_TYPE: - raise WrongModelTypeException() - if values.search: if not values.search_args: raise ValueError("search_args is required when search is true") diff --git a/app/schemas/completions.py b/app/schemas/completions.py index 0013091a..4e726280 100644 --- a/app/schemas/completions.py +++ b/app/schemas/completions.py @@ -1,11 +1,7 @@ from typing import Dict, Iterable, List, Optional, Union from openai.types import Completion -from pydantic import BaseModel, Field, model_validator - -from app.utils.lifespan import clients -from app.utils.variables import LANGUAGE_MODEL_TYPE -from app.utils.exceptions import WrongModelTypeException +from pydantic import BaseModel, Field class CompletionRequest(BaseModel): @@ -27,11 +23,6 @@ class CompletionRequest(BaseModel): top_p: Optional[float] = 1.0 user: Optional[str] = None - @model_validator(mode="after") - def validate_model(cls, values): - if clients.models[values.model].type != LANGUAGE_MODEL_TYPE: - raise WrongModelTypeException() - class Completions(Completion): pass diff --git a/app/schemas/settings.py b/app/schemas/settings.py index dc5c3fdd..c4ab4829 100644 --- a/app/schemas/settings.py +++ b/app/schemas/settings.py @@ -1,17 +1,17 @@ import os -from typing import List, Literal, Optional +from typing import Dict, List, Literal, Optional from pydantic import BaseModel, Field, field_validator, model_validator from pydantic_settings import BaseSettings import yaml from app.utils.variables import ( + AUDIO_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, + INTERNET_CLIENT_BRAVE_TYPE, + INTERNET_CLIENT_DUCKDUCKGO_TYPE, LANGUAGE_MODEL_TYPE, - AUDIO_MODEL_TYPE, RERANK_MODEL_TYPE, - INTERNET_CLIENT_DUCKDUCKGO_TYPE, - INTERNET_CLIENT_BRAVE_TYPE, SEARCH_CLIENT_ELASTIC_TYPE, SEARCH_CLIENT_QDRANT_TYPE, ) @@ -22,6 +22,29 @@ class Config: extra = "allow" +class RateLimit(ConfigBaseModel): + by_key: str = "10/minute" + by_ip: str = "100/minute" + + +class Internet(ConfigBaseModel): + default_language_model: str + default_embeddings_model: str + + +class Models(ConfigBaseModel): + aliases: Dict[str, List[str]] = {} + + @field_validator("aliases", mode="before") + def validate_aliases(cls, aliases): + unique_aliases = list() + for key, values in aliases.items(): + unique_aliases.extend(values) + + assert len(unique_aliases) == len(set(unique_aliases)), "Duplicated aliases found." + return aliases + + class Key(ConfigBaseModel): key: str @@ -31,45 +54,37 @@ class Auth(ConfigBaseModel): args: dict -class Model(ConfigBaseModel): +class ModelClient(ConfigBaseModel): url: str type: Literal[LANGUAGE_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE, AUDIO_MODEL_TYPE, RERANK_MODEL_TYPE] key: Optional[str] = "EMPTY" -class SearchDB(BaseModel): +class SearchDatabase(ConfigBaseModel): type: Literal[SEARCH_CLIENT_ELASTIC_TYPE, SEARCH_CLIENT_QDRANT_TYPE] = SEARCH_CLIENT_QDRANT_TYPE args: dict -class CacheDB(ConfigBaseModel): +class CacheDatabase(ConfigBaseModel): type: Literal["redis"] = "redis" args: dict -class Databases(ConfigBaseModel): - cache: CacheDB - search: SearchDB - - -class InternetArgs(ConfigBaseModel): - default_language_model: str - default_embeddings_model: str - - class Config: - extra = "allow" +class DatabasesClient(ConfigBaseModel): + cache: CacheDatabase + search: SearchDatabase -class Internet(ConfigBaseModel): +class InternetClient(ConfigBaseModel): type: Literal[INTERNET_CLIENT_DUCKDUCKGO_TYPE, INTERNET_CLIENT_BRAVE_TYPE] = INTERNET_CLIENT_DUCKDUCKGO_TYPE - args: InternetArgs + args: dict -class Config(ConfigBaseModel): +class Clients(ConfigBaseModel): auth: Optional[Auth] = None - models: List[Model] = Field(..., min_length=1) - databases: Databases - internet: Optional[Internet] = None + models: List[ModelClient] = Field(..., min_length=1) + databases: DatabasesClient + internet: InternetClient @model_validator(mode="after") def validate_models(cls, values): @@ -89,6 +104,13 @@ def validate_models(cls, values): return values +class Config(ConfigBaseModel): + rate_limit: RateLimit = Field(default_factory=RateLimit) + internet: Internet + models: Models = Field(default_factory=Models) + clients: Clients + + class Settings(BaseSettings): # logging log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" @@ -103,10 +125,6 @@ class Settings(BaseSettings): app_version: str = "0.0.0" app_description: str = "[See documentation](https://github.com/etalab-ia/albert-api/blob/main/README.md)" - # rate_limit - global_rate_limit: str = "100/minute" - default_rate_limit: str = "10/minute" - class Config: extra = "allow" @@ -119,10 +137,11 @@ def config_file_exists(cls, config_file): def setup_config(cls, values): config = Config(**yaml.safe_load(stream=open(file=values.config_file, mode="r"))) - values.auth = config.auth - values.cache = config.databases.cache + values.rate_limit = config.rate_limit values.internet = config.internet values.models = config.models - values.search = config.databases.search + values.clients = config.clients + values.clients.cache = config.clients.databases.cache + values.clients.search = config.clients.databases.search return values diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 9b07d25e..d9c4fced 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -76,4 +76,4 @@ def cleanup_collections(args, session_user, session_admin): def sleep_between_tests(): # Sleep between tests to avoid rate limit errors yield - time.sleep(10) + time.sleep(20) diff --git a/app/tests/test_chat.py b/app/tests/test_chat.py index 178795d1..11dad851 100644 --- a/app/tests/test_chat.py +++ b/app/tests/test_chat.py @@ -7,6 +7,7 @@ from app.schemas.chat import ChatCompletion, ChatCompletionChunk from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE +from app.utils.settings import settings @pytest.fixture(scope="module") @@ -260,3 +261,21 @@ def test_chat_completions_search_wrong_collection(self, args, session_user, setu } response = session_user.post(f"{args["base_url"]}/chat/completions", json=params) assert response.status_code == 404, f"error: retrieve chat completions ({response.status_code})" + + def test_chat_completions_model_alias(self, args, session_user, setup): + """Test the GET /chat/completions model alias.""" + MODEL_ID, _, _, _ = setup + + model_id = list(settings.models.aliases.keys())[0] + aliases = settings.models.aliases[model_id] + + params = { + "model": aliases[0], + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "stream": False, + "n": 1, + "max_tokens": 10, + } + + response = session_user.post(f"{args["base_url"]}/chat/completions", json=params) + assert response.status_code == 200, f"error: retrieve chat completions ({response.status_code}" diff --git a/app/tests/test_embeddings.py b/app/tests/test_embeddings.py new file mode 100644 index 00000000..471ecac8 --- /dev/null +++ b/app/tests/test_embeddings.py @@ -0,0 +1,149 @@ +import pytest + +from app.utils.settings import settings +from app.utils.variables import EMBEDDINGS_MODEL_TYPE + + +@pytest.fixture(scope="module") +def setup(args, session_user): + # Get an embeddings model + response = session_user.get(f"{args['base_url']}/models") + assert response.status_code == 200, f"error: retrieve models ({response.status_code})" + response_json = response.json() + model = [model for model in response_json["data"] if model["type"] == EMBEDDINGS_MODEL_TYPE][0] + MODEL_ID = model["id"] + yield MODEL_ID + + +@pytest.mark.usefixtures("args", "session_user", "setup") +class TestEmbeddings: + def test_embeddings_single_input(self, args, session_user, setup): + """Test the POST /embeddings endpoint with a single input.""" + MODEL_ID = setup + params = { + "model": MODEL_ID, + "input": "Hello, this is a test.", + } + response = session_user.post(f"{args['base_url']}/embeddings", json=params) + assert response.status_code == 200, f"error: create embeddings ({response.status_code})" + + response_json = response.json() + assert "data" in response_json + assert len(response_json["data"]) == 1 + assert "embedding" in response_json["data"][0] + assert isinstance(response_json["data"][0]["embedding"], list) + assert all(isinstance(x, float) for x in response_json["data"][0]["embedding"]) + + def test_embeddings_token_integers_input(self, args, session_user, setup): + """Test the POST /embeddings endpoint with token integers input.""" + MODEL_ID = setup + params = { + "model": MODEL_ID, + "input": [1, 2, 3, 4, 5], # List[int] + } + response = session_user.post(f"{args['base_url']}/embeddings", json=params) + assert response.status_code == 200, f"error: create embeddings ({response.status_code})" + + def test_embeddings_token_integers_batch_input(self, args, session_user, setup): + """Test the POST /embeddings endpoint with batch of token integers input.""" + MODEL_ID = setup + params = { + "model": MODEL_ID, + "input": [[1, 2, 3], [4, 5, 6]], # List[List[int]] + } + response = session_user.post(f"{args['base_url']}/embeddings", json=params) + assert response.status_code == 200, f"error: create embeddings ({response.status_code})" + + def test_embeddings_with_encoding_format(self, args, session_user, setup): + """Test the POST /embeddings endpoint with encoding format.""" + MODEL_ID = setup + params = { + "model": MODEL_ID, + "input": "Test text", + "encoding_format": "float", + } + response = session_user.post(f"{args['base_url']}/embeddings", json=params) + assert response.status_code == 200, f"error: create embeddings ({response.status_code})" + + def test_embeddings_invalid_encoding_format(self, args, session_user, setup): + """Test the POST /embeddings endpoint with invalid encoding format.""" + MODEL_ID = setup + params = { + "model": MODEL_ID, + "input": "Test text", + "encoding_format": "invalid_format", + } + response = session_user.post(f"{args['base_url']}/embeddings", json=params) + assert response.status_code == 422, f"error: invalid encoding format should return 422 ({response.status_code})" + + def test_embeddings_wrong_model_type(self, args, session_user): + """Test the POST /embeddings endpoint with wrong model type.""" + # Get a non-embeddings model (e.g., language model) + response = session_user.get(f"{args['base_url']}/models") + models = response.json()["data"] + non_embeddings_model = [m for m in models if m["type"] != EMBEDDINGS_MODEL_TYPE][0] + + params = { + "model": non_embeddings_model["id"], + "input": "Test text", + } + response = session_user.post(f"{args['base_url']}/embeddings", json=params) + assert response.status_code == 422, f"error: wrong model type should return 400 ({response.status_code})" + + def test_embeddings_batch_input(self, args, session_user, setup): + """Test the POST /embeddings endpoint with batch input.""" + MODEL_ID = setup + params = { + "model": MODEL_ID, + "input": ["Hello, this is a test.", "This is another test."], + } + response = session_user.post(f"{args['base_url']}/embeddings", json=params) + assert response.status_code == 200, f"error: create embeddings ({response.status_code})" + + response_json = response.json() + assert "data" in response_json + assert len(response_json["data"]) == 2 + for item in response_json["data"]: + assert "embedding" in item + assert isinstance(item["embedding"], list) + assert all(isinstance(x, float) for x in item["embedding"]) + + def test_embeddings_empty_input(self, args, session_user, setup): + """Test the POST /embeddings endpoint with empty input.""" + MODEL_ID = setup + params = { + "model": MODEL_ID, + "input": "", + } + response = session_user.post(f"{args['base_url']}/embeddings", json=params) + assert response.status_code == 413, f"error: empty input should return 422 ({response.status_code})" + + def test_embeddings_invalid_model(self, args, session_user): + """Test the POST /embeddings endpoint with invalid model.""" + params = { + "model": "invalid_model_id", + "input": "Hello, this is a test.", + } + response = session_user.post(f"{args['base_url']}/embeddings", json=params) + assert response.status_code == 404, f"error: invalid model should return 404 ({response.status_code})" + + def test_embeddings_missing_input(self, args, session_user, setup): + """Test the POST /embeddings endpoint with missing input.""" + MODEL_ID = setup + params = { + "model": MODEL_ID, + } + response = session_user.post(f"{args['base_url']}/embeddings", json=params) + assert response.status_code == 422, f"error: missing input should return 422 ({response.status_code})" + + def test_embeddings_model_alias(self, args, session_user, setup): + """Test the POST /embeddings endpoint with a model alias.""" + MODEL_ID = setup + aliases = settings.models.aliases[MODEL_ID] + + params = { + "model": aliases[0], + "input": "Hello, this is a test.", + } + response = session_user.post(f"{args['base_url']}/embeddings", json=params) + assert response.status_code == 200, f"error: create embeddings ({response.status_code})" diff --git a/app/tests/test_models.py b/app/tests/test_models.py index 752b98dd..ee5980ba 100644 --- a/app/tests/test_models.py +++ b/app/tests/test_models.py @@ -29,10 +29,22 @@ def test_get_models_non_existing_model(self, args, session_admin): response = session_admin.get(f"{args["base_url"]}/models/non-existing-model") assert response.status_code == 404, f"error: retrieve non-existing model ({response.status_code})" + def test_get_models_aliases(self, args, session_admin): + """Test the GET /models response status code for a non-existing model.""" + + model_id = list(settings.models.aliases.keys())[0] + aliases = settings.models.aliases[model_id] + + response = session_admin.get(f"{args["base_url"]}/models/{model_id}") + assert response.json()["aliases"] == aliases + + response = session_admin.get(f"{args["base_url"]}/models/{aliases[0]}") + assert response.json()["id"] == model_id + def test_get_models_rate_limit(self, args, session_user): """Test the GET /models rate limiting.""" start = time.time() - limit = int(settings.default_rate_limit.replace("/minute", "")) + limit = int(settings.rate_limit.by_key.replace("/minute", "")) i = 0 check = False while time.time() - start < 60: diff --git a/app/tests/test_search.py b/app/tests/test_search.py index 21ce0077..a150425c 100644 --- a/app/tests/test_search.py +++ b/app/tests/test_search.py @@ -112,7 +112,7 @@ def test_lexical_search(self, args, session_user, setup): response = session_user.post(f"{args["base_url"]}/search", json=data) result = response.json() - if settings.search.type == SEARCH_CLIENT_ELASTIC_TYPE: + if settings.clients.search.type == SEARCH_CLIENT_ELASTIC_TYPE: assert response.status_code == 200 assert "Albert" in result["data"][0]["chunk"]["content"] else: @@ -136,7 +136,7 @@ def test_hybrid_search(self, args, session_user, setup): data = {"prompt": "Erasmus", "collections": [COLLECTION_ID], "k": 3, "method": "hybrid"} response = session_user.post(f"{args["base_url"]}/search", json=data) result = response.json() - if settings.search.type == SEARCH_CLIENT_ELASTIC_TYPE: + if settings.clients.search.type == SEARCH_CLIENT_ELASTIC_TYPE: assert response.status_code == 200 assert "Erasmus" in result["data"][0]["chunk"]["content"] else: diff --git a/app/utils/lifespan.py b/app/utils/lifespan.py index 41bcc21c..2523f7ae 100644 --- a/app/utils/lifespan.py +++ b/app/utils/lifespan.py @@ -10,8 +10,8 @@ clients = ClientsManager(settings=settings) limiter = Limiter( key_func=get_ipaddr, - storage_uri=f"redis://{settings.cache.args.get("username", "")}:{settings.cache.args.get("password", "")}@{settings.cache.args["host"]}:{settings.cache.args["port"]}", - default_limits=[settings.global_rate_limit], + storage_uri=f"redis://{settings.clients.cache.args.get("username", "")}:{settings.clients.cache.args.get("password", "")}@{settings.clients.cache.args["host"]}:{settings.clients.cache.args["port"]}", + default_limits=[settings.rate_limit.by_ip], ) diff --git a/app/utils/security.py b/app/utils/security.py index 78fb7fa6..a3a94be2 100644 --- a/app/utils/security.py +++ b/app/utils/security.py @@ -10,7 +10,7 @@ from app.schemas.security import Role -if settings.auth: +if settings.clients.auth: async def check_api_key(api_key: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(scheme_name="API key"))]) -> User: """ diff --git a/docs/deployment.md b/docs/deployment.md index a675f167..e2f057e0 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -14,20 +14,22 @@ ### Variables d'environnements +Les variables d'environnements sont celles propres à FastAPI. + | Variable | Description | | --- | --- | | APP_CONTACT_URL | URL pour les informations de contact de l'application (par défaut : None) | | APP_CONTACT_EMAIL | Email de contact pour l'application (par défaut : None) | | APP_VERSION | Version de l'application (par défaut : "0.0.0") | | APP_DESCRIPTION | Description de l'application (par défaut : None) | -| GLOBAL_RATE_LIMIT | Limite de taux global pour les requêtes API par adresse IP (par défaut : "100/minute") | -| DEFAULT_RATE_LIMIT | Limite de taux par défaut pour les requêtes API par utilisateur (par défaut : "10/minute") | | CONFIG_FILE | Chemin vers le fichier de configuration (par défaut : "config.yml") | | LOG_LEVEL | Niveau de journalisation (par défaut : DEBUG) | -### Clients tiers +### Fichier de configuration (config.yml) + +Pour fonctionner, l'API Albert nécessite configurer le fichier de configuration (config.yml). Celui-ci définit les clients tiers et des paramètres de configuration. -Pour fonctionner, l'API Albert nécessite des clients tiers : +Voici les clients tiers nécessaires : * Auth (optionnel) : [Grist](https://www.getgrist.com/)* * Cache : [Redis](https://redis.io/) @@ -47,38 +49,52 @@ Vous devez à minima à disposer d'un modèle language (text-generation) et d'un Ces clients sont déclarés dans un fichier de configuration qui doit respecter les spécifications suivantes (voir *[config.example.yml](./config.example.yml)* pour un exemple) : ```yaml -auth: [optional] - type: grist - args: [optional] - [arg_name]: [value] - ... +rate_limit: + by_ip: [optional] + by_key: [optional] internet: - type: duckduckgo|brave - args: - default_language_model: [required] - default_embeddings_model: [required] - [arg_name]: [value] - ... + default_language_model: [required] # alias not allowed + default_embeddings_model: [required] # alias not allowed models: - - url: text-generation|text-embeddings-inference|automatic-speech-recognition|text-classification - key: [optional] - type: [required] # at least one of embedding model (text-embeddings-inference) + aliases: [optional] + - [model_name]: [[value, ...]] # duplicate alias not allowed or model_id not allowed ... -databases: - cache: [required] - type: redis - args: [required] +clients: + auth: [optional] + type: grist + args: [optional] [arg_name]: [value] ... - - search: [required] - type: elastic|qdrant - args: [required] + + internet: + type: duckduckgo|brave + args: + default_language_model: [required] + default_embeddings_model: [required] [arg_name]: [value] ... + + models: + - url: text-generation|text-embeddings-inference|automatic-speech-recognition|text-classification + key: [optional] + type: [required] + ... + + databases: + cache: [required] + type: redis + args: [required] + [arg_name]: [value] + ... + + search: [required] + type: elastic|qdrant + args: [required] + [arg_name]: [value] + ... ``` Pour avoir un détail des arguments de configuration, vous pouvez consulter le schéma Pydantic de la configuration [ici](../app/schemas/config.py).